main_analysis.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import sys
  2. sys.path.append("..")
  3. from Database.database_ import Database, DatabaseParam
  4. import pandas as pd
  5. from scipy import stats
  6. from utils_analysis import label_queue, diff_tool, skip_tool
  7. import config_analysis
  8. import os
  9. import json
  10. import time
  11. # 打印信息确认
  12. print(f"""
  13. 查询数据库:ws_data
  14. 查询表:{config_analysis.DB_SHEET_NAME}
  15. 起始日期:{config_analysis.CHECK_YEAR_START}-{config_analysis.CHECK_MONTH_START}-{config_analysis.CHECK_DAY_START}
  16. 终止日期:{config_analysis.CHECK_YEAR_END}-{config_analysis.CHECK_MONTH_END}-{config_analysis.CHECK_DAY_END}
  17. 项目ID:{config_analysis.PROJECT_ID}
  18. """)
  19. time.sleep(6)
  20. # 创建数据库参数
  21. db_param = DatabaseParam(
  22. db_host='192.168.50.4',
  23. db_user='root',
  24. db_password='*B-@p2b+97D5xAF1e6',
  25. db_name='ws_data',
  26. db_port=4000)
  27. # 存储总数量
  28. total_name_list = []
  29. total_code_list = []
  30. # 数据库操作应在内部,Database定义了上下文管理器,负责自动释放连接和游标
  31. with Database(db_param) as db:
  32. # 排除常数序列
  33. # 选择从文件加载
  34. if os.path.exists(config_analysis.TOTAL_LIST_JSON_FILE):
  35. with open(config_analysis.TOTAL_LIST_JSON_FILE, "r", encoding="utf-8") as f:
  36. loaded_data = json.load(f)
  37. print(f'从文件{config_analysis.TOTAL_LIST_JSON_FILE}中加载待分析列表...')
  38. total_name_list = loaded_data['total_name_list']
  39. total_code_list = loaded_data['total_code_list']
  40. # 文件不存在进行及时分析
  41. else:
  42. for lab in label_queue():
  43. time_series_name = lab.get('name')
  44. time_series_code = lab.get('code')
  45. df = db.query_sql_time_series2data_frame(project_id=config_analysis.PROJECT_ID,
  46. sheet_name=config_analysis.DB_SHEET_NAME,
  47. data_code=time_series_code,
  48. start_year=config_analysis.CHECK_YEAR_START, end_year = config_analysis.CHECK_YEAR_END,
  49. start_month=config_analysis.CHECK_MONTH_START, end_month=config_analysis.CHECK_MONTH_END,
  50. start_day=config_analysis.CHECK_DAY_START, end_day=config_analysis.CHECK_DAY_END)
  51. if df is None:
  52. continue
  53. # 过滤常数序列
  54. if df[df.columns[1]].nunique() <= 2:
  55. print(f'过滤常数序列{time_series_name}({time_series_code})!')
  56. continue
  57. else:
  58. total_name_list.append(time_series_name)
  59. total_code_list.append(time_series_code)
  60. # 保存文件
  61. saved_data = {
  62. 'total_name_list': total_name_list,
  63. 'total_code_list': total_code_list,
  64. }
  65. with open(config_analysis.TOTAL_LIST_JSON_FILE, "w", encoding="utf-8") as f:
  66. json.dump(saved_data, f, ensure_ascii=False, indent=4)
  67. print(f'分析列表保存到{config_analysis.TOTAL_LIST_JSON_FILE}')
  68. # 存储所有计算结果
  69. result = []
  70. """
  71. result: [dict, dict, ...]
  72. dict格式:
  73. {
  74. 'A':{'name':,'code':},
  75. 'B':{'name':,'code':},
  76. 'res':[{'k':值,'r':值,'p':值},...]:
  77. }
  78. """
  79. # 寻找需要分析的数据,应该从文件中读取字段
  80. # 序列A
  81. for a_idx in range(0, len(total_code_list), 1):
  82. time_series_a_name = total_name_list[a_idx]
  83. time_series_a_code = total_code_list[a_idx]
  84. # 获取A列
  85. df_a = db.query_sql_time_series2data_frame(project_id=config_analysis.PROJECT_ID,
  86. sheet_name=config_analysis.DB_SHEET_NAME,
  87. data_code=time_series_a_code,
  88. start_year=config_analysis.CHECK_YEAR_START,
  89. end_year=config_analysis.CHECK_YEAR_END,
  90. start_month=config_analysis.CHECK_MONTH_START,
  91. end_month=config_analysis.CHECK_MONTH_END,
  92. start_day=config_analysis.CHECK_DAY_START,
  93. end_day=config_analysis.CHECK_DAY_END)
  94. if df_a is None:
  95. continue
  96. # 过滤常数序列
  97. if df_a[df_a.columns[1]].nunique() <= 2 :
  98. print(f'序列A.{time_series_a_name}({time_series_a_code})遇到常数列, 跳过计算!')
  99. continue
  100. # 平稳化
  101. df_a = diff_tool(time_series_a_name, df_a, df_a.columns[1])
  102. # 序列B
  103. for b_idx in range(a_idx, len(total_code_list), 1):
  104. time_series_b_name = total_name_list[b_idx]
  105. time_series_b_code = total_code_list[b_idx]
  106. if skip_tool(time_series_a_name, time_series_b_name):
  107. print(f'跳过组合:{time_series_a_name} vs. {time_series_b_name}')
  108. continue
  109. # 获取B列
  110. df_b = db.query_sql_time_series2data_frame(project_id=config_analysis.PROJECT_ID,
  111. sheet_name=config_analysis.DB_SHEET_NAME,
  112. data_code=time_series_b_code,
  113. start_year=config_analysis.CHECK_YEAR_START,
  114. end_year=config_analysis.CHECK_YEAR_END,
  115. start_month=config_analysis.CHECK_MONTH_START,
  116. end_month=config_analysis.CHECK_MONTH_END,
  117. start_day=config_analysis.CHECK_DAY_START,
  118. end_day=config_analysis.CHECK_DAY_END)
  119. if df_b is None:
  120. continue
  121. # if abs(len(df_a) - len(df_b)) > 20: raise ValueError('时序数据数量差异过大:len(A), len(B)', len(df_a),
  122. # len(df_b))
  123. # 过滤常数序列,有一些数列为常数,这些数据方差接近0,无法计算协方差
  124. if df_b[df_b.columns[1]].nunique() <= 2:
  125. print(f'序列B.{time_series_b_name}({time_series_b_code})遇到常数列, 跳过计算!')
  126. continue
  127. # 平稳化,根据name筛选出需要平稳化的数据,进行一阶差分
  128. df_b = diff_tool(time_series_b_name, df_b, df_b.columns[1])
  129. # 融合AB序列
  130. df_merge = pd.merge(df_a, df_b, how='inner', on='time').sort_values('time', kind='mergesort')
  131. _, time_series_a_column, time_series_b_column = df_merge.columns
  132. # 互相关分析
  133. series_a = df_merge[time_series_a_column]
  134. series_b = df_merge[time_series_b_column]
  135. lags = config_analysis.MAX_LAG # 最大滞后
  136. step = 1
  137. print(f'正在进行互相关性分析:A.{time_series_a_name}({time_series_a_code}) | B.{time_series_b_name}({time_series_b_code}) ')
  138. tem_dict = {'A': {'name': time_series_a_name, 'code': time_series_a_code},
  139. 'B': {'name': time_series_b_name, 'code': time_series_b_code},
  140. 'res':[]}
  141. for lag in range(-lags, lags, step):
  142. if lag < 0: # a滞后于b
  143. series_a_shifted = series_a[-lag:]
  144. series_b_shifted = series_b[:lag]
  145. elif lag > 0: # b滞后于a
  146. series_a_shifted = series_a[:-lag]
  147. series_b_shifted = series_b[lag:]
  148. elif lag == 0: # 0滞后
  149. series_a_shifted = series_a
  150. series_b_shifted = series_b
  151. # 计算皮尔逊系数和显著性
  152. if len(series_a_shifted) < 24 or len(series_b_shifted) < 24:
  153. print('skip')
  154. continue
  155. r, p_value = stats.pearsonr(series_a_shifted, series_b_shifted)
  156. # 过滤不显著的数据
  157. if p_value > config_analysis.P_VALUE_THRESHOLD:
  158. continue
  159. if abs(r) < config_analysis.R_THRESHOLD:
  160. continue
  161. tem_dict.get('res').append({'k':lag, 'r':r, 'p':p_value})
  162. # if lag < 0:
  163. # print(f'A滞后B {abs(lag)}个单位时间, k={lag}, r={r:.4f}, 显著性p={p_value:.4f}')
  164. # elif lag > 0:
  165. # print(f'B滞后A {abs(lag)}个单位时间, k={lag}, r={r:.4f}, 显著性p={p_value:.4f}')
  166. # else:
  167. # print(f'A与B无滞后, k={lag}, r={r:.4f}, 显著性p={p_value:.6f}')
  168. if 0 < len(tem_dict.get('res')): result.append(tem_dict)
  169. print(f'计算完成,结果总数量为:{len(result)}')
  170. # 将结果保存到文件
  171. if os.path.exists(config_analysis.OUTPUT_JSON_FILE):
  172. print(f'删除旧文件{config_analysis.OUTPUT_JSON_FILE}')
  173. os.remove(config_analysis.OUTPUT_JSON_FILE)
  174. data = {'data': result, 'len': len(result), 'r_threshold': config_analysis.R_THRESHOLD, 'p_threshold': config_analysis.P_VALUE_THRESHOLD}
  175. with open(config_analysis.OUTPUT_JSON_FILE, 'w', encoding="utf-8") as f:
  176. json.dump(data, f, ensure_ascii=False, indent=4)
  177. print(f'数据保存完成,{config_analysis.OUTPUT_JSON_FILE}')