import sys sys.path.append('..') import pandas as pd from matplotlib import pyplot as plt import seaborn as sns import config import os from matplotlib import rcParams import matplotlib.font_manager as fm import numpy as np from typing import Iterable import json def group_list(data:list, group_elements_num:int) -> list: """对输入的列表元素进行分组,分组数量""" group_num = len(data) // group_elements_num + 1 group_code = [] num = 0 for g in range(group_num): group_code.append(data[num:num + group_elements_num]) num += group_elements_num return group_code def fmt_date(start_year,end_year, start_month,end_month, start_day,end_day, start_hour=0,end_hour=23, start_minute=0,end_minute=59, start_second=0,end_second=59): fmt = lambda x: '0' + str(x) if abs(x) < 10 else str(x) start_month = fmt(start_month) end_month = fmt(end_month) start_day = fmt(start_day) end_day = fmt(end_day) start_hour = fmt(start_hour) end_hour = fmt(end_hour) start_minute = fmt(start_minute) end_minute = fmt(end_minute) start_second = fmt(start_second) end_second = fmt(end_second) start_datetime = f'{start_year}-{start_month}-{start_day} {start_hour}:{start_minute}:{start_second}' end_datetime = f'{end_year}-{end_month}-{end_day} {end_hour}:{end_minute}:{end_second}' return start_datetime, end_datetime def create_custom_heatmap(corr_matrix: pd.DataFrame, title:str="相关系数热力图") -> str: """绘制热力图,输入协方差矩阵,自动生成热力图""" corr_matrix.replace(0., np.nan, inplace=True) # 设置图像尺寸(根据矩阵大小动态调整) size_factor = max(0.5, min(1.5, len(corr_matrix) / 30)) # 缩放因子 fig_width = 9 + len(corr_matrix.columns) * 0.4 * size_factor fig_height = 7 + len(corr_matrix.index) * 0.4 * size_factor plt.figure(figsize=(fig_width, fig_height)) # 创建热力图 ax = sns.heatmap( corr_matrix, cmap="coolwarm", center=0, annot=True, # 显示数值 fmt=".2f", annot_kws={"size": 13 - len(corr_matrix) / 20}, # 动态调整注释大小 linewidths=0.5, linecolor="white", cbar_kws={"shrink": 0.8, "label": "皮尔逊相关系数"} ) # 旋转x轴标签 plt.xticks(rotation=45, ha='right', fontsize=15) plt.yticks(fontsize=15,rotation=0, ha='right') # 设置标题和标签 plt.title(title, fontsize=18, pad=20) plt.xlabel(f"B序列影响因素(显著性p值<{config.P_VALUE_THRESHOLD})", fontsize=15) plt.ylabel(f"A序列影响因素(显著性p值<{config.P_VALUE_THRESHOLD})", fontsize=15) # 添加次要网格线 ax.grid(True, which='minor', color='white', linestyle='-', linewidth=0.5) # 调整布局 plt.tight_layout() # 保存图像 output_file = f"{title.replace(' ', '_')}.png" plt.savefig(output_file, dpi=300, bbox_inches='tight') plt.close() print(f"热力图已保存为: {output_file}") return output_file def set_chinese_font(): """设置matplotlib中文字体""" # 1. 清除Matplotlib缓存(关键步骤) # cache_dir = os.path.expanduser('~/.cache/matplotlib') # if os.path.exists(cache_dir): # print(f"清除Matplotlib缓存: {cache_dir}") # for file in os.listdir(cache_dir): # if file.endswith('.cache') or file.endswith('.json'): # os.remove(os.path.join(cache_dir, file)) # 2. 列出所有可用中文字体 chinese_fonts = [ # '/usr/share/fonts/truetype/wqy/wqy-microhei.ttc', # 文泉驿微米黑 # '/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc', # 文泉驿正黑 # '/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc', # 思源黑体 # '/usr/share/fonts/windows/msyh.ttc', # 微软雅黑 '/usr/share/fonts/windows/simsun.ttc' # 宋体 ] # 3. 选择第一个可用的中文字体 selected_font = None for font_path in chinese_fonts: if os.path.exists(font_path): selected_font = font_path print(f"使用字体: {font_path}") break if selected_font is None: print("警告: 未找到任何中文字体文件") # 尝试使用字体名称 rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei', 'Microsoft YaHei', 'SimSun'] rcParams['axes.unicode_minus'] = False else: # 手动添加字体到字体管理器 fm.fontManager.addfont(selected_font) # 获取字体名称 font_prop = fm.FontProperties(fname=selected_font) font_name = font_prop.get_name() print(f"字体名称: {font_name}") # 设置全局字体 rcParams['font.family'] = 'sans-serif' rcParams['font.sans-serif'] = [font_name] rcParams['axes.unicode_minus'] = False def cal_vari_without_zero_nan(data:Iterable, l='None', is_exclude_zero=True)-> tuple: """统计平局值和标准差,0和Nan不参与计算""" if not isinstance(data, pd.Series): raise TypeError("data must be pd.Series") # 计算平均 tem_value_list = [] for x in data: if (abs(x - 0.) < 1e-6) and is_exclude_zero: continue if pd.isna(x): continue tem_value_list.append(x) arr = np.array(tem_value_list) # 检查是否仍存在nan if np.sum(np.isnan(arr)) > 0: raise ValueError(f'数据{l}中仍存在nan没有被剔除') mean = np.mean(arr, dtype=np.float32) # 均值 std_dev = np.std(arr, dtype=np.float32) # 标准差 return mean, std_dev def cal_vari_without_nan(data:Iterable, l='None')-> tuple: return cal_vari_without_zero_nan(data, l, is_exclude_zero=False) def iqr(data:pd.Series)-> None: """剔除序列中的离群点,采用四分位数距法(IQR)""" pass return def quicksort_part(arr:list, low:int, high:int): """快速排序""" if low >= high: return None # 设定基准值 left, right = low, high pivot = abs(arr[low][1]) # 右边放大数,左边放小数 while left < right: # 先从右面开始向左找小于基准值的数 while left < right and abs(arr[right][1]) >= pivot: right -= 1 # 执行一次交换 if left < right: arr[left], arr[right] = arr[right], arr[left] left += 1 # 再从左面开始向右找大于基准值的数 while left < right and abs(arr[left][1]) <= pivot: left += 1 # 执行一次交换 if left < right: arr[left], arr[right] = arr[right], arr[left] right -= 1 return left def quick_sort(arr:list[tuple], low:int, high:int): """元组快排算法""" if low >= high: return # 先排一趟 mid = quicksort_part(arr, low, high) # 排左面 quick_sort(arr, low, mid-1) # 排右面 quick_sort(arr, mid+1, high) def df_is_symetry(df_mat:pd.DataFrame) -> bool: """检查DataFrame类型的矩阵是否为对称""" if df_mat.shape[0] != df_mat.shape[1]: return False # 检查索引和列名是否匹配 if not np.array_equal(df_mat.index, df_mat.columns): return False # 转换为 NumPy 数组并检查 return np.allclose(df_mat.values, df_mat.values.T, rtol=1e-5, atol=1e-08) def load_transfer_file_name_code(path): if not os.path.exists(path): raise FileNotFoundError('文件未发现:', path) with open(path, "r", encoding="utf-8") as f: json_data = json.load(f) return json_data.get('name_2_code'), json_data.get('code_2_name') if __name__ == '__main__': pass