| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221 |
- import sys
- sys.path.append('..')
- import pandas as pd
- from matplotlib import pyplot as plt
- import seaborn as sns
- import config
- import os
- from matplotlib import rcParams
- import matplotlib.font_manager as fm
- import numpy as np
- from typing import Iterable
- import json
- def group_list(data:list, group_elements_num:int) -> list:
- """对输入的列表元素进行分组,分组数量"""
- group_num = len(data) // group_elements_num + 1
- group_code = []
- num = 0
- for g in range(group_num):
- group_code.append(data[num:num + group_elements_num])
- num += group_elements_num
- return group_code
- def fmt_date(start_year,end_year,
- start_month,end_month,
- start_day,end_day,
- start_hour=0,end_hour=23,
- start_minute=0,end_minute=59,
- start_second=0,end_second=59):
- fmt = lambda x: '0' + str(x) if abs(x) < 10 else str(x)
- start_month = fmt(start_month)
- end_month = fmt(end_month)
- start_day = fmt(start_day)
- end_day = fmt(end_day)
- start_hour = fmt(start_hour)
- end_hour = fmt(end_hour)
- start_minute = fmt(start_minute)
- end_minute = fmt(end_minute)
- start_second = fmt(start_second)
- end_second = fmt(end_second)
- start_datetime = f'{start_year}-{start_month}-{start_day} {start_hour}:{start_minute}:{start_second}'
- end_datetime = f'{end_year}-{end_month}-{end_day} {end_hour}:{end_minute}:{end_second}'
- return start_datetime, end_datetime
- def create_custom_heatmap(corr_matrix: pd.DataFrame, title:str="相关系数热力图") -> str:
- """绘制热力图,输入协方差矩阵,自动生成热力图"""
- corr_matrix.replace(0., np.nan, inplace=True)
- # 设置图像尺寸(根据矩阵大小动态调整)
- size_factor = max(0.5, min(1.5, len(corr_matrix) / 30)) # 缩放因子
- fig_width = 9 + len(corr_matrix.columns) * 0.4 * size_factor
- fig_height = 7 + len(corr_matrix.index) * 0.4 * size_factor
- plt.figure(figsize=(fig_width, fig_height))
- # 创建热力图
- ax = sns.heatmap(
- corr_matrix,
- cmap="coolwarm",
- center=0,
- annot=True, # 显示数值
- fmt=".2f",
- annot_kws={"size": 13 - len(corr_matrix) / 20}, # 动态调整注释大小
- linewidths=0.5,
- linecolor="white",
- cbar_kws={"shrink": 0.8, "label": "皮尔逊相关系数"}
- )
- # 旋转x轴标签
- plt.xticks(rotation=45, ha='right', fontsize=15)
- plt.yticks(fontsize=15,rotation=0, ha='right')
- # 设置标题和标签
- plt.title(title, fontsize=18, pad=20)
- plt.xlabel(f"B序列影响因素(显著性p值<{config.P_VALUE_THRESHOLD})", fontsize=15)
- plt.ylabel(f"A序列影响因素(显著性p值<{config.P_VALUE_THRESHOLD})", fontsize=15)
- # 添加次要网格线
- ax.grid(True, which='minor', color='white', linestyle='-', linewidth=0.5)
- # 调整布局
- plt.tight_layout()
- # 保存图像
- output_file = f"{title.replace(' ', '_')}.png"
- plt.savefig(output_file, dpi=300, bbox_inches='tight')
- plt.close()
- print(f"热力图已保存为: {output_file}")
- return output_file
- def set_chinese_font():
- """设置matplotlib中文字体"""
- # 1. 清除Matplotlib缓存(关键步骤)
- # cache_dir = os.path.expanduser('~/.cache/matplotlib')
- # if os.path.exists(cache_dir):
- # print(f"清除Matplotlib缓存: {cache_dir}")
- # for file in os.listdir(cache_dir):
- # if file.endswith('.cache') or file.endswith('.json'):
- # os.remove(os.path.join(cache_dir, file))
- # 2. 列出所有可用中文字体
- chinese_fonts = [
- # '/usr/share/fonts/truetype/wqy/wqy-microhei.ttc', # 文泉驿微米黑
- # '/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc', # 文泉驿正黑
- # '/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc', # 思源黑体
- # '/usr/share/fonts/windows/msyh.ttc', # 微软雅黑
- '/usr/share/fonts/windows/simsun.ttc' # 宋体
- ]
- # 3. 选择第一个可用的中文字体
- selected_font = None
- for font_path in chinese_fonts:
- if os.path.exists(font_path):
- selected_font = font_path
- print(f"使用字体: {font_path}")
- break
- if selected_font is None:
- print("警告: 未找到任何中文字体文件")
- # 尝试使用字体名称
- rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei', 'Microsoft YaHei', 'SimSun']
- rcParams['axes.unicode_minus'] = False
- else:
- # 手动添加字体到字体管理器
- fm.fontManager.addfont(selected_font)
- # 获取字体名称
- font_prop = fm.FontProperties(fname=selected_font)
- font_name = font_prop.get_name()
- print(f"字体名称: {font_name}")
- # 设置全局字体
- rcParams['font.family'] = 'sans-serif'
- rcParams['font.sans-serif'] = [font_name]
- rcParams['axes.unicode_minus'] = False
- def cal_vari_without_zero_nan(data:Iterable, l='None', is_exclude_zero=True)-> tuple:
- """统计平局值和标准差,0和Nan不参与计算"""
- if not isinstance(data, pd.Series):
- raise TypeError("data must be pd.Series")
- # 计算平均
- tem_value_list = []
- for x in data:
- if (abs(x - 0.) < 1e-6) and is_exclude_zero: continue
- if pd.isna(x): continue
- tem_value_list.append(x)
- arr = np.array(tem_value_list)
- # 检查是否仍存在nan
- if np.sum(np.isnan(arr)) > 0:
- raise ValueError(f'数据{l}中仍存在nan没有被剔除')
- mean = np.mean(arr, dtype=np.float32) # 均值
- std_dev = np.std(arr, dtype=np.float32) # 标准差
- return mean, std_dev
- def cal_vari_without_nan(data:Iterable, l='None')-> tuple:
- return cal_vari_without_zero_nan(data, l, is_exclude_zero=False)
- def iqr(data:pd.Series)-> None:
- """剔除序列中的离群点,采用四分位数距法(IQR)"""
- pass
- return
- def quicksort_part(arr:list, low:int, high:int):
- """快速排序"""
- if low >= high:
- return None
- # 设定基准值
- left, right = low, high
- pivot = abs(arr[low][1])
- # 右边放大数,左边放小数
- while left < right:
- # 先从右面开始向左找小于基准值的数
- while left < right and abs(arr[right][1]) >= pivot:
- right -= 1
- # 执行一次交换
- if left < right:
- arr[left], arr[right] = arr[right], arr[left]
- left += 1
- # 再从左面开始向右找大于基准值的数
- while left < right and abs(arr[left][1]) <= pivot:
- left += 1
- # 执行一次交换
- if left < right:
- arr[left], arr[right] = arr[right], arr[left]
- right -= 1
- return left
- def quick_sort(arr:list[tuple], low:int, high:int):
- """元组快排算法"""
- if low >= high:
- return
- # 先排一趟
- mid = quicksort_part(arr, low, high)
- # 排左面
- quick_sort(arr, low, mid-1)
- # 排右面
- quick_sort(arr, mid+1, high)
- def df_is_symetry(df_mat:pd.DataFrame) -> bool:
- """检查DataFrame类型的矩阵是否为对称"""
- if df_mat.shape[0] != df_mat.shape[1]:
- return False
- # 检查索引和列名是否匹配
- if not np.array_equal(df_mat.index, df_mat.columns):
- return False
- # 转换为 NumPy 数组并检查
- return np.allclose(df_mat.values, df_mat.values.T, rtol=1e-5, atol=1e-08)
- def load_transfer_file_name_code(path):
- if not os.path.exists(path):
- raise FileNotFoundError('文件未发现:', path)
- with open(path, "r", encoding="utf-8") as f:
- json_data = json.load(f)
- return json_data.get('name_2_code'), json_data.get('code_2_name')
- if __name__ == '__main__':
- pass
|