test.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # -*- coding: utf-8 -*-
  2. """test.py: 在线诊断接口"""
  3. import os
  4. import pandas as pd
  5. import numpy as np
  6. import torch
  7. import argparse
  8. from config import config
  9. class WaterPlantDiagnoser:
  10. def __init__(self):
  11. from data_processing import DataAnomalyProcessor
  12. from causal_structure import CausalStructureBuilder
  13. from rl_tracing import RLTrainer, CausalTracingEnv
  14. self.processor = DataAnomalyProcessor()
  15. self.sensor_list = self.processor.sensor_list
  16. self.threshold_df = self.processor.threshold_df
  17. self.builder = CausalStructureBuilder(self.threshold_df)
  18. self.causal_graph = self.builder.build()
  19. dummy_scores = np.zeros((1, len(self.sensor_list)), dtype=np.float32)
  20. self.trainer = RLTrainer(self.causal_graph, dummy_scores, self.threshold_df)
  21. self.CausalTracingEnv = CausalTracingEnv # 缓存类供后面使用
  22. # 直接使用 config 里面拼好的路径
  23. if not os.path.exists(config.MODEL_FILE_PATH):
  24. print(f"[Warning] 未找到模型文件: {config.MODEL_FILE_PATH}。请先执行 main.py 进行训练。")
  25. else:
  26. state_dict = torch.load(config.MODEL_FILE_PATH, map_location=torch.device('cpu'), weights_only=True)
  27. self.trainer.model.load_state_dict(state_dict)
  28. print(f"[*] 成功加载模型: {config.MODEL_FILE_PATH}")
  29. self.trainer.model.eval()
  30. def _preprocess_dataframe(self, df_input):
  31. """
  32. 严格复现数据处理逻辑
  33. """
  34. try:
  35. df = df_input.copy()
  36. # 1. 时间解析
  37. time_col = df.columns[0]
  38. df[time_col] = pd.to_datetime(df[time_col], format='mixed', errors='coerce')
  39. df = df.dropna(subset=[time_col]).set_index(time_col)
  40. # 2. 筛选列
  41. valid_cols = [c for c in df.columns if c in self.sensor_list]
  42. if not valid_cols: return None
  43. df_valid = df[valid_cols]
  44. # 3. 降采样 (4s -> 20s)
  45. df_resampled = df_valid.resample(f"{config.TARGET_SAMPLE_INTERVAL}s").mean()
  46. df_resampled = df_resampled.astype(np.float32)
  47. return df_resampled
  48. except Exception as e:
  49. print(f"数据预处理错误: {e}")
  50. return None
  51. def api_predict(self, df_raw, mode='auto', manual_trigger=None):
  52. """
  53. 对外接口: 每次输入过去2小时的数据,取最新的40分钟进行诊断
  54. :param df_raw: 原始DataFrame数据
  55. :param mode: 'auto'为自动触发模式, 'manual'为手动指定触发模式
  56. :param manual_trigger: 当mode='manual'时生效,指定诱发变量的ID
  57. """
  58. # --- Layer 1: 数据预处理 ---
  59. df_resampled = self._preprocess_dataframe(df_raw)
  60. if df_resampled is None or df_resampled.empty:
  61. return {"status": "error", "msg": "数据预处理失败或数据为空"}
  62. # --- Layer 1: 异常得分计算 (包含绝对+MAD综合得分) ---
  63. scores_dict = {}
  64. for sensor in self.sensor_list:
  65. if sensor in df_resampled.columns:
  66. scores_dict[sensor] = self.processor._calculate_point_score_vectorized(
  67. df_resampled[sensor], sensor
  68. )
  69. else:
  70. # 严格对齐: 使用 np.nan 保持与训练完全一致
  71. scores_dict[sensor] = pd.Series(np.nan, index=df_resampled.index, dtype=np.float32)
  72. point_score_df = pd.DataFrame(scores_dict)[self.sensor_list]
  73. # --- Layer 1: 提取最新40分钟进行窗口聚合与原始数据均值计算 ---
  74. req_points = config.POINTS_PER_WINDOW # 40分钟 = 120个点 (20s间隔)
  75. total_points = len(point_score_df)
  76. if total_points < req_points:
  77. return {
  78. "status": "warning",
  79. "message": f"传入数据不足40分钟 (当前: {total_points*20}s, 需要: {req_points*20}s)"
  80. }
  81. # 切片:取最后40分钟的数据 (用于计算异常分) 和 原始数据 (用于计算偏差百分比)
  82. latest_40min_scores = point_score_df.iloc[-req_points:].values
  83. latest_40min_raw = df_resampled.iloc[-req_points:]
  84. # 计算这40分钟内原始传感器的均值,作为业务展示的参考值
  85. real_time_values = latest_40min_raw.mean(skipna=True)
  86. # 计算有效比例和 95 分位数
  87. valid_counts = np.sum(~np.isnan(latest_40min_scores), axis=0)
  88. valid_ratios = valid_counts / req_points
  89. current_window_scores = np.zeros(latest_40min_scores.shape[1], dtype=np.float32)
  90. valid_mask = valid_ratios >= config.VALID_DATA_RATIO
  91. if np.any(valid_mask):
  92. current_window_scores[valid_mask] = np.nanquantile(latest_40min_scores[:, valid_mask], 0.95, axis=0)
  93. current_window_scores = np.clip(current_window_scores, 0, 1)
  94. # --- Layer 2 & 3: 触发检测与溯源 ---
  95. results = []
  96. active_triggers = []
  97. # 根据模式判定触发列表
  98. if mode == 'auto':
  99. for t_name in config.TRIGGER_SENSORS:
  100. if t_name in self.causal_graph['sensor_to_idx']:
  101. idx = self.causal_graph['sensor_to_idx'][t_name]
  102. score = current_window_scores[idx]
  103. if score > config.TRIGGER_SCORE_THRESH:
  104. active_triggers.append((t_name, idx, score))
  105. elif mode == 'manual':
  106. if not manual_trigger:
  107. return {"status": "error", "message": "手动模式下必须提供 manual_trigger 参数"}
  108. if manual_trigger not in self.causal_graph['sensor_to_idx']:
  109. return {"status": "error", "message": f"指定的诱发变量 {manual_trigger} 不在设备/传感器列表中"}
  110. # 手动模式:强制加入触发列表,不检测得分阈值
  111. idx = self.causal_graph['sensor_to_idx'][manual_trigger]
  112. score = current_window_scores[idx]
  113. active_triggers.append((manual_trigger, idx, score))
  114. else:
  115. return {"status": "error", "message": f"未知的诊断模式: {mode}"}
  116. if not active_triggers:
  117. return {
  118. "status": "normal",
  119. "message": "系统运行正常",
  120. }
  121. # 构建临时环境进行溯源
  122. env_scores = current_window_scores.reshape(1, -1)
  123. temp_env = self.CausalTracingEnv(self.causal_graph, env_scores, self.threshold_df, self.trainer.expert_knowledge)
  124. for t_name, t_idx, t_score in active_triggers:
  125. state_data = temp_env.reset(force_window_idx=0, force_trigger=t_name)
  126. path_idxs = [t_idx]
  127. done = False
  128. while not done:
  129. s_int = state_data[0].unsqueeze(0)
  130. s_float = state_data[1].unsqueeze(0)
  131. valid = temp_env.get_valid_actions(path_idxs[-1])
  132. if len(valid) == 0: break
  133. logits, _ = self.trainer.model(s_int, s_float)
  134. mask = torch.full_like(logits, -1e9)
  135. mask[0, valid] = 0
  136. act = torch.argmax(logits + mask, dim=1).item()
  137. state_data, _, done, _ = temp_env.step(act)
  138. path_idxs.append(act)
  139. if len(path_idxs) >= config.MAX_PATH_LENGTH: done = True
  140. path_names = [self.trainer.idx_to_sensor[i] for i in path_idxs]
  141. # --- 1 & 2:计算链路有效性与双重偏差量化 ---
  142. path_details = []
  143. abnormal_other_nodes_count = 0
  144. for node_idx, node_name in zip(path_idxs, path_names):
  145. node_score = current_window_scores[node_idx]
  146. is_node_abnormal = node_score > config.WINDOW_ANOMALY_THRESHOLD
  147. # 若不是诱发变量本身且达到异常标准,计数+1
  148. if node_idx != t_idx and is_node_abnormal:
  149. abnormal_other_nodes_count += 1
  150. # 1. 获取当前参考值
  151. raw_val = real_time_values.get(node_name, np.nan)
  152. # 2. 实时回溯计算该节点的动态基线与上下限
  153. dyn_lower, dyn_upper, dyn_med = np.nan, np.nan, np.nan
  154. if node_name in df_resampled.columns:
  155. vals = df_resampled[node_name].astype(np.float32)
  156. # 模拟滑动窗口还原 MAD 的历史状态
  157. r_med = vals.rolling(window=config.MAD_HISTORY_WINDOW, min_periods=1).median()
  158. r_mad = (vals - r_med).abs().rolling(window=config.MAD_HISTORY_WINDOW, min_periods=1).median()
  159. # 提取时刻末尾的值作为基线
  160. last_med = r_med.iloc[-1]
  161. last_mad = r_mad.iloc[-1]
  162. dyn_lower = last_med - config.MAD_THRESHOLD * last_mad
  163. dyn_upper = last_med + config.MAD_THRESHOLD * last_mad
  164. dyn_med = last_med
  165. # 3. 分别构建【物理范围】与【动态范围】的显示文字
  166. phy_str = "未定义"
  167. dyn_str = "未定义"
  168. node_name_zh = "未知"
  169. # 提取中文名称
  170. if node_name in self.processor.threshold_dict:
  171. info = self.processor.threshold_dict[node_name]
  172. node_name_zh = info.get('Name', '未知')
  173. if not pd.isna(raw_val) and node_name in self.processor.threshold_dict:
  174. info = self.processor.threshold_dict[node_name]
  175. g_min, g_max = info['Good_min'], info['Good_max']
  176. # -- 物理范围判定 --
  177. if raw_val > g_max and g_max != np.inf:
  178. pct = (raw_val - g_max) / (abs(g_max) + 1e-5) * 100
  179. phy_str = f"偏高 {pct:.1f}% (物理允许上限: {g_max})"
  180. elif raw_val < g_min and g_min != -np.inf:
  181. pct = (g_min - raw_val) / (abs(g_min) + 1e-5) * 100
  182. phy_str = f"偏低 {pct:.1f}% (物理允许下限: {g_min})"
  183. else:
  184. phy_str = f"正常 (物理范围: [{g_min}, {g_max}])"
  185. # -- 动态范围判定 --
  186. if not pd.isna(dyn_lower) and not pd.isna(dyn_upper):
  187. if raw_val > dyn_upper:
  188. dyn_str = f"异常突增 (近期基线: {dyn_med:.2f}, 动态上限: {dyn_upper:.2f})"
  189. elif raw_val < dyn_lower:
  190. dyn_str = f"异常突降 (近期基线: {dyn_med:.2f}, 动态下限: {dyn_lower:.2f})"
  191. else:
  192. dyn_str = f"平稳波动 (近期基线: {dyn_med:.2f}, 动态区间: [{dyn_lower:.2f}, {dyn_upper:.2f}])"
  193. # 组装为清晰的结构化文本
  194. dev_str = f"当前值: {raw_val:.2f} | 物理工况: {phy_str} | 动态趋势: {dyn_str}"
  195. path_details.append({
  196. "node": node_name,
  197. "name": node_name_zh,
  198. "anomaly_score": round(float(node_score), 4),
  199. "is_abnormal": bool(is_node_abnormal),
  200. "deviation": dev_str
  201. })
  202. # 判断是否将本链路加入返回结果
  203. # 自动模式下要求:除诱发变量外,至少有 >= 1 个额外节点异常
  204. # 手动模式下要求:无视拦截规则,强制返回链路用于排查
  205. if mode == 'manual' or abnormal_other_nodes_count >= 1:
  206. results.append({
  207. "trigger": t_name,
  208. "path": " -> ".join(path_names),
  209. "root_cause": path_names[-1],
  210. "details": path_details
  211. })
  212. # 若所有触发链路均不满足 >= 1 个额外异常节点的条件 (仅在auto模式下发生)
  213. if not results:
  214. return {
  215. "status": "normal",
  216. "message": "系统运行正常",
  217. }
  218. return {
  219. "status": "manual_traced" if mode == 'manual' else "abnormal",
  220. "results": results
  221. }
  222. if __name__ == "__main__":
  223. # 模拟外部调用机制:每次固定传 2 小时的数据
  224. parser = argparse.ArgumentParser(description="水厂在线诊断测试")
  225. parser.add_argument('-p', '--plant', type=str, required=True, help="测试的水厂名称")
  226. args = parser.parse_args()
  227. config.load(args.plant)
  228. diagnoser = WaterPlantDiagnoser()
  229. test_file = os.path.join(config.DATASET_SENSOR_DIR, f"{config.SENSOR_FILE_PREFIX}1.csv")
  230. if os.path.exists(test_file):
  231. print(">>> 正在启动在线诊断引擎测试 (模式:读2小时,查末尾40分钟)...")
  232. # 2小时 = 7200秒 = 1800 行原始数据
  233. CHUNK_SIZE = 1800
  234. df_full = pd.read_csv(test_file, low_memory=False)
  235. if len(df_full) >= CHUNK_SIZE * 3:
  236. import json
  237. # 模拟 1:传入 0~2 小时的数据(自动模式测试)
  238. df_sim1 = df_full.iloc[0 : CHUNK_SIZE]
  239. print("\n[测试 1] 传入 0~2 小时数据 (自动模式):")
  240. print(json.dumps(diagnoser.api_predict(df_sim1, mode='auto'), ensure_ascii=False, indent=2))
  241. # 模拟 2:传入 2~4 小时的数据(手动模式测试)
  242. df_sim2 = df_full.iloc[CHUNK_SIZE : CHUNK_SIZE * 2]
  243. print("\n[测试 2] 传入 2~4 小时数据 (手动模式,指定触发):")
  244. # 从 config 自动抓取一个诱发变量,或者固定想测试的,如 'UF1Per'
  245. test_trigger = config.TRIGGER_SENSORS[0] if len(config.TRIGGER_SENSORS) > 0 else "Unknown"
  246. print(json.dumps(diagnoser.api_predict(df_sim2, mode='manual', manual_trigger=test_trigger), ensure_ascii=False, indent=2))
  247. # 模拟 3:传入末尾 2 小时的数据(自动模式测试)
  248. df_sim3 = df_full.iloc[-CHUNK_SIZE:]
  249. print("\n[测试 3] 传入 4~6 小时数据 (自动模式):")
  250. print(json.dumps(diagnoser.api_predict(df_sim3, mode='auto'), ensure_ascii=False, indent=2))
  251. else:
  252. import json
  253. print("\n数据量不足以做三次两小时模拟,直接进行单次全量手动模拟:")
  254. test_trigger = config.TRIGGER_SENSORS[0] if len(config.TRIGGER_SENSORS) > 0 else "Unknown"
  255. print(json.dumps(diagnoser.api_predict(df_full, mode='manual', manual_trigger=test_trigger), ensure_ascii=False, indent=2))
  256. else:
  257. print(f"测试文件 {test_file} 不存在,无法执行本地测试。")