utils_analysis.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import os
  2. import sys
  3. sys.path.append("..")
  4. import csv
  5. from temp import config_analysis
  6. from temp.config_analysis import COLUMN_NAME_2_INDEX as COLUMN_IDX
  7. import seaborn as sns
  8. import matplotlib.pyplot as plt
  9. from matplotlib import rcParams
  10. import matplotlib.font_manager as fm
  11. from scipy import stats
  12. import numpy as np
  13. import pandas as pd
  14. def label_queue():
  15. """
  16. 从统计文件中筛选标签,返回标签数据,如果需要修改内部参数请对应修改config_analysis文件
  17. """
  18. with open(config_analysis.INPUT_CSV_FILE) as csv_file_handler:
  19. csv_reader = csv.reader(csv_file_handler)
  20. next(csv_reader) # ['名称', '编码', '单位', '精度', '设备号', '记录数', '最小时间', '最大时间']
  21. for row in csv_reader: # row: list
  22. # 通过记录数量筛选
  23. if int(row[COLUMN_IDX['记录数']]) < config_analysis.DATA_MIN_RECORDS: continue
  24. yield {'name': row[COLUMN_IDX['名称']], 'code': row[COLUMN_IDX['编码']]}
  25. def diff_tool(name:str, frame: pd.DataFrame, col:str):
  26. words = ['累计', '计数', '运行时间']
  27. for word in words:
  28. if word in name:
  29. frame[col] = frame[col].diff()
  30. frame.dropna(subset=[col], inplace=True)
  31. return frame
  32. def skip_tool(series_a_name:str, series_b_name:str):
  33. if '温度' in series_a_name and '温度' in series_b_name: return True
  34. if '次数' in series_a_name and '次数' in series_b_name: return True
  35. if '累计' in series_a_name and '累计' in series_b_name: return True
  36. if '电流' in series_a_name and '电流' in series_b_name: return True
  37. if '电压' in series_a_name and '电压' in series_b_name: return True
  38. if '电流' in series_a_name and '温度' in series_b_name: return True
  39. if '温度' in series_a_name and '电流' in series_b_name: return True
  40. if '累计电量' in series_a_name and '累计电量' in series_b_name: return True
  41. if '运行时间' in series_a_name and '累计电量' in series_b_name: return True
  42. if '累计电量' in series_a_name and '运行时间' in series_b_name: return True
  43. if '运行时间' in series_a_name and '运行时间' in series_b_name: return True
  44. if '时间设定' in series_a_name and '时间设定' in series_b_name: return True
  45. return False
  46. def set_chinese_font():
  47. # 1. 清除Matplotlib缓存(关键步骤)
  48. # cache_dir = os.path.expanduser('~/.cache/matplotlib')
  49. # if os.path.exists(cache_dir):
  50. # print(f"清除Matplotlib缓存: {cache_dir}")
  51. # for file in os.listdir(cache_dir):
  52. # if file.endswith('.cache') or file.endswith('.json'):
  53. # os.remove(os.path.join(cache_dir, file))
  54. # 2. 列出所有可用中文字体
  55. chinese_fonts = [
  56. # '/usr/share/fonts/truetype/wqy/wqy-microhei.ttc', # 文泉驿微米黑
  57. # '/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc', # 文泉驿正黑
  58. # '/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc', # 思源黑体
  59. # '/usr/share/fonts/windows/msyh.ttc', # 微软雅黑
  60. '/usr/share/fonts/windows/simsun.ttc' # 宋体
  61. ]
  62. # 3. 选择第一个可用的中文字体
  63. selected_font = None
  64. for font_path in chinese_fonts:
  65. if os.path.exists(font_path):
  66. selected_font = font_path
  67. print(f"使用字体: {font_path}")
  68. break
  69. if selected_font is None:
  70. print("警告: 未找到任何中文字体文件")
  71. # 尝试使用字体名称
  72. rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei', 'Microsoft YaHei', 'SimSun']
  73. rcParams['axes.unicode_minus'] = False
  74. else:
  75. # 手动添加字体到字体管理器
  76. fm.fontManager.addfont(selected_font)
  77. # 获取字体名称
  78. font_prop = fm.FontProperties(fname=selected_font)
  79. font_name = font_prop.get_name()
  80. print(f"字体名称: {font_name}")
  81. # 设置全局字体
  82. rcParams['font.family'] = 'sans-serif'
  83. rcParams['font.sans-serif'] = [font_name]
  84. rcParams['axes.unicode_minus'] = False
  85. def create_custom_heatmap(corr_matrix: pd.DataFrame, title:str="相关系数热力图") -> str:
  86. # 设置图像尺寸(根据矩阵大小动态调整)
  87. size_factor = max(0.5, min(1.5, len(corr_matrix) / 30)) # 缩放因子
  88. fig_width = 9 + len(corr_matrix.columns) * 0.4 * size_factor
  89. fig_height = 7 + len(corr_matrix.index) * 0.4 * size_factor
  90. plt.figure(figsize=(fig_width, fig_height))
  91. # 创建热力图
  92. ax = sns.heatmap(
  93. corr_matrix,
  94. cmap="coolwarm",
  95. center=0,
  96. annot=True, # 显示数值
  97. fmt=".2f",
  98. annot_kws={"size": 13 - len(corr_matrix) / 20}, # 动态调整注释大小
  99. linewidths=0.5,
  100. linecolor="white",
  101. cbar_kws={"shrink": 0.8, "label": "皮尔逊相关系数"}
  102. )
  103. # 旋转x轴标签
  104. plt.xticks(rotation=45, ha='right', fontsize=15)
  105. plt.yticks(fontsize=15,rotation=0, ha='right')
  106. # 设置标题和标签
  107. plt.title(title, fontsize=18, pad=20)
  108. plt.xlabel(f"B序列影响因素(显著性p值<{config_analysis.P_VALUE_THRESHOLD})", fontsize=15)
  109. plt.ylabel(f"A序列影响因素(显著性p值<{config_analysis.P_VALUE_THRESHOLD})", fontsize=15)
  110. # 添加次要网格线
  111. ax.grid(True, which='minor', color='white', linestyle='-', linewidth=0.5)
  112. # 调整布局
  113. plt.tight_layout()
  114. # 保存图像
  115. output_file = f"{title.replace(' ', '_')}.png"
  116. plt.savefig(output_file, dpi=300, bbox_inches='tight')
  117. plt.close()
  118. print(f"热力图已保存为: {output_file}")
  119. return output_file
  120. def cross_corr(group_a:list, group_b:list, all_data:pd.DataFrame, code_2_name_dict:dict) -> pd.DataFrame:
  121. # 创建交叉协方差矩阵
  122. corr_matrix = pd.DataFrame(index=group_a, columns=group_b, dtype=np.float32)
  123. for a in group_a:
  124. for b in group_b:
  125. r, p_value = stats.pearsonr(all_data.loc[:, a], all_data.loc[:, b])
  126. if p_value < config_analysis.P_VALUE_THRESHOLD:
  127. corr_matrix.loc[a, b] = np.float32(r)
  128. # 行列标签中文化
  129. a_code_2_name = {code: code_2_name_dict.get(code) for code in group_a }
  130. b_code_2_name = {code: code_2_name_dict.get(code) for code in group_b }
  131. corr_matrix.rename(index=a_code_2_name, columns=b_code_2_name, inplace=True)
  132. return corr_matrix
  133. def group_list(data:list, group_elements_num:int) -> list:
  134. """对输入的列表元素进行分组"""
  135. group_num = len(data) // group_elements_num + 1
  136. group_code = []
  137. num = 0
  138. for g in range(group_num):
  139. group_code.append(data[num:num + group_elements_num])
  140. num += group_elements_num
  141. return group_code
  142. if __name__ == '__main__':
  143. label_q1 = label_queue()
  144. label_q2 = label_queue()
  145. # for i in label_q1:
  146. # print(i['name'], i['code'])