539 lines
20 KiB
Python
539 lines
20 KiB
Python
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import base64, os
|
||
from typing import List, Optional, Tuple, Dict, Union
|
||
from io import BytesIO
|
||
|
||
def visualize_data(
|
||
df: pd.DataFrame,
|
||
x_column: str,
|
||
y_columns: List[str],
|
||
**kwargs
|
||
) -> Tuple[str, Optional[str]]:
|
||
"""
|
||
将DataFrame数据可视化并返回base64编码的图像及可选的PNG文件路径
|
||
|
||
参数:
|
||
df: 输入DataFrame
|
||
x_column: 横坐标对应的字段名
|
||
y_columns: 纵坐标字段列表(将显示在独立的Y轴上)
|
||
**kwargs: 绘图配置参数,支持以下选项:
|
||
- title: str = '多Y轴数据可视化' (图表标题)
|
||
- title_fontsize: int = 12 (标题字体大小)
|
||
- figsize: Tuple[float, float] = (10, 5) (图表尺寸)
|
||
- dpi: int = 100 (图像分辨率)
|
||
- colors: List[str] (各系列颜色,默认10种颜色)
|
||
- markers: List[str] (各系列marker类型,默认10种marker)
|
||
- y_labels: List[str] (各Y轴标签)
|
||
- grid: bool = True (是否显示网格)
|
||
- x_scale: str = 'linear' ('linear'或'log', X轴刻度类型)
|
||
- y_scale: Union[str, List[str]] = 'linear' ('linear'或'log', Y轴刻度类型)
|
||
- font_family: str = 'Microsoft YaHei' (中文字体名称)
|
||
- marker_size: int = 6 (marker点大小)
|
||
- legend_position: Literal['top', 'bottom', 'left', 'right'] = 'top' (图例位置)
|
||
- legend_ncol: int = 3 (图例列数)
|
||
- legend_fontsize: int = 10 (图例字体大小)
|
||
- legend_offset: float = 0.05 (图例偏移量)
|
||
- tight_layout_pad: float = 1.5 (紧凑布局的padding值)
|
||
- save_path: Optional[str] = None (PNG文件保存路径)
|
||
- plot_style: Union[str, List[str]] = 'line' (绘图类型: 'line', 'scatter', 'scatter_fit')
|
||
- fit_degree: int = 1 (拟合曲线多项式阶数)
|
||
- scatter_alpha: float = 0.7 (散点透明度)
|
||
- line_style: str = '-' (线型)
|
||
- title_pad: float = 1.0 (标题与图表顶部的间距)
|
||
- group_by_units: bool = False (是否按单位分组到同一Y轴)
|
||
- unit_mapping: Dict[str, str] = None (字段到单位的映射字典)
|
||
|
||
返回:
|
||
Tuple[base64编码字符串, PNG文件路径或None]
|
||
"""
|
||
# 默认画图配置
|
||
config = {
|
||
'title': '多Y轴数据可视化',
|
||
'title_fontsize': 14,
|
||
'figsize': (7.56, 4.54),
|
||
'dpi': 300,
|
||
'colors': [
|
||
'#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
|
||
'#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',
|
||
'#aec7e8', '#ffbb78', '#98df8a', '#ff9896', '#c5b0d5',
|
||
'#c49c94', '#f7b6d2', '#dbdb8d', '#9edae5', '#393b79'
|
||
], # 扩展为20种颜色
|
||
'markers': [
|
||
'o', 's', '^', 'D', 'v', 'p', '*', 'h', 'X', 'd',
|
||
'P', '>', '<', '1', '2', '3', '4', '+', 'x', '|'
|
||
], # 扩展为20种marker
|
||
'y_labels': None,
|
||
'grid': True,
|
||
'x_scale': 'linear',
|
||
'y_scale': 'linear',
|
||
'font_family': 'Microsoft YaHei',
|
||
'marker_size': 3,
|
||
'legend_position': 'right',
|
||
'legend_ncol': 1,
|
||
'legend_fontsize': 10,
|
||
'legend_offset': 0.05,
|
||
'tight_layout_pad': 1.5,
|
||
'save_path': None,
|
||
'plot_style': 'line',
|
||
'fit_degree': 2,
|
||
'scatter_alpha': 0.7,
|
||
'line_style': '-',
|
||
'title_pad': 1.0,
|
||
'group_by_units': False,
|
||
'unit_mapping': None
|
||
}
|
||
|
||
# 更新用户提供的配置
|
||
config.update(kwargs)
|
||
|
||
# 处理plot_style参数
|
||
if isinstance(config['plot_style'], str):
|
||
plot_styles = [config['plot_style']] * len(y_columns)
|
||
else:
|
||
plot_styles = config['plot_style']
|
||
|
||
# 设置中文字体
|
||
try:
|
||
plt.rcParams['font.family'] = config['font_family']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
except:
|
||
print(f"警告: 字体 '{config['font_family']}' 不可用,使用默认字体")
|
||
|
||
# 1. 创建共享X轴,多Y轴的图表
|
||
fig, ax1 = plt.subplots(figsize=config['figsize'], dpi=config['dpi'])
|
||
|
||
# 设置颜色和marker
|
||
colors = config['colors'][:len(y_columns)]
|
||
markers = config['markers'][:len(y_columns)]
|
||
|
||
# 处理y_scale参数
|
||
if isinstance(config['y_scale'], str):
|
||
y_scales = [config['y_scale']] * len(y_columns)
|
||
else:
|
||
y_scales = config['y_scale']
|
||
|
||
# 处理时间类型数据
|
||
x_values = df[x_column]
|
||
if pd.api.types.is_datetime64_any_dtype(x_values):
|
||
x_numeric = (x_values - x_values.min()).dt.total_seconds()
|
||
x_min = x_values.min()
|
||
else:
|
||
x_numeric = x_values
|
||
x_min = None
|
||
|
||
# 生成拟合曲线的x值范围(用于所有系列)
|
||
x_fit = np.linspace(min(x_numeric), max(x_numeric), 100)
|
||
|
||
# 如果启用按单位分组,则重新组织数据
|
||
if config['group_by_units']:
|
||
if config['unit_mapping'] is None:
|
||
# 如果没有提供单位映射,则每个Y列使用自己的单位
|
||
unit_mapping = {col: f"Unit_{i}" for i, col in enumerate(y_columns)}
|
||
else:
|
||
unit_mapping = config['unit_mapping']
|
||
|
||
# 按单位分组
|
||
unit_groups = {}
|
||
for col in y_columns:
|
||
unit = unit_mapping.get(col, col)
|
||
if unit not in unit_groups:
|
||
unit_groups[unit] = []
|
||
unit_groups[unit].append(col)
|
||
|
||
# 重新组织y_columns、plot_styles、y_scales等参数
|
||
grouped_y_columns = []
|
||
grouped_plot_styles = []
|
||
grouped_y_scales = []
|
||
grouped_colors = []
|
||
grouped_markers = []
|
||
|
||
for unit, cols in unit_groups.items():
|
||
grouped_y_columns.extend(cols)
|
||
for i, col in enumerate(cols):
|
||
idx = y_columns.index(col)
|
||
grouped_plot_styles.append(plot_styles[idx])
|
||
grouped_y_scales.append(y_scales[idx])
|
||
grouped_colors.append(colors[idx])
|
||
grouped_markers.append(markers[idx])
|
||
|
||
y_columns = grouped_y_columns
|
||
plot_styles = grouped_plot_styles
|
||
y_scales = grouped_y_scales
|
||
colors = grouped_colors
|
||
markers = grouped_markers
|
||
|
||
# 第一个Y轴
|
||
y_values = df[y_columns[0]]
|
||
|
||
# 根据plot_style绘制图形
|
||
if plot_styles[0] == 'scatter':
|
||
ax1.scatter(x_values, y_values,
|
||
color=colors[0],
|
||
marker=markers[0],
|
||
s=config['marker_size']**2,
|
||
alpha=config['scatter_alpha'],
|
||
label=y_columns[0])
|
||
elif plot_styles[0] == 'scatter_fit':
|
||
# 绘制散点
|
||
ax1.scatter(x_values, y_values,
|
||
color=colors[0],
|
||
marker=markers[0],
|
||
s=config['marker_size']**2,
|
||
alpha=config['scatter_alpha'],
|
||
label=y_columns[0])
|
||
|
||
# 计算拟合曲线
|
||
coeffs = np.polyfit(x_numeric, y_values, config['fit_degree'])
|
||
poly = np.poly1d(coeffs)
|
||
|
||
# 绘制拟合曲线
|
||
if pd.api.types.is_datetime64_any_dtype(x_values):
|
||
x_fit_dates = x_min + pd.to_timedelta(x_fit, unit='s')
|
||
ax1.plot(x_fit_dates, poly(x_fit),
|
||
color=colors[0],
|
||
linestyle=config['line_style'],
|
||
linewidth=1.5,
|
||
label=f'{y_columns[0]} (拟合)')
|
||
else:
|
||
ax1.plot(x_fit, poly(x_fit),
|
||
color=colors[0],
|
||
linestyle=config['line_style'],
|
||
linewidth=1.5,
|
||
label=f'{y_columns[0]} (拟合)')
|
||
else: # 默认线图
|
||
ax1.plot(x_values, y_values,
|
||
color=colors[0],
|
||
marker=markers[0],
|
||
markersize=config['marker_size'],
|
||
linestyle=config['line_style'],
|
||
linewidth=1.5,
|
||
label=y_columns[0])
|
||
|
||
ax1.set_xlabel(x_column)
|
||
ax1.set_ylabel(y_columns[0] if config['y_labels'] is None else config['y_labels'][0],
|
||
color=colors[0])
|
||
ax1.tick_params(axis='y', labelcolor=colors[0])
|
||
ax1.set_xscale(config['x_scale'])
|
||
ax1.set_yscale(y_scales[0])
|
||
|
||
# 显示网格
|
||
if config['grid']:
|
||
ax1.grid(True, linestyle=':', alpha=0.5)
|
||
|
||
# 创建额外的Y轴
|
||
axes = [ax1]
|
||
axis_units = {y_columns[0]: ax1} # 记录每个Y列对应的轴
|
||
|
||
# 智能计算Y轴间距
|
||
base_spacing = 0.1 # 基础间距
|
||
spacing_increment = 0.00 # 每个新增Y轴的间距增量
|
||
current_position = 0.9 # 起始位置
|
||
|
||
for i, (y_col, y_scale, plot_style) in enumerate(zip(y_columns[1:], y_scales[1:], plot_styles[1:]), start=1):
|
||
# 检查是否应该使用现有轴
|
||
reuse_axis = False
|
||
if config['group_by_units']:
|
||
current_unit = config['unit_mapping'].get(y_col, y_col)
|
||
for col, ax in axis_units.items():
|
||
if config['unit_mapping'].get(col, col) == current_unit:
|
||
reuse_axis = True
|
||
current_ax = ax
|
||
break
|
||
|
||
if reuse_axis:
|
||
# 使用现有轴
|
||
ax = current_ax
|
||
else:
|
||
# 创建新轴
|
||
ax = ax1.twinx()
|
||
|
||
# 计算新Y轴的位置
|
||
current_position += base_spacing + (i-1)*spacing_increment
|
||
ax.spines['right'].set_position(('axes', current_position))
|
||
|
||
# 记录新轴
|
||
axis_units[y_col] = ax
|
||
|
||
y_values = df[y_col]
|
||
|
||
# 根据plot_style绘制图形
|
||
if plot_style == 'scatter':
|
||
ax.scatter(x_values, y_values,
|
||
color=colors[i % len(colors)],
|
||
marker=markers[i % len(markers)],
|
||
s=config['marker_size']**2,
|
||
alpha=config['scatter_alpha'],
|
||
label=y_col)
|
||
elif plot_style == 'scatter_fit':
|
||
# 绘制散点
|
||
ax.scatter(x_values, y_values,
|
||
color=colors[i % len(colors)],
|
||
marker=markers[i % len(markers)],
|
||
s=config['marker_size']**2,
|
||
alpha=config['scatter_alpha'],
|
||
label=y_col)
|
||
|
||
# 计算拟合曲线
|
||
coeffs = np.polyfit(x_numeric, y_values, config['fit_degree'])
|
||
poly = np.poly1d(coeffs)
|
||
|
||
# 绘制拟合曲线
|
||
if pd.api.types.is_datetime64_any_dtype(x_values):
|
||
x_fit_dates = x_min + pd.to_timedelta(x_fit, unit='s')
|
||
ax.plot(x_fit_dates, poly(x_fit),
|
||
color=colors[i % len(colors)],
|
||
linestyle=config['line_style'],
|
||
linewidth=1.5,
|
||
label=f'{y_col} (拟合)')
|
||
else:
|
||
ax.plot(x_fit, poly(x_fit),
|
||
color=colors[i % len(colors)],
|
||
linestyle=config['line_style'],
|
||
linewidth=1.5,
|
||
label=f'{y_col} (拟合)')
|
||
else: # 默认线图
|
||
ax.plot(x_values, y_values,
|
||
color=colors[i % len(colors)],
|
||
marker=markers[i % len(markers)],
|
||
markersize=config['marker_size'],
|
||
linestyle=config['line_style'],
|
||
linewidth=1.5,
|
||
label=y_col)
|
||
|
||
ax.set_xlabel(x_column if config['x_label'] is None else config['x_label'])
|
||
if not reuse_axis:
|
||
# 设置Y轴标签位置
|
||
ax.set_ylabel(y_col if config['y_labels'] is None else config['y_labels'][i],
|
||
color=colors[i % len(colors)])
|
||
ax.yaxis.set_label_position('right')
|
||
ax.yaxis.set_ticks_position('right')
|
||
|
||
# 智能调整标签位置
|
||
label_offset = 0.07 # 动态偏移量
|
||
ax.yaxis.set_label_coords(current_position + label_offset, 0.5)
|
||
|
||
ax.tick_params(axis='y', labelcolor=colors[i % len(colors)])
|
||
ax.set_yscale(y_scale)
|
||
axes.append(ax)
|
||
|
||
# 设置标题(放在最上方)
|
||
plt.title(config['title'], fontsize=config['title_fontsize'], pad=config['title_pad'], y=1.05)
|
||
|
||
# 合并图例
|
||
lines = []
|
||
labels = []
|
||
for ax in axes:
|
||
line, label = ax.get_legend_handles_labels()
|
||
lines.extend(line)
|
||
labels.extend(label)
|
||
|
||
# 根据图例位置设置不同的布局参数
|
||
legend_params = {
|
||
'top': {
|
||
'loc': 'lower center',
|
||
'bbox_to_anchor': (0.5+config['legend_offset'], 1.0),
|
||
'borderaxespad': 0.1,
|
||
'ncol': 4,
|
||
'frameon': False,
|
||
'fontsize': config['legend_fontsize'],
|
||
'adjust_top': 0.85,
|
||
'adjust_bottom': 0.15,
|
||
'adjust_left': 0.0,
|
||
'adjust_right': 1.0
|
||
},
|
||
'bottom': {
|
||
'loc': 'upper center',
|
||
'bbox_to_anchor': (0.2+config['legend_offset'], -0.02),
|
||
'borderaxespad': 0.5,
|
||
'ncol': 4,
|
||
'frameon': False,
|
||
'fontsize': config['legend_fontsize'],
|
||
'adjust_top': 0.95,
|
||
'adjust_bottom': 0.2,
|
||
'adjust_left': 0.1,
|
||
'adjust_right': 0.9
|
||
},
|
||
'left': {
|
||
'loc': 'center left',
|
||
'bbox_to_anchor': (-0.12+config['legend_offset'], 0.5),
|
||
'borderaxespad': 0.5,
|
||
'ncol': 1,
|
||
'frameon': False,
|
||
'fontsize': config['legend_fontsize'],
|
||
'adjust_top': 0.95,
|
||
'adjust_bottom': 0.1,
|
||
'adjust_left': 0.25,
|
||
'adjust_right': 0.9 - 0.07*len(y_columns)
|
||
},
|
||
'right': {
|
||
'loc': 'center right',
|
||
'bbox_to_anchor': (1.15+config['legend_offset'], 0.5),
|
||
'borderaxespad': 0.5,
|
||
'ncol': 1,
|
||
'frameon': False,
|
||
'fontsize': config['legend_fontsize'],
|
||
'adjust_top': 0.95,
|
||
'adjust_bottom': 0.1,
|
||
'adjust_left': 0.1,
|
||
'adjust_right': 0.85 - 0.07*len(y_columns)
|
||
}
|
||
}
|
||
|
||
# 添加图例
|
||
legend = fig.legend(
|
||
lines, labels,
|
||
loc=legend_params[config['legend_position']]['loc'],
|
||
bbox_to_anchor=legend_params[config['legend_position']]['bbox_to_anchor'],
|
||
ncol=legend_params[config['legend_position']]['ncol'],
|
||
borderaxespad=legend_params[config['legend_position']]['borderaxespad'],
|
||
frameon=legend_params[config['legend_position']]['frameon'],
|
||
fontsize=legend_params[config['legend_position']]['fontsize']
|
||
)
|
||
|
||
# 动态调整布局参数
|
||
params = legend_params[config['legend_position']]
|
||
fig.subplots_adjust(
|
||
top=params['adjust_top'],
|
||
bottom=params['adjust_bottom'],
|
||
left=params['adjust_left'],
|
||
right=min(params['adjust_right'], 0.9)
|
||
)
|
||
|
||
# 紧凑布局
|
||
plt.tight_layout(pad=config['tight_layout_pad'])
|
||
|
||
# 将图像保存到内存并转为base64
|
||
buffer = BytesIO()
|
||
plt.savefig(buffer, format='png', bbox_inches='tight', dpi=config['dpi'])
|
||
plt.close()
|
||
|
||
base64_data = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||
|
||
# 保存到文件(如果指定了路径)
|
||
png_path = None
|
||
if config['save_path']:
|
||
os.makedirs(os.path.dirname(config['save_path']) or '.', exist_ok=True)
|
||
with open(config['save_path'], 'wb') as f:
|
||
f.write(buffer.getvalue())
|
||
png_path = os.path.abspath(config['save_path'])
|
||
|
||
return base64_data, png_path
|
||
|
||
|
||
# 使用示例
|
||
if __name__ == "__main__":
|
||
# 创建示例数据
|
||
np.random.seed(42)
|
||
data = {
|
||
'时间': pd.date_range('2023-01-01', periods=10, freq='M'),
|
||
'销售额': np.random.randint(1000, 2000, 10) + np.arange(10)*10,
|
||
'用户数': np.random.randint(5000, 8000, 10) + np.arange(10)*50,
|
||
'转化率': np.random.uniform(0.001, 0.003, 10) + np.arange(10)*0.02,
|
||
'ZDU1': np.random.uniform(0.01, 0.03, 10) + np.arange(10)*0.02,
|
||
'ZDU2': np.random.uniform(0.0001, 0.0003, 10) + np.arange(10)*0.02
|
||
}
|
||
df = pd.DataFrame(data)
|
||
|
||
# 定义单位映射
|
||
unit_mapping = {
|
||
'销售额': '元',
|
||
'用户数': '个',
|
||
'转化率': '百分比',
|
||
'ZDU1': '百分比',
|
||
'ZDU2': '百分比'
|
||
}
|
||
|
||
# 示例1: 按单位分组
|
||
img_data1, img_path1 = visualize_data(
|
||
df=df,
|
||
x_column='时间',
|
||
y_columns=['销售额', '用户数', '转化率', 'ZDU1', 'ZDU2'],
|
||
title='业务指标分析(按单位分组)',
|
||
plot_style=['line', 'scatter_fit', 'scatter_fit', 'line', 'line'],
|
||
y_labels=['销售额(元)', '用户数(个)', '转化率(%)', 'ZDU1(%)', 'ZDU2(%)'],
|
||
figsize=(12, 6),
|
||
fit_degree=2,
|
||
legend_position='top',
|
||
legend_ncol=3,
|
||
save_path='output/grouped_units.png',
|
||
title_pad=15,
|
||
group_by_units=True,
|
||
unit_mapping=unit_mapping
|
||
)
|
||
|
||
# 示例2: 不按单位分组
|
||
img_data2, img_path2 = visualize_data(
|
||
df=df,
|
||
x_column='时间',
|
||
y_columns=['销售额', '用户数', '转化率', 'ZDU1', 'ZDU2'],
|
||
title='业务指标分析(不分组)',
|
||
plot_style=['line', 'scatter_fit', 'scatter_fit', 'line', 'line'],
|
||
y_labels=['销售额(元)', '用户数(个)', '转化率(%)', 'ZDU1(%)', 'ZDU2(%)'],
|
||
figsize=(12, 6),
|
||
fit_degree=2,
|
||
legend_position='top',
|
||
legend_ncol=3,
|
||
save_path='output/non_grouped.png',
|
||
title_pad=15,
|
||
group_by_units=False
|
||
)
|
||
|
||
# 示例1: 图例在上方
|
||
img_data3, img_path3 = visualize_data(
|
||
df=df,
|
||
x_column='时间',
|
||
y_columns=['销售额', '用户数', '转化率', 'ZDU1', 'ZDU2'],
|
||
title='业务指标分析(图例在上方)',
|
||
plot_style=['line', 'scatter_fit', 'scatter_fit', 'line', 'line'],
|
||
y_labels=['销售额(万)', '用户数', '转化率(%)', 'ZDY1', 'ZDU2'],
|
||
figsize=(12, 6),
|
||
fit_degree=2,
|
||
legend_position='top',
|
||
legend_ncol=3,
|
||
save_path='output/legend_top.png',
|
||
title_pad=15
|
||
)
|
||
|
||
# 示例2: 图例在下方
|
||
img_data4, img_path4 = visualize_data(
|
||
df=df,
|
||
x_column='时间',
|
||
y_columns=['销售额', '用户数', '转化率'],
|
||
title='业务指标分析(图例在下方)',
|
||
plot_style=['line', 'scatter_fit', 'scatter_fit'],
|
||
y_labels=['销售额(万)', '用户数', '转化率(%)'],
|
||
figsize=(12, 6),
|
||
legend_position='bottom',
|
||
save_path='output/legend_bottom.png',
|
||
title_pad=10
|
||
)
|
||
|
||
# 示例3: 图例在左侧
|
||
img_data5, img_path5 = visualize_data(
|
||
df=df,
|
||
x_column='时间',
|
||
y_columns=['销售额', '用户数'],
|
||
title='业务指标分析(图例在左侧)',
|
||
plot_style=['line', 'scatter_fit'],
|
||
y_labels=['销售额(万)', '用户数'],
|
||
figsize=(12, 6),
|
||
legend_position='left',
|
||
save_path='output/legend_left.png',
|
||
title_pad=10
|
||
)
|
||
|
||
# 示例4: 图例在右侧
|
||
img_data6, img_path6 = visualize_data(
|
||
df=df,
|
||
x_column='时间',
|
||
y_columns=['销售额', '用户数', '转化率'],
|
||
title='业务指标分析(图例在右侧)',
|
||
plot_style=['line', 'scatter_fit', 'scatter_fit'],
|
||
y_labels=['销售额(万)', '用户数', '转化率(%)'],
|
||
figsize=(12, 6),
|
||
legend_position='right',
|
||
save_path='output/legend_right.png',
|
||
title_pad=10
|
||
) |