test.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. # -*- coding: utf-8 -*-
  2. """test.py: 在线诊断接口"""
  3. import os
  4. import pandas as pd
  5. import numpy as np
  6. import torch
  7. from config import config
  8. from data_processing import DataAnomalyProcessor
  9. from causal_structure import CausalStructureBuilder
  10. from rl_tracing import RLTrainer, CausalTracingEnv
  11. class WaterPlantDiagnoser:
  12. def __init__(self):
  13. # 1. 初始化数据处理器 (用于加载阈值和计算得分,异常表征逻辑与训练完全一致)
  14. self.processor = DataAnomalyProcessor()
  15. self.sensor_list = self.processor.sensor_list
  16. self.threshold_df = self.processor.threshold_df
  17. # 2. 构建因果图 (Layer 2)
  18. self.builder = CausalStructureBuilder(self.threshold_df)
  19. self.causal_graph = self.builder.build()
  20. # 3. 加载强化学习模型 (Layer 3)
  21. dummy_scores = np.zeros((1, len(self.sensor_list)), dtype=np.float32)
  22. self.trainer = RLTrainer(self.causal_graph, dummy_scores, self.threshold_df)
  23. model_path = os.path.join(config.MODEL_SAVE_DIR, "ppo_tracing_model.pth")
  24. if not os.path.exists(model_path):
  25. print(f"[Warning] 未找到模型文件: {model_path}。如果是测试环境,请确保已有预训练模型。")
  26. else:
  27. state_dict = torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)
  28. self.trainer.model.load_state_dict(state_dict)
  29. self.trainer.model.eval()
  30. def _preprocess_dataframe(self, df_input):
  31. """
  32. 严格复现数据处理逻辑
  33. """
  34. try:
  35. df = df_input.copy()
  36. # 1. 时间解析
  37. time_col = df.columns[0]
  38. df[time_col] = pd.to_datetime(df[time_col], format='mixed', errors='coerce')
  39. df = df.dropna(subset=[time_col]).set_index(time_col)
  40. # 2. 筛选列
  41. valid_cols = [c for c in df.columns if c in self.sensor_list]
  42. if not valid_cols: return None
  43. df_valid = df[valid_cols]
  44. # 3. 降采样 (4s -> 20s)
  45. df_resampled = df_valid.resample(f"{config.TARGET_SAMPLE_INTERVAL}s").mean()
  46. df_resampled = df_resampled.astype(np.float32)
  47. return df_resampled
  48. except Exception as e:
  49. print(f"数据预处理错误: {e}")
  50. return None
  51. def api_predict(self, df_raw):
  52. """
  53. 对外接口: 每次输入过去2小时的数据,取最新的40分钟进行诊断
  54. """
  55. # --- Layer 1: 数据预处理 ---
  56. df_resampled = self._preprocess_dataframe(df_raw)
  57. if df_resampled is None or df_resampled.empty:
  58. return {"status": "error", "msg": "数据预处理失败或数据为空"}
  59. # --- Layer 1: 异常得分计算 (包含绝对+MAD综合得分) ---
  60. scores_dict = {}
  61. for sensor in self.sensor_list:
  62. if sensor in df_resampled.columns:
  63. scores_dict[sensor] = self.processor._calculate_point_score_vectorized(
  64. df_resampled[sensor], sensor
  65. )
  66. else:
  67. # 严格对齐: 使用 np.nan 保持与训练完全一致
  68. scores_dict[sensor] = pd.Series(np.nan, index=df_resampled.index, dtype=np.float32)
  69. point_score_df = pd.DataFrame(scores_dict)[self.sensor_list]
  70. # --- Layer 1: 提取最新40分钟进行窗口聚合与原始数据均值计算 ---
  71. req_points = config.POINTS_PER_WINDOW # 40分钟 = 120个点 (20s间隔)
  72. total_points = len(point_score_df)
  73. if total_points < req_points:
  74. return {
  75. "status": "warning",
  76. "message": f"传入数据不足40分钟 (当前: {total_points*20}s, 需要: {req_points*20}s)"
  77. }
  78. # 切片:取最后40分钟的数据 (用于计算异常分) 和 原始数据 (用于计算偏差百分比)
  79. latest_40min_scores = point_score_df.iloc[-req_points:].values
  80. latest_40min_raw = df_resampled.iloc[-req_points:]
  81. # 计算这40分钟内原始传感器的均值,作为业务展示的参考值
  82. real_time_values = latest_40min_raw.mean(skipna=True)
  83. # 计算有效比例和 95 分位数
  84. valid_counts = np.sum(~np.isnan(latest_40min_scores), axis=0)
  85. valid_ratios = valid_counts / req_points
  86. current_window_scores = np.zeros(latest_40min_scores.shape[1], dtype=np.float32)
  87. valid_mask = valid_ratios >= config.VALID_DATA_RATIO
  88. if np.any(valid_mask):
  89. current_window_scores[valid_mask] = np.nanquantile(latest_40min_scores[:, valid_mask], 0.95, axis=0)
  90. current_window_scores = np.clip(current_window_scores, 0, 1)
  91. # --- Layer 2 & 3: 触发检测与溯源 ---
  92. results = []
  93. active_triggers = []
  94. for t_name in config.TRIGGER_SENSORS:
  95. if t_name in self.causal_graph['sensor_to_idx']:
  96. idx = self.causal_graph['sensor_to_idx'][t_name]
  97. score = current_window_scores[idx]
  98. if score > config.TRIGGER_SCORE_THRESH:
  99. active_triggers.append((t_name, idx, score))
  100. if not active_triggers:
  101. return {
  102. "status": "normal",
  103. "message": "系统运行正常",
  104. }
  105. # 构建临时环境进行溯源
  106. env_scores = current_window_scores.reshape(1, -1)
  107. temp_env = CausalTracingEnv(self.causal_graph, env_scores, self.threshold_df, self.trainer.expert_knowledge)
  108. for t_name, t_idx, t_score in active_triggers:
  109. state_data = temp_env.reset(force_window_idx=0, force_trigger=t_name)
  110. path_idxs = [t_idx]
  111. done = False
  112. while not done:
  113. s_int = state_data[0].unsqueeze(0)
  114. s_float = state_data[1].unsqueeze(0)
  115. valid = temp_env.get_valid_actions(path_idxs[-1])
  116. if len(valid) == 0: break
  117. logits, _ = self.trainer.model(s_int, s_float)
  118. mask = torch.full_like(logits, -1e9)
  119. mask[0, valid] = 0
  120. act = torch.argmax(logits + mask, dim=1).item()
  121. state_data, _, done, _ = temp_env.step(act)
  122. path_idxs.append(act)
  123. if len(path_idxs) >= config.MAX_PATH_LENGTH: done = True
  124. path_names = [self.trainer.idx_to_sensor[i] for i in path_idxs]
  125. # --- 1 & 2:计算链路有效性与双重偏差量化 ---
  126. path_details = []
  127. abnormal_other_nodes_count = 0
  128. for node_idx, node_name in zip(path_idxs, path_names):
  129. node_score = current_window_scores[node_idx]
  130. is_node_abnormal = node_score > config.WINDOW_ANOMALY_THRESHOLD
  131. # 若不是诱发变量本身且达到异常标准,计数+1
  132. if node_idx != t_idx and is_node_abnormal:
  133. abnormal_other_nodes_count += 1
  134. # 1. 获取当前参考值
  135. raw_val = real_time_values.get(node_name, np.nan)
  136. # 2. 实时回溯计算该节点的动态基线与上下限
  137. dyn_lower, dyn_upper, dyn_med = np.nan, np.nan, np.nan
  138. if node_name in df_resampled.columns:
  139. vals = df_resampled[node_name].astype(np.float32)
  140. # 模拟滑动窗口还原 MAD 的历史状态
  141. r_med = vals.rolling(window=config.MAD_HISTORY_WINDOW, min_periods=1).median()
  142. r_mad = (vals - r_med).abs().rolling(window=config.MAD_HISTORY_WINDOW, min_periods=1).median()
  143. # 提取时刻末尾的值作为基线
  144. last_med = r_med.iloc[-1]
  145. last_mad = r_mad.iloc[-1]
  146. dyn_lower = last_med - config.MAD_THRESHOLD * last_mad
  147. dyn_upper = last_med + config.MAD_THRESHOLD * last_mad
  148. dyn_med = last_med
  149. # 3. 分别构建【物理范围】与【动态范围】的显示文字
  150. phy_str = "未定义"
  151. dyn_str = "未定义"
  152. node_name_zh = "未知"
  153. # 提取中文名称
  154. if node_name in self.processor.threshold_dict:
  155. info = self.processor.threshold_dict[node_name]
  156. node_name_zh = info.get('Name', '未知')
  157. if not pd.isna(raw_val) and node_name in self.processor.threshold_dict:
  158. info = self.processor.threshold_dict[node_name]
  159. g_min, g_max = info['Good_min'], info['Good_max']
  160. # -- 物理范围判定 --
  161. if raw_val > g_max and g_max != np.inf:
  162. pct = (raw_val - g_max) / (abs(g_max) + 1e-5) * 100
  163. phy_str = f"偏高 {pct:.1f}% (物理允许上限: {g_max})"
  164. elif raw_val < g_min and g_min != -np.inf:
  165. pct = (g_min - raw_val) / (abs(g_min) + 1e-5) * 100
  166. phy_str = f"偏低 {pct:.1f}% (物理允许下限: {g_min})"
  167. else:
  168. phy_str = f"正常 (物理范围: [{g_min}, {g_max}])"
  169. # -- 动态范围判定 --
  170. if not pd.isna(dyn_lower) and not pd.isna(dyn_upper):
  171. if raw_val > dyn_upper:
  172. dyn_str = f"异常突增 (近期基线: {dyn_med:.2f}, 动态上限: {dyn_upper:.2f})"
  173. elif raw_val < dyn_lower:
  174. dyn_str = f"异常突降 (近期基线: {dyn_med:.2f}, 动态下限: {dyn_lower:.2f})"
  175. else:
  176. dyn_str = f"平稳波动 (近期基线: {dyn_med:.2f}, 动态区间: [{dyn_lower:.2f}, {dyn_upper:.2f}])"
  177. # 组装为清晰的结构化文本
  178. dev_str = f"当前值: {raw_val:.2f} | 物理工况: {phy_str} | 动态趋势: {dyn_str}"
  179. path_details.append({
  180. "node": node_name,
  181. "name": node_name_zh,
  182. "anomaly_score": round(float(node_score), 4),
  183. "is_abnormal": bool(is_node_abnormal),
  184. "deviation": dev_str
  185. })
  186. # 除诱发变量外,只要有 >= 1 个节点异常即触发溯源报警
  187. if abnormal_other_nodes_count >= 1:
  188. results.append({
  189. "trigger": t_name,
  190. "path": " -> ".join(path_names),
  191. "root_cause": path_names[-1],
  192. "details": path_details
  193. })
  194. # 若所有触发链路均不满足 >= 1 个额外异常节点的条件
  195. if not results:
  196. return {
  197. "status": "normal",
  198. "message": "系统运行正常",
  199. }
  200. return {
  201. "status": "abnormal",
  202. "results": results
  203. }
  204. if __name__ == "__main__":
  205. # 模拟外部调用机制:每次固定传 2 小时的数据
  206. test_file = os.path.join(config.DATASET_SENSOR_DIR, f"{config.SENSOR_FILE_PREFIX}1.csv")
  207. if os.path.exists(test_file):
  208. print(">>> 正在启动在线诊断引擎测试 (模式:读2小时,查末尾40分钟)...")
  209. diagnoser = WaterPlantDiagnoser()
  210. # 假设原始数据采样率为 4 秒一次
  211. # 2小时 = 7200秒 = 1800 行原始数据
  212. CHUNK_SIZE = 1800
  213. df_full = pd.read_csv(test_file, low_memory=False)
  214. if len(df_full) >= CHUNK_SIZE * 3:
  215. import json
  216. # 模拟 1:传入 0~2 小时的数据
  217. df_sim1 = df_full.iloc[0 : CHUNK_SIZE]
  218. print("\n[测试 1] 传入 0~2 小时数据:")
  219. print(json.dumps(diagnoser.api_predict(df_sim1), ensure_ascii=False, indent=2))
  220. # 模拟 2:传入 2~4 小时的数据
  221. df_sim2 = df_full.iloc[CHUNK_SIZE : CHUNK_SIZE * 2]
  222. print("\n[测试 2] 传入 2~4 小时数据:")
  223. print(json.dumps(diagnoser.api_predict(df_sim2), ensure_ascii=False, indent=2))
  224. # 模拟 3:传入末尾 2 小时的数据
  225. df_sim3 = df_full.iloc[-CHUNK_SIZE:]
  226. print("\n[测试 3] 传入 4~6 小时数据:")
  227. print(json.dumps(diagnoser.api_predict(df_sim3), ensure_ascii=False, indent=2))
  228. else:
  229. import json
  230. print("\n数据量不足以做三次两小时模拟,直接进行单次全量模拟:")
  231. print(json.dumps(diagnoser.api_predict(df_full), ensure_ascii=False, indent=2))
  232. else:
  233. print(f"测试文件 {test_file} 不存在,无法执行本地测试。")