import pandas as pd import matplotlib.pyplot as plt import numpy as np import base64, os from typing import List, Optional, Tuple, Dict, Union from io import BytesIO def visualize_data( df: pd.DataFrame, x_column: str, y_columns: List[str], **kwargs ) -> Tuple[str, Optional[str]]: """ 将DataFrame数据可视化并返回base64编码的图像及可选的PNG文件路径 参数: df: 输入DataFrame x_column: 横坐标对应的字段名 y_columns: 纵坐标字段列表(将显示在独立的Y轴上) **kwargs: 绘图配置参数,支持以下选项: - title: str = '多Y轴数据可视化' (图表标题) - title_fontsize: int = 12 (标题字体大小) - figsize: Tuple[float, float] = (10, 5) (图表尺寸) - dpi: int = 100 (图像分辨率) - colors: List[str] (各系列颜色,默认10种颜色) - markers: List[str] (各系列marker类型,默认10种marker) - y_labels: List[str] (各Y轴标签) - grid: bool = True (是否显示网格) - x_scale: str = 'linear' ('linear'或'log', X轴刻度类型) - y_scale: Union[str, List[str]] = 'linear' ('linear'或'log', Y轴刻度类型) - font_family: str = 'Microsoft YaHei' (中文字体名称) - marker_size: int = 6 (marker点大小) - legend_position: Literal['top', 'bottom', 'left', 'right'] = 'top' (图例位置) - legend_ncol: int = 3 (图例列数) - legend_fontsize: int = 10 (图例字体大小) - legend_offset: float = 0.05 (图例偏移量) - tight_layout_pad: float = 1.5 (紧凑布局的padding值) - save_path: Optional[str] = None (PNG文件保存路径) - plot_style: Union[str, List[str]] = 'line' (绘图类型: 'line', 'scatter', 'scatter_fit') - fit_degree: int = 1 (拟合曲线多项式阶数) - scatter_alpha: float = 0.7 (散点透明度) - line_style: str = '-' (线型) - title_pad: float = 1.0 (标题与图表顶部的间距) - group_by_units: bool = False (是否按单位分组到同一Y轴) - unit_mapping: Dict[str, str] = None (字段到单位的映射字典) 返回: Tuple[base64编码字符串, PNG文件路径或None] """ # 默认画图配置 config = { 'title': '多Y轴数据可视化', 'title_fontsize': 14, 'figsize': (7.56, 4.54), 'dpi': 300, 'colors': [ '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf', '#aec7e8', '#ffbb78', '#98df8a', '#ff9896', '#c5b0d5', '#c49c94', '#f7b6d2', '#dbdb8d', '#9edae5', '#393b79' ], # 扩展为20种颜色 'markers': [ 'o', 's', '^', 'D', 'v', 'p', '*', 'h', 'X', 'd', 'P', '>', '<', '1', '2', '3', '4', '+', 'x', '|' ], # 扩展为20种marker 'y_labels': None, 'grid': True, 'x_scale': 'linear', 'y_scale': 'linear', 'font_family': 'Microsoft YaHei', 'marker_size': 3, 'legend_position': 'right', 'legend_ncol': 1, 'legend_fontsize': 10, 'legend_offset': 0.05, 'tight_layout_pad': 1.5, 'save_path': None, 'plot_style': 'line', 'fit_degree': 2, 'scatter_alpha': 0.7, 'line_style': '-', 'title_pad': 1.0, 'group_by_units': False, 'unit_mapping': None } # 更新用户提供的配置 config.update(kwargs) # 处理plot_style参数 if isinstance(config['plot_style'], str): plot_styles = [config['plot_style']] * len(y_columns) else: plot_styles = config['plot_style'] # 设置中文字体 try: plt.rcParams['font.family'] = config['font_family'] plt.rcParams['axes.unicode_minus'] = False except: print(f"警告: 字体 '{config['font_family']}' 不可用,使用默认字体") # 1. 创建共享X轴,多Y轴的图表 fig, ax1 = plt.subplots(figsize=config['figsize'], dpi=config['dpi']) # 设置颜色和marker colors = config['colors'][:len(y_columns)] markers = config['markers'][:len(y_columns)] # 处理y_scale参数 if isinstance(config['y_scale'], str): y_scales = [config['y_scale']] * len(y_columns) else: y_scales = config['y_scale'] # 处理时间类型数据 x_values = df[x_column] if pd.api.types.is_datetime64_any_dtype(x_values): x_numeric = (x_values - x_values.min()).dt.total_seconds() x_min = x_values.min() else: x_numeric = x_values x_min = None # 生成拟合曲线的x值范围(用于所有系列) x_fit = np.linspace(min(x_numeric), max(x_numeric), 100) # 如果启用按单位分组,则重新组织数据 if config['group_by_units']: if config['unit_mapping'] is None: # 如果没有提供单位映射,则每个Y列使用自己的单位 unit_mapping = {col: f"Unit_{i}" for i, col in enumerate(y_columns)} else: unit_mapping = config['unit_mapping'] # 按单位分组 unit_groups = {} for col in y_columns: unit = unit_mapping.get(col, col) if unit not in unit_groups: unit_groups[unit] = [] unit_groups[unit].append(col) # 重新组织y_columns、plot_styles、y_scales等参数 grouped_y_columns = [] grouped_plot_styles = [] grouped_y_scales = [] grouped_colors = [] grouped_markers = [] for unit, cols in unit_groups.items(): grouped_y_columns.extend(cols) for i, col in enumerate(cols): idx = y_columns.index(col) grouped_plot_styles.append(plot_styles[idx]) grouped_y_scales.append(y_scales[idx]) grouped_colors.append(colors[idx]) grouped_markers.append(markers[idx]) y_columns = grouped_y_columns plot_styles = grouped_plot_styles y_scales = grouped_y_scales colors = grouped_colors markers = grouped_markers # 第一个Y轴 y_values = df[y_columns[0]] # 根据plot_style绘制图形 if plot_styles[0] == 'scatter': ax1.scatter(x_values, y_values, color=colors[0], marker=markers[0], s=config['marker_size']**2, alpha=config['scatter_alpha'], label=y_columns[0]) elif plot_styles[0] == 'scatter_fit': # 绘制散点 ax1.scatter(x_values, y_values, color=colors[0], marker=markers[0], s=config['marker_size']**2, alpha=config['scatter_alpha'], label=y_columns[0]) # 计算拟合曲线 coeffs = np.polyfit(x_numeric, y_values, config['fit_degree']) poly = np.poly1d(coeffs) # 绘制拟合曲线 if pd.api.types.is_datetime64_any_dtype(x_values): x_fit_dates = x_min + pd.to_timedelta(x_fit, unit='s') ax1.plot(x_fit_dates, poly(x_fit), color=colors[0], linestyle=config['line_style'], linewidth=1.5, label=f'{y_columns[0]} (拟合)') else: ax1.plot(x_fit, poly(x_fit), color=colors[0], linestyle=config['line_style'], linewidth=1.5, label=f'{y_columns[0]} (拟合)') else: # 默认线图 ax1.plot(x_values, y_values, color=colors[0], marker=markers[0], markersize=config['marker_size'], linestyle=config['line_style'], linewidth=1.5, label=y_columns[0]) ax1.set_xlabel(x_column) ax1.set_ylabel(y_columns[0] if config['y_labels'] is None else config['y_labels'][0], color=colors[0]) ax1.tick_params(axis='y', labelcolor=colors[0]) ax1.set_xscale(config['x_scale']) ax1.set_yscale(y_scales[0]) # 显示网格 if config['grid']: ax1.grid(True, linestyle=':', alpha=0.5) # 创建额外的Y轴 axes = [ax1] axis_units = {y_columns[0]: ax1} # 记录每个Y列对应的轴 # 智能计算Y轴间距 base_spacing = 0.1 # 基础间距 spacing_increment = 0.00 # 每个新增Y轴的间距增量 current_position = 0.9 # 起始位置 for i, (y_col, y_scale, plot_style) in enumerate(zip(y_columns[1:], y_scales[1:], plot_styles[1:]), start=1): # 检查是否应该使用现有轴 reuse_axis = False if config['group_by_units']: current_unit = config['unit_mapping'].get(y_col, y_col) for col, ax in axis_units.items(): if config['unit_mapping'].get(col, col) == current_unit: reuse_axis = True current_ax = ax break if reuse_axis: # 使用现有轴 ax = current_ax else: # 创建新轴 ax = ax1.twinx() # 计算新Y轴的位置 current_position += base_spacing + (i-1)*spacing_increment ax.spines['right'].set_position(('axes', current_position)) # 记录新轴 axis_units[y_col] = ax y_values = df[y_col] # 根据plot_style绘制图形 if plot_style == 'scatter': ax.scatter(x_values, y_values, color=colors[i % len(colors)], marker=markers[i % len(markers)], s=config['marker_size']**2, alpha=config['scatter_alpha'], label=y_col) elif plot_style == 'scatter_fit': # 绘制散点 ax.scatter(x_values, y_values, color=colors[i % len(colors)], marker=markers[i % len(markers)], s=config['marker_size']**2, alpha=config['scatter_alpha'], label=y_col) # 计算拟合曲线 coeffs = np.polyfit(x_numeric, y_values, config['fit_degree']) poly = np.poly1d(coeffs) # 绘制拟合曲线 if pd.api.types.is_datetime64_any_dtype(x_values): x_fit_dates = x_min + pd.to_timedelta(x_fit, unit='s') ax.plot(x_fit_dates, poly(x_fit), color=colors[i % len(colors)], linestyle=config['line_style'], linewidth=1.5, label=f'{y_col} (拟合)') else: ax.plot(x_fit, poly(x_fit), color=colors[i % len(colors)], linestyle=config['line_style'], linewidth=1.5, label=f'{y_col} (拟合)') else: # 默认线图 ax.plot(x_values, y_values, color=colors[i % len(colors)], marker=markers[i % len(markers)], markersize=config['marker_size'], linestyle=config['line_style'], linewidth=1.5, label=y_col) ax.set_xlabel(x_column if config['x_label'] is None else config['x_label']) if not reuse_axis: # 设置Y轴标签位置 ax.set_ylabel(y_col if config['y_labels'] is None else config['y_labels'][i], color=colors[i % len(colors)]) ax.yaxis.set_label_position('right') ax.yaxis.set_ticks_position('right') # 智能调整标签位置 label_offset = 0.07 # 动态偏移量 ax.yaxis.set_label_coords(current_position + label_offset, 0.5) ax.tick_params(axis='y', labelcolor=colors[i % len(colors)]) ax.set_yscale(y_scale) axes.append(ax) # 设置标题(放在最上方) plt.title(config['title'], fontsize=config['title_fontsize'], pad=config['title_pad'], y=1.05) # 合并图例 lines = [] labels = [] for ax in axes: line, label = ax.get_legend_handles_labels() lines.extend(line) labels.extend(label) # 根据图例位置设置不同的布局参数 legend_params = { 'top': { 'loc': 'lower center', 'bbox_to_anchor': (0.5+config['legend_offset'], 1.0), 'borderaxespad': 0.1, 'ncol': 4, 'frameon': False, 'fontsize': config['legend_fontsize'], 'adjust_top': 0.85, 'adjust_bottom': 0.15, 'adjust_left': 0.0, 'adjust_right': 1.0 }, 'bottom': { 'loc': 'upper center', 'bbox_to_anchor': (0.2+config['legend_offset'], -0.02), 'borderaxespad': 0.5, 'ncol': 4, 'frameon': False, 'fontsize': config['legend_fontsize'], 'adjust_top': 0.95, 'adjust_bottom': 0.2, 'adjust_left': 0.1, 'adjust_right': 0.9 }, 'left': { 'loc': 'center left', 'bbox_to_anchor': (-0.12+config['legend_offset'], 0.5), 'borderaxespad': 0.5, 'ncol': 1, 'frameon': False, 'fontsize': config['legend_fontsize'], 'adjust_top': 0.95, 'adjust_bottom': 0.1, 'adjust_left': 0.25, 'adjust_right': 0.9 - 0.07*len(y_columns) }, 'right': { 'loc': 'center right', 'bbox_to_anchor': (1.15+config['legend_offset'], 0.5), 'borderaxespad': 0.5, 'ncol': 1, 'frameon': False, 'fontsize': config['legend_fontsize'], 'adjust_top': 0.95, 'adjust_bottom': 0.1, 'adjust_left': 0.1, 'adjust_right': 0.85 - 0.07*len(y_columns) } } # 添加图例 legend = fig.legend( lines, labels, loc=legend_params[config['legend_position']]['loc'], bbox_to_anchor=legend_params[config['legend_position']]['bbox_to_anchor'], ncol=legend_params[config['legend_position']]['ncol'], borderaxespad=legend_params[config['legend_position']]['borderaxespad'], frameon=legend_params[config['legend_position']]['frameon'], fontsize=legend_params[config['legend_position']]['fontsize'] ) # 动态调整布局参数 params = legend_params[config['legend_position']] fig.subplots_adjust( top=params['adjust_top'], bottom=params['adjust_bottom'], left=params['adjust_left'], right=min(params['adjust_right'], 0.9) ) # 紧凑布局 plt.tight_layout(pad=config['tight_layout_pad']) # 将图像保存到内存并转为base64 buffer = BytesIO() plt.savefig(buffer, format='png', bbox_inches='tight', dpi=config['dpi']) plt.close() base64_data = base64.b64encode(buffer.getvalue()).decode('utf-8') # 保存到文件(如果指定了路径) png_path = None if config['save_path']: os.makedirs(os.path.dirname(config['save_path']) or '.', exist_ok=True) with open(config['save_path'], 'wb') as f: f.write(buffer.getvalue()) png_path = os.path.abspath(config['save_path']) return base64_data, png_path # 使用示例 if __name__ == "__main__": # 创建示例数据 np.random.seed(42) data = { '时间': pd.date_range('2023-01-01', periods=10, freq='M'), '销售额': np.random.randint(1000, 2000, 10) + np.arange(10)*10, '用户数': np.random.randint(5000, 8000, 10) + np.arange(10)*50, '转化率': np.random.uniform(0.001, 0.003, 10) + np.arange(10)*0.02, 'ZDU1': np.random.uniform(0.01, 0.03, 10) + np.arange(10)*0.02, 'ZDU2': np.random.uniform(0.0001, 0.0003, 10) + np.arange(10)*0.02 } df = pd.DataFrame(data) # 定义单位映射 unit_mapping = { '销售额': '元', '用户数': '个', '转化率': '百分比', 'ZDU1': '百分比', 'ZDU2': '百分比' } # 示例1: 按单位分组 img_data1, img_path1 = visualize_data( df=df, x_column='时间', y_columns=['销售额', '用户数', '转化率', 'ZDU1', 'ZDU2'], title='业务指标分析(按单位分组)', plot_style=['line', 'scatter_fit', 'scatter_fit', 'line', 'line'], y_labels=['销售额(元)', '用户数(个)', '转化率(%)', 'ZDU1(%)', 'ZDU2(%)'], figsize=(12, 6), fit_degree=2, legend_position='top', legend_ncol=3, save_path='output/grouped_units.png', title_pad=15, group_by_units=True, unit_mapping=unit_mapping ) # 示例2: 不按单位分组 img_data2, img_path2 = visualize_data( df=df, x_column='时间', y_columns=['销售额', '用户数', '转化率', 'ZDU1', 'ZDU2'], title='业务指标分析(不分组)', plot_style=['line', 'scatter_fit', 'scatter_fit', 'line', 'line'], y_labels=['销售额(元)', '用户数(个)', '转化率(%)', 'ZDU1(%)', 'ZDU2(%)'], figsize=(12, 6), fit_degree=2, legend_position='top', legend_ncol=3, save_path='output/non_grouped.png', title_pad=15, group_by_units=False ) # 示例1: 图例在上方 img_data3, img_path3 = visualize_data( df=df, x_column='时间', y_columns=['销售额', '用户数', '转化率', 'ZDU1', 'ZDU2'], title='业务指标分析(图例在上方)', plot_style=['line', 'scatter_fit', 'scatter_fit', 'line', 'line'], y_labels=['销售额(万)', '用户数', '转化率(%)', 'ZDY1', 'ZDU2'], figsize=(12, 6), fit_degree=2, legend_position='top', legend_ncol=3, save_path='output/legend_top.png', title_pad=15 ) # 示例2: 图例在下方 img_data4, img_path4 = visualize_data( df=df, x_column='时间', y_columns=['销售额', '用户数', '转化率'], title='业务指标分析(图例在下方)', plot_style=['line', 'scatter_fit', 'scatter_fit'], y_labels=['销售额(万)', '用户数', '转化率(%)'], figsize=(12, 6), legend_position='bottom', save_path='output/legend_bottom.png', title_pad=10 ) # 示例3: 图例在左侧 img_data5, img_path5 = visualize_data( df=df, x_column='时间', y_columns=['销售额', '用户数'], title='业务指标分析(图例在左侧)', plot_style=['line', 'scatter_fit'], y_labels=['销售额(万)', '用户数'], figsize=(12, 6), legend_position='left', save_path='output/legend_left.png', title_pad=10 ) # 示例4: 图例在右侧 img_data6, img_path6 = visualize_data( df=df, x_column='时间', y_columns=['销售额', '用户数', '转化率'], title='业务指标分析(图例在右侧)', plot_style=['line', 'scatter_fit', 'scatter_fit'], y_labels=['销售额(万)', '用户数', '转化率(%)'], figsize=(12, 6), legend_position='right', save_path='output/legend_right.png', title_pad=10 )