import os import sys sys.path.append("..") import csv from temp import config_analysis from temp.config_analysis import COLUMN_NAME_2_INDEX as COLUMN_IDX import seaborn as sns import matplotlib.pyplot as plt from matplotlib import rcParams import matplotlib.font_manager as fm from scipy import stats import numpy as np import pandas as pd def label_queue(): """ 从统计文件中筛选标签,返回标签数据,如果需要修改内部参数请对应修改config_analysis文件 """ with open(config_analysis.INPUT_CSV_FILE) as csv_file_handler: csv_reader = csv.reader(csv_file_handler) next(csv_reader) # ['名称', '编码', '单位', '精度', '设备号', '记录数', '最小时间', '最大时间'] for row in csv_reader: # row: list # 通过记录数量筛选 if int(row[COLUMN_IDX['记录数']]) < config_analysis.DATA_MIN_RECORDS: continue yield {'name': row[COLUMN_IDX['名称']], 'code': row[COLUMN_IDX['编码']]} def diff_tool(name:str, frame: pd.DataFrame, col:str): words = ['累计', '计数', '运行时间'] for word in words: if word in name: frame[col] = frame[col].diff() frame.dropna(subset=[col], inplace=True) return frame def skip_tool(series_a_name:str, series_b_name:str): if '温度' in series_a_name and '温度' in series_b_name: return True if '次数' in series_a_name and '次数' in series_b_name: return True if '累计' in series_a_name and '累计' in series_b_name: return True if '电流' in series_a_name and '电流' in series_b_name: return True if '电压' in series_a_name and '电压' in series_b_name: return True if '电流' in series_a_name and '温度' in series_b_name: return True if '温度' in series_a_name and '电流' in series_b_name: return True if '累计电量' in series_a_name and '累计电量' in series_b_name: return True if '运行时间' in series_a_name and '累计电量' in series_b_name: return True if '累计电量' in series_a_name and '运行时间' in series_b_name: return True if '运行时间' in series_a_name and '运行时间' in series_b_name: return True if '时间设定' in series_a_name and '时间设定' in series_b_name: return True return False def set_chinese_font(): # 1. 清除Matplotlib缓存(关键步骤) # cache_dir = os.path.expanduser('~/.cache/matplotlib') # if os.path.exists(cache_dir): # print(f"清除Matplotlib缓存: {cache_dir}") # for file in os.listdir(cache_dir): # if file.endswith('.cache') or file.endswith('.json'): # os.remove(os.path.join(cache_dir, file)) # 2. 列出所有可用中文字体 chinese_fonts = [ # '/usr/share/fonts/truetype/wqy/wqy-microhei.ttc', # 文泉驿微米黑 # '/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc', # 文泉驿正黑 # '/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc', # 思源黑体 # '/usr/share/fonts/windows/msyh.ttc', # 微软雅黑 '/usr/share/fonts/windows/simsun.ttc' # 宋体 ] # 3. 选择第一个可用的中文字体 selected_font = None for font_path in chinese_fonts: if os.path.exists(font_path): selected_font = font_path print(f"使用字体: {font_path}") break if selected_font is None: print("警告: 未找到任何中文字体文件") # 尝试使用字体名称 rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei', 'Microsoft YaHei', 'SimSun'] rcParams['axes.unicode_minus'] = False else: # 手动添加字体到字体管理器 fm.fontManager.addfont(selected_font) # 获取字体名称 font_prop = fm.FontProperties(fname=selected_font) font_name = font_prop.get_name() print(f"字体名称: {font_name}") # 设置全局字体 rcParams['font.family'] = 'sans-serif' rcParams['font.sans-serif'] = [font_name] rcParams['axes.unicode_minus'] = False def create_custom_heatmap(corr_matrix: pd.DataFrame, title:str="相关系数热力图") -> str: # 设置图像尺寸(根据矩阵大小动态调整) size_factor = max(0.5, min(1.5, len(corr_matrix) / 30)) # 缩放因子 fig_width = 9 + len(corr_matrix.columns) * 0.4 * size_factor fig_height = 7 + len(corr_matrix.index) * 0.4 * size_factor plt.figure(figsize=(fig_width, fig_height)) # 创建热力图 ax = sns.heatmap( corr_matrix, cmap="coolwarm", center=0, annot=True, # 显示数值 fmt=".2f", annot_kws={"size": 13 - len(corr_matrix) / 20}, # 动态调整注释大小 linewidths=0.5, linecolor="white", cbar_kws={"shrink": 0.8, "label": "皮尔逊相关系数"} ) # 旋转x轴标签 plt.xticks(rotation=45, ha='right', fontsize=15) plt.yticks(fontsize=15,rotation=0, ha='right') # 设置标题和标签 plt.title(title, fontsize=18, pad=20) plt.xlabel(f"B序列影响因素(显著性p值<{config_analysis.P_VALUE_THRESHOLD})", fontsize=15) plt.ylabel(f"A序列影响因素(显著性p值<{config_analysis.P_VALUE_THRESHOLD})", fontsize=15) # 添加次要网格线 ax.grid(True, which='minor', color='white', linestyle='-', linewidth=0.5) # 调整布局 plt.tight_layout() # 保存图像 output_file = f"{title.replace(' ', '_')}.png" plt.savefig(output_file, dpi=300, bbox_inches='tight') plt.close() print(f"热力图已保存为: {output_file}") return output_file def cross_corr(group_a:list, group_b:list, all_data:pd.DataFrame, code_2_name_dict:dict) -> pd.DataFrame: # 创建交叉协方差矩阵 corr_matrix = pd.DataFrame(index=group_a, columns=group_b, dtype=np.float32) for a in group_a: for b in group_b: r, p_value = stats.pearsonr(all_data.loc[:, a], all_data.loc[:, b]) if p_value < config_analysis.P_VALUE_THRESHOLD: corr_matrix.loc[a, b] = np.float32(r) # 行列标签中文化 a_code_2_name = {code: code_2_name_dict.get(code) for code in group_a } b_code_2_name = {code: code_2_name_dict.get(code) for code in group_b } corr_matrix.rename(index=a_code_2_name, columns=b_code_2_name, inplace=True) return corr_matrix def group_list(data:list, group_elements_num:int) -> list: """对输入的列表元素进行分组""" group_num = len(data) // group_elements_num + 1 group_code = [] num = 0 for g in range(group_num): group_code.append(data[num:num + group_elements_num]) num += group_elements_num return group_code if __name__ == '__main__': label_q1 = label_queue() label_q2 = label_queue() # for i in label_q1: # print(i['name'], i['code'])