show.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import os.path
  2. import sys
  3. sys.path.append("..")
  4. import pandas as pd
  5. import config
  6. import pickle
  7. from utils.tools import create_custom_heatmap, set_chinese_font, group_list, quick_sort, load_transfer_file_name_code
  8. import csv
  9. from sklearn.preprocessing import StandardScaler
  10. from sklearn.linear_model import LinearRegression
  11. from sklearn.metrics import mean_squared_error, mean_absolute_error
  12. from sklearn.metrics import r2_score
  13. import numpy as np
  14. from openpyxl import load_workbook
  15. set_chinese_font()
  16. def load_pearsonr_mat():
  17. with open(os.path.join(config.R_MAT_JSON_FILE_DIR, config.R_MAT_JSON_FILE_NAME), 'rb') as f:
  18. results = pickle.load(f)
  19. return results
  20. def show_all_results():
  21. """展示所有结果"""
  22. # 加载计算结果
  23. with open(os.path.join(config.R_MAT_JSON_FILE_DIR, config.R_MAT_JSON_FILE_NAME), 'rb') as f:
  24. results = pickle.load(f)
  25. label_list = results.columns.tolist()
  26. # 行列分组显示热力图
  27. row_group_elements_num = 25 # 行分组
  28. row_group_name = group_list(label_list, row_group_elements_num)
  29. col_group_elements_num = 50 # 列分组
  30. col_group_name = group_list(label_list, col_group_elements_num)
  31. for i, row_group in enumerate(row_group_name):
  32. for j, col_group in enumerate(col_group_name):
  33. corr_matrix = results.loc[row_group, col_group]
  34. create_custom_heatmap(corr_matrix, title=f'{config.PROJECT_ID}_水厂数据相关系数热力图{i}-{j}')
  35. query_name = input('是否继续查询结果所在位置(y/n):')
  36. if query_name == 'y' or query_name == 'Y':
  37. while query_name != '退出':
  38. query_name = input('查询:')
  39. flag = False
  40. for i, row_group in enumerate(row_group_name):
  41. if query_name in row_group:
  42. flag = True
  43. print(f'位置:{i}-*.png')
  44. break
  45. if not flag:
  46. print(f'位置:{query_name}不在统计范围')
  47. def save_txt(path):
  48. """按照某一个格式写入txt文件"""
  49. with open(os.path.join(config.R_MAT_JSON_FILE_DIR, config.R_MAT_JSON_FILE_NAME), 'rb') as f:
  50. results = pickle.load(f)
  51. label_list = results.columns.tolist()
  52. with open(path, 'w', encoding='utf-8') as f:
  53. for i in range(len(label_list)):
  54. for j in range(len(label_list)):
  55. r = results.iloc[i,j]
  56. if abs(r-1) < 1e-6 or r < 0.2: continue
  57. f.write(f'{label_list[i]}-{label_list[j]}:{r:.2f};')
  58. def save_csv(path):
  59. with open(os.path.join(config.R_MAT_JSON_FILE_DIR, config.R_MAT_JSON_FILE_NAME), 'rb') as f:
  60. results = pickle.load(f)
  61. label_list = results.columns.tolist()
  62. def rank(path):
  63. """按照相关系数进行排序,排序后写入文件"""
  64. if os.path.exists(path):
  65. os.remove(path)
  66. rmat = load_pearsonr_mat() # 对称矩阵
  67. label_list = rmat.columns.tolist()
  68. with open(path, 'w', newline='', encoding='utf-8') as f:
  69. csv_writer = csv.writer(f)
  70. for col_label in label_list:
  71. # 从皮尔逊矩阵挑选出1列元素
  72. elements = []
  73. for row_label in label_list:
  74. elements.append((row_label, rmat.loc[row_label, col_label]))
  75. # 按照皮尔逊相关系数的绝对值进行升序排序
  76. quick_sort(elements, 0, len(elements) - 1)
  77. # 反转list,由大到小排序
  78. elements = elements[::-1]
  79. # 写入csv
  80. csv_line_content = [f'{tup[0]} | {tup[1]:.2f}' for tup in elements if abs(tup[1]) > 0]
  81. csv_writer.writerow([col_label] + csv_line_content)
  82. def directed_heatmap(series_a_name, series_b_name):
  83. # 定向绘制皮尔逊系数矩阵
  84. rmat = load_pearsonr_mat()
  85. series_a_name = [i for i in series_a_name if i in rmat.columns.tolist()]
  86. series_b_name = [i for i in series_b_name if i in rmat.columns.tolist()]
  87. corr_matrix = rmat.loc[series_a_name, series_b_name]
  88. create_custom_heatmap(corr_matrix, title=f'{config.PROJECT_ID}_PearsonMat-{'_'.join(series_a_name[:3])}等-VS-{'_'.join(series_b_name[:3])}等')
  89. def free_ols(target_name, x_name):
  90. """自由最小二乘"""
  91. # 剔除自身字段
  92. if target_name in x_name:
  93. x_name.remove(target_name)
  94. # 获取数据
  95. with open(config.DF_MERGE_FILE_PATH, 'rb') as f:
  96. df_merge_mat = pickle.load(f)
  97. name_2_code_dict, code_2_name_dict = load_transfer_file_name_code(os.path.join(config.ALL_ITEMS_FILE_DIR, config.TRANSFER_JSON_NAME))
  98. if target_name not in name_2_code_dict.keys():
  99. raise RuntimeError('输入的target字段与数据不匹配', target_name)
  100. x_name = [i for i in x_name if i in name_2_code_dict.keys()]
  101. target_code = name_2_code_dict.get(target_name)
  102. if target_code not in df_merge_mat.columns.tolist():
  103. return
  104. x_code = [name_2_code_dict.get(i) for i in x_name if name_2_code_dict.get(i) in df_merge_mat.columns.tolist()]
  105. if len(x_name) == 0 or len(x_code) == 0:
  106. raise RuntimeError('输入的x字段与数据不匹配', x_name)
  107. #ols
  108. # 标准化
  109. x = df_merge_mat.loc[:, x_code].copy()
  110. y = df_merge_mat.loc[:, target_code]
  111. scaler = StandardScaler()
  112. x = scaler.fit_transform(x)
  113. ols_model = LinearRegression()
  114. ols_model.fit(x, y)
  115. # OLS模型诊断
  116. print_info = []
  117. print('\n===========OLS训练结果==================')
  118. print(f'Y:{target_name}')
  119. print(f"OLS 截距: {ols_model.intercept_}")
  120. print(f"OLS 系数:")
  121. for feat, coef in zip(x_name, ols_model.coef_):
  122. print(f" {feat}: {coef:.4f}")
  123. print_info.append(f'{coef:.4f}*{feat}')
  124. print(f"OLS R² (训练集): {ols_model.score(x,y):.4f}")
  125. print_info = ['+'+i if i[0]!='-' else i for i in print_info]
  126. print(f"{target_name}="+''.join(print_info) + f'+{ols_model.intercept_:.4}' if str(ols_model.intercept_)[0] != '-' else f'{ols_model.intercept_:.4}')
  127. # 基本指标评价
  128. y_pred = ols_model.predict(x)
  129. residuals = y - y_pred
  130. mse = mean_squared_error(y, y_pred)
  131. rmse = np.sqrt(mse)
  132. mae = mean_absolute_error(y, y_pred)
  133. r2 = r2_score(y, y_pred)
  134. # 调整R²
  135. n = len(y)
  136. p = x.shape[1]
  137. adj_r2 = 1 - (1 - r2) * (n - 1) / (n - p - 1)
  138. print("\n===========模型性能指标==================:")
  139. print(f"均方误差 (MSE): {mse:.4f}")
  140. print(f"均方根误差 (RMSE): {rmse:.4f}")
  141. print(f"平均绝对误差 (MAE): {mae:.4f}")
  142. print(f"决定系数 (R²): {r2:.4f}")
  143. print(f"调整R²: {adj_r2:.4f}")
  144. if __name__ == '__main__':
  145. # 对所有结果进行展示
  146. # show_all_results()
  147. # 按照格式写入txt
  148. # save_txt('./tem.txt')
  149. # 皮尔逊排序
  150. #rank(f'./{config.PROJECT_ID}_rank.csv')
  151. # 定向绘制皮尔逊分布图
  152. # 加载Excel工作簿
  153. # workbook = load_workbook(f'./{config.PROJECT_ID}_field_combination.xlsx')
  154. # # 获取所有sheet的名称
  155. # sheet_names = workbook.sheetnames
  156. # print("文件中包含的sheet有:", sheet_names)
  157. # # 遍历每一个sheet
  158. # for sheet_name in sheet_names:
  159. # sheet = workbook[sheet_name]
  160. # # 获取A列数据(从第一行开始)
  161. # series_a_name = [cell.value for cell in sheet['A'] if cell.value is not None]
  162. # # 获取B列数据(从第一行开始)
  163. # series_b_name = [cell.value for cell in sheet['B'] if cell.value is not None]
  164. # print(f"Sheet名称: {sheet_name}")
  165. # print(f" A列 {series_a_name} ")
  166. # print(f" B列 {series_b_name} ")
  167. # directed_heatmap(series_a_name, series_b_name)
  168. # 定向自由回归
  169. # 加载Excel工作簿
  170. workbook = load_workbook(f'./{config.PROJECT_ID}_field_ols.xlsx')
  171. # 获取所有sheet的名称
  172. sheet_names = workbook.sheetnames
  173. print("文件中包含的sheet有:", sheet_names)
  174. # 遍历每一个sheet
  175. for sheet_name in sheet_names:
  176. sheet = workbook[sheet_name]
  177. # 获取A列数据(从第一行开始)
  178. series_a_name = [cell.value for cell in sheet['A'] if cell.value is not None]
  179. # 获取B列数据(从第一行开始)
  180. series_b_name = [cell.value for cell in sheet['B'] if cell.value is not None]
  181. free_ols(series_b_name[0], series_a_name)