tools.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. import sys
  2. sys.path.append('..')
  3. import pandas as pd
  4. from matplotlib import pyplot as plt
  5. import seaborn as sns
  6. import config
  7. import os
  8. from matplotlib import rcParams
  9. import matplotlib.font_manager as fm
  10. import numpy as np
  11. from typing import Iterable
  12. import json
  13. def group_list(data:list, group_elements_num:int) -> list:
  14. """对输入的列表元素进行分组,分组数量"""
  15. group_num = len(data) // group_elements_num + 1
  16. group_code = []
  17. num = 0
  18. for g in range(group_num):
  19. group_code.append(data[num:num + group_elements_num])
  20. num += group_elements_num
  21. return group_code
  22. def fmt_date(start_year,end_year,
  23. start_month,end_month,
  24. start_day,end_day,
  25. start_hour=0,end_hour=23,
  26. start_minute=0,end_minute=59,
  27. start_second=0,end_second=59):
  28. fmt = lambda x: '0' + str(x) if abs(x) < 10 else str(x)
  29. start_month = fmt(start_month)
  30. end_month = fmt(end_month)
  31. start_day = fmt(start_day)
  32. end_day = fmt(end_day)
  33. start_hour = fmt(start_hour)
  34. end_hour = fmt(end_hour)
  35. start_minute = fmt(start_minute)
  36. end_minute = fmt(end_minute)
  37. start_second = fmt(start_second)
  38. end_second = fmt(end_second)
  39. start_datetime = f'{start_year}-{start_month}-{start_day} {start_hour}:{start_minute}:{start_second}'
  40. end_datetime = f'{end_year}-{end_month}-{end_day} {end_hour}:{end_minute}:{end_second}'
  41. return start_datetime, end_datetime
  42. def create_custom_heatmap(corr_matrix: pd.DataFrame, title:str="相关系数热力图") -> str:
  43. """绘制热力图,输入协方差矩阵,自动生成热力图"""
  44. corr_matrix.replace(0., np.nan, inplace=True)
  45. # 设置图像尺寸(根据矩阵大小动态调整)
  46. size_factor = max(0.5, min(1.5, len(corr_matrix) / 30)) # 缩放因子
  47. fig_width = 9 + len(corr_matrix.columns) * 0.4 * size_factor
  48. fig_height = 7 + len(corr_matrix.index) * 0.4 * size_factor
  49. plt.figure(figsize=(fig_width, fig_height))
  50. # 创建热力图
  51. ax = sns.heatmap(
  52. corr_matrix,
  53. cmap="coolwarm",
  54. center=0,
  55. annot=True, # 显示数值
  56. fmt=".2f",
  57. annot_kws={"size": 13 - len(corr_matrix) / 20}, # 动态调整注释大小
  58. linewidths=0.5,
  59. linecolor="white",
  60. cbar_kws={"shrink": 0.8, "label": "皮尔逊相关系数"}
  61. )
  62. # 旋转x轴标签
  63. plt.xticks(rotation=45, ha='right', fontsize=15)
  64. plt.yticks(fontsize=15,rotation=0, ha='right')
  65. # 设置标题和标签
  66. plt.title(title, fontsize=18, pad=20)
  67. plt.xlabel(f"B序列影响因素(显著性p值<{config.P_VALUE_THRESHOLD})", fontsize=15)
  68. plt.ylabel(f"A序列影响因素(显著性p值<{config.P_VALUE_THRESHOLD})", fontsize=15)
  69. # 添加次要网格线
  70. ax.grid(True, which='minor', color='white', linestyle='-', linewidth=0.5)
  71. # 调整布局
  72. plt.tight_layout()
  73. # 保存图像
  74. output_file = f"{title.replace(' ', '_')}.png"
  75. plt.savefig(output_file, dpi=300, bbox_inches='tight')
  76. plt.close()
  77. print(f"热力图已保存为: {output_file}")
  78. return output_file
  79. def set_chinese_font():
  80. """设置matplotlib中文字体"""
  81. # 1. 清除Matplotlib缓存(关键步骤)
  82. # cache_dir = os.path.expanduser('~/.cache/matplotlib')
  83. # if os.path.exists(cache_dir):
  84. # print(f"清除Matplotlib缓存: {cache_dir}")
  85. # for file in os.listdir(cache_dir):
  86. # if file.endswith('.cache') or file.endswith('.json'):
  87. # os.remove(os.path.join(cache_dir, file))
  88. # 2. 列出所有可用中文字体
  89. chinese_fonts = [
  90. # '/usr/share/fonts/truetype/wqy/wqy-microhei.ttc', # 文泉驿微米黑
  91. # '/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc', # 文泉驿正黑
  92. # '/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc', # 思源黑体
  93. # '/usr/share/fonts/windows/msyh.ttc', # 微软雅黑
  94. '/usr/share/fonts/windows/simsun.ttc' # 宋体
  95. ]
  96. # 3. 选择第一个可用的中文字体
  97. selected_font = None
  98. for font_path in chinese_fonts:
  99. if os.path.exists(font_path):
  100. selected_font = font_path
  101. print(f"使用字体: {font_path}")
  102. break
  103. if selected_font is None:
  104. print("警告: 未找到任何中文字体文件")
  105. # 尝试使用字体名称
  106. rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei', 'Microsoft YaHei', 'SimSun']
  107. rcParams['axes.unicode_minus'] = False
  108. else:
  109. # 手动添加字体到字体管理器
  110. fm.fontManager.addfont(selected_font)
  111. # 获取字体名称
  112. font_prop = fm.FontProperties(fname=selected_font)
  113. font_name = font_prop.get_name()
  114. print(f"字体名称: {font_name}")
  115. # 设置全局字体
  116. rcParams['font.family'] = 'sans-serif'
  117. rcParams['font.sans-serif'] = [font_name]
  118. rcParams['axes.unicode_minus'] = False
  119. def cal_vari_without_zero_nan(data:Iterable, l='None', is_exclude_zero=True)-> tuple:
  120. """统计平局值和标准差,0和Nan不参与计算"""
  121. if not isinstance(data, pd.Series):
  122. raise TypeError("data must be pd.Series")
  123. # 计算平均
  124. tem_value_list = []
  125. for x in data:
  126. if (abs(x - 0.) < 1e-6) and is_exclude_zero: continue
  127. if pd.isna(x): continue
  128. tem_value_list.append(x)
  129. arr = np.array(tem_value_list)
  130. # 检查是否仍存在nan
  131. if np.sum(np.isnan(arr)) > 0:
  132. raise ValueError(f'数据{l}中仍存在nan没有被剔除')
  133. mean = np.mean(arr, dtype=np.float32) # 均值
  134. std_dev = np.std(arr, dtype=np.float32) # 标准差
  135. return mean, std_dev
  136. def cal_vari_without_nan(data:Iterable, l='None')-> tuple:
  137. return cal_vari_without_zero_nan(data, l, is_exclude_zero=False)
  138. def iqr(data:pd.Series)-> None:
  139. """剔除序列中的离群点,采用四分位数距法(IQR)"""
  140. pass
  141. return
  142. def quicksort_part(arr:list, low:int, high:int):
  143. """快速排序"""
  144. if low >= high:
  145. return None
  146. # 设定基准值
  147. left, right = low, high
  148. pivot = abs(arr[low][1])
  149. # 右边放大数,左边放小数
  150. while left < right:
  151. # 先从右面开始向左找小于基准值的数
  152. while left < right and abs(arr[right][1]) >= pivot:
  153. right -= 1
  154. # 执行一次交换
  155. if left < right:
  156. arr[left], arr[right] = arr[right], arr[left]
  157. left += 1
  158. # 再从左面开始向右找大于基准值的数
  159. while left < right and abs(arr[left][1]) <= pivot:
  160. left += 1
  161. # 执行一次交换
  162. if left < right:
  163. arr[left], arr[right] = arr[right], arr[left]
  164. right -= 1
  165. return left
  166. def quick_sort(arr:list[tuple], low:int, high:int):
  167. """元组快排算法"""
  168. if low >= high:
  169. return
  170. # 先排一趟
  171. mid = quicksort_part(arr, low, high)
  172. # 排左面
  173. quick_sort(arr, low, mid-1)
  174. # 排右面
  175. quick_sort(arr, mid+1, high)
  176. def df_is_symetry(df_mat:pd.DataFrame) -> bool:
  177. """检查DataFrame类型的矩阵是否为对称"""
  178. if df_mat.shape[0] != df_mat.shape[1]:
  179. return False
  180. # 检查索引和列名是否匹配
  181. if not np.array_equal(df_mat.index, df_mat.columns):
  182. return False
  183. # 转换为 NumPy 数组并检查
  184. return np.allclose(df_mat.values, df_mat.values.T, rtol=1e-5, atol=1e-08)
  185. def load_transfer_file_name_code(path):
  186. if not os.path.exists(path):
  187. raise FileNotFoundError('文件未发现:', path)
  188. with open(path, "r", encoding="utf-8") as f:
  189. json_data = json.load(f)
  190. return json_data.get('name_2_code'), json_data.get('code_2_name')
  191. if __name__ == '__main__':
  192. pass