rl_tracing.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. # -*- coding: utf-8 -*-
  2. """rl_tracing.py: 强化学习链路级异常溯源"""
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. import torch.nn.functional as F
  7. from torch.distributions import Categorical
  8. import numpy as np
  9. import pandas as pd
  10. import os
  11. from tqdm import tqdm
  12. from config import config
  13. # ----------------- 1. 环境 -----------------
  14. class CausalTracingEnv:
  15. def __init__(self, causal_graph, window_scores, threshold_df, expert_knowledge=None):
  16. self.sensor_list = causal_graph['sensor_list']
  17. self.map = causal_graph['sensor_to_idx']
  18. self.idx_to_sensor = {v: k for k, v in self.map.items()}
  19. self.adj = causal_graph['adj_matrix']
  20. self.scores = window_scores
  21. self.expert_knowledge = expert_knowledge if expert_knowledge else {}
  22. self.num_sensors = len(self.sensor_list)
  23. # 解析属性
  24. self.node_props = {}
  25. col_one_layer = self._find_col(threshold_df, config.KEYWORD_LAYER)
  26. col_device = self._find_col(threshold_df, config.KEYWORD_DEVICE)
  27. df_indexed = threshold_df.set_index('ID')
  28. dict_one = df_indexed[col_one_layer].to_dict() if col_one_layer else {}
  29. dict_dev = df_indexed[col_device].to_dict() if col_device else {}
  30. for name, idx in self.map.items():
  31. l_val = dict_one.get(name, -1)
  32. try: l_val = int(l_val)
  33. except: l_val = 0
  34. d_val = dict_dev.get(name, None)
  35. d_val = str(d_val).strip() if pd.notna(d_val) and str(d_val).strip() != '' else None
  36. self.node_props[idx] = {'one_layer': l_val, 'device': d_val}
  37. self.current_window_idx = 0
  38. self.current_node_idx = 0
  39. self.prev_node_idx = 0
  40. self.trigger_node_idx = 0
  41. self.path = []
  42. self.current_expert_paths = []
  43. self.target_roots = set()
  44. def _find_col(self, df, keyword):
  45. if keyword in df.columns: return keyword
  46. for c in df.columns:
  47. if c.lower() == keyword.lower(): return c
  48. return None
  49. def reset(self, force_window_idx=None, force_trigger=None):
  50. if force_window_idx is not None:
  51. self.current_window_idx = force_window_idx
  52. t_name = force_trigger
  53. else:
  54. found = False
  55. for _ in range(100):
  56. w_idx = np.random.randint(len(self.scores))
  57. win_scores = self.scores[w_idx]
  58. candidates = []
  59. for t_name in config.TRIGGER_SENSORS:
  60. if t_name in self.map:
  61. idx = self.map[t_name]
  62. if win_scores[idx] > config.TRIGGER_SCORE_THRESH:
  63. candidates.append(t_name)
  64. if candidates:
  65. self.current_window_idx = w_idx
  66. t_name = np.random.choice(candidates)
  67. found = True
  68. break
  69. if not found:
  70. self.current_window_idx = np.random.randint(len(self.scores))
  71. t_name = list(self.map.keys())[0]
  72. self.current_node_idx = self.map.get(t_name, 0)
  73. self.trigger_node_idx = self.current_node_idx
  74. self.prev_node_idx = self.current_node_idx
  75. self.path = [self.current_node_idx]
  76. self.target_roots = set()
  77. self.current_expert_paths = []
  78. if self.current_node_idx in self.expert_knowledge:
  79. entry = self.expert_knowledge[self.current_node_idx]
  80. self.target_roots = entry['roots']
  81. self.current_expert_paths = entry['paths']
  82. return self._get_state()
  83. def _get_state(self):
  84. curr_score = self.scores[self.current_window_idx, self.current_node_idx]
  85. prev_score = self.scores[self.current_window_idx, self.prev_node_idx]
  86. curr_layer = self.node_props[self.current_node_idx]['one_layer'] / 20.0
  87. return (
  88. torch.LongTensor([self.current_node_idx, self.prev_node_idx, self.trigger_node_idx]),
  89. torch.FloatTensor([curr_score, prev_score, curr_layer])
  90. )
  91. def get_valid_actions(self, curr_idx):
  92. neighbors = np.where(self.adj[curr_idx] == 1)[0]
  93. curr_props = self.node_props[curr_idx]
  94. curr_l, curr_d = curr_props['one_layer'], curr_props['device']
  95. valid = []
  96. for n in neighbors:
  97. if n in self.path: continue
  98. tgt_props = self.node_props[n]
  99. tgt_l, tgt_d = tgt_props['one_layer'], tgt_props['device']
  100. if curr_l != 0 and tgt_l != 0:
  101. if not ((tgt_l == curr_l) or (tgt_l == curr_l - 1)): continue
  102. if (curr_d is not None) and (tgt_d is not None):
  103. if curr_d != tgt_d: continue
  104. valid.append(n)
  105. return np.array(valid)
  106. def step(self, action_idx):
  107. prev = self.current_node_idx
  108. self.prev_node_idx = prev
  109. self.current_node_idx = action_idx
  110. self.path.append(action_idx)
  111. score_curr = self.scores[self.current_window_idx, self.current_node_idx]
  112. reward = 0.0
  113. done = False
  114. # 奖励机制 (Imitation > Root > Gradient)
  115. in_expert_nodes = False
  116. for e_path in self.current_expert_paths:
  117. if action_idx in e_path:
  118. in_expert_nodes = True
  119. break
  120. if in_expert_nodes: reward += 2.0
  121. else: reward -= 0.2
  122. if action_idx in self.target_roots:
  123. reward += 10.0
  124. done = True
  125. score_prev = self.scores[self.current_window_idx, prev]
  126. diff = score_curr - score_prev
  127. if diff > 0: reward += diff * 3.0
  128. else: reward -= 0.5
  129. if len(self.path) >= config.MAX_PATH_LENGTH:
  130. done = True
  131. if action_idx not in self.target_roots: reward -= 5.0
  132. if score_curr < 0.15 and len(self.path) > 3:
  133. done = True
  134. reward -= 2.0
  135. return self._get_state(), reward, done, {}
  136. # ----------------- 2. 网络 -----------------
  137. class TargetDrivenActorCritic(nn.Module):
  138. def __init__(self, num_sensors, embedding_dim=64, hidden_dim=256):
  139. super().__init__()
  140. self.node_emb = nn.Embedding(num_sensors, embedding_dim)
  141. input_dim = (embedding_dim * 3) + 3
  142. self.shared_net = nn.Sequential(
  143. nn.Linear(input_dim, hidden_dim),
  144. nn.ReLU(),
  145. nn.LayerNorm(hidden_dim),
  146. nn.Linear(hidden_dim, hidden_dim),
  147. nn.ReLU()
  148. )
  149. self.actor = nn.Linear(hidden_dim, num_sensors)
  150. self.critic = nn.Linear(hidden_dim, 1)
  151. def forward(self, int_data, float_data):
  152. curr_emb = self.node_emb(int_data[:, 0])
  153. prev_emb = self.node_emb(int_data[:, 1])
  154. trig_emb = self.node_emb(int_data[:, 2])
  155. x = torch.cat([curr_emb, prev_emb, trig_emb, float_data], dim=1)
  156. feat = self.shared_net(x)
  157. return self.actor(feat), self.critic(feat)
  158. # ----------------- 3. 训练器 -----------------
  159. class RLTrainer:
  160. def __init__(self, causal_graph, train_scores, threshold_df):
  161. self.sensor_map = causal_graph['sensor_to_idx']
  162. self.idx_to_sensor = {v: k for k, v in self.sensor_map.items()}
  163. self.threshold_df = threshold_df
  164. self.causal_graph = causal_graph
  165. self.expert_knowledge, self.bc_samples, _ = self._load_expert_data()
  166. self.env = CausalTracingEnv(causal_graph, train_scores, threshold_df, self.expert_knowledge)
  167. self.model = TargetDrivenActorCritic(self.env.num_sensors, config.EMBEDDING_DIM, config.HIDDEN_DIM)
  168. self.optimizer = optim.Adam(self.model.parameters(), lr=config.PPO_LR)
  169. def _load_expert_data(self):
  170. path = os.path.join(config.BASE_DIR, config.ABNORMAL_LINK_FILENAME)
  171. kb_data = {}
  172. bc_data = []
  173. if not os.path.exists(path): return kb_data, bc_data, None
  174. df = pd.read_excel(path)
  175. for _, row in df.iterrows():
  176. link = str(row.get('Link Path', ''))
  177. if not link: continue
  178. nodes_str = [n.strip() for n in link.replace('→', '->').split('->')]
  179. path_nodes = nodes_str[::-1]
  180. ids = []
  181. valid = True
  182. for n in path_nodes:
  183. if n in self.sensor_map: ids.append(self.sensor_map[n])
  184. else: valid = False; break
  185. if not valid or len(ids)<2: continue
  186. trigger_id = ids[0]
  187. root_id = ids[-1]
  188. if trigger_id not in kb_data:
  189. kb_data[trigger_id] = {'paths': [], 'roots': set(), 'logic': row.get('Process Logic Basis', '')}
  190. kb_data[trigger_id]['paths'].append(ids)
  191. kb_data[trigger_id]['roots'].add(root_id)
  192. for i in range(len(ids) - 1):
  193. curr = ids[i]
  194. prev = ids[max(0, i-1)]
  195. nxt = ids[i+1]
  196. bc_data.append(((curr, prev, trigger_id), nxt))
  197. return kb_data, bc_data, df
  198. def pretrain_bc(self):
  199. if not self.bc_samples: return
  200. print(f"\n>>> [Step 3.1] 启动BC预训练 ({config.BC_EPOCHS}轮)...")
  201. states_int = torch.LongTensor([list(s) for s, a in self.bc_samples])
  202. actions = torch.LongTensor([a for s, a in self.bc_samples])
  203. states_float = torch.zeros((len(states_int), 3))
  204. states_float[:, 0] = 0.9
  205. states_float[:, 1] = 0.8
  206. loss_fn = nn.CrossEntropyLoss()
  207. pbar = tqdm(range(config.BC_EPOCHS), desc="BC Training")
  208. for epoch in pbar:
  209. logits, _ = self.model(states_int, states_float)
  210. loss = loss_fn(logits, actions)
  211. self.optimizer.zero_grad()
  212. loss.backward()
  213. self.optimizer.step()
  214. if epoch%100==0: pbar.set_postfix({'Loss': f"{loss.item():.4f}"})
  215. def train_ppo(self):
  216. print(f"\n>>> [Step 3.2] 启动PPO训练 ({config.RL_EPISODES}轮)...")
  217. pbar = tqdm(range(config.RL_EPISODES), desc="PPO Training")
  218. rewards_hist = []
  219. for _ in pbar:
  220. state_data = self.env.reset()
  221. done = False
  222. ep_r = 0
  223. b_int, b_float, b_act, b_lp, b_rew, b_mask = [], [], [], [], [], []
  224. while not done:
  225. s_int = state_data[0].unsqueeze(0)
  226. s_float = state_data[1].unsqueeze(0)
  227. valid = self.env.get_valid_actions(s_int[0, 0].item())
  228. if len(valid) == 0: break
  229. logits, _ = self.model(s_int, s_float)
  230. mask = torch.full_like(logits, -1e9)
  231. mask[0, valid] = 0
  232. dist = Categorical(F.softmax(logits+mask, dim=-1))
  233. action = dist.sample()
  234. next_s, r, done, _ = self.env.step(action.item())
  235. b_int.append(s_int); b_float.append(s_float)
  236. b_act.append(action); b_lp.append(dist.log_prob(action))
  237. b_rew.append(r); b_mask.append(1-done)
  238. state_data = next_s
  239. ep_r += r
  240. if len(b_rew) > 1:
  241. self._update_ppo(b_int, b_float, b_act, b_lp, b_rew, b_mask)
  242. rewards_hist.append(ep_r)
  243. if len(rewards_hist)>50: rewards_hist.pop(0)
  244. pbar.set_postfix({'AvgR': f"{np.mean(rewards_hist):.2f}"})
  245. def _update_ppo(self, b_int, b_float, b_act, b_lp, b_rew, b_mask):
  246. returns = []
  247. R = 0
  248. for r, m in zip(reversed(b_rew), reversed(b_mask)):
  249. R = r + config.PPO_GAMMA * R * m
  250. returns.insert(0, R)
  251. returns = torch.tensor(returns)
  252. if returns.numel() > 1 and returns.std() > 1e-5:
  253. returns = (returns - returns.mean()) / (returns.std() + 1e-5)
  254. elif returns.numel() > 1:
  255. returns = returns - returns.mean()
  256. s_int = torch.cat(b_int)
  257. s_float = torch.cat(b_float)
  258. act = torch.stack(b_act)
  259. old_lp = torch.stack(b_lp).detach()
  260. for _ in range(config.PPO_K_EPOCHS):
  261. logits, vals = self.model(s_int, s_float)
  262. dist = Categorical(logits=logits)
  263. new_lp = dist.log_prob(act)
  264. ratio = torch.exp(new_lp - old_lp)
  265. surr1 = ratio * returns
  266. surr2 = torch.clamp(ratio, 1-config.PPO_EPS_CLIP, 1+config.PPO_EPS_CLIP) * returns
  267. v_pred = vals.squeeze()
  268. if v_pred.shape != returns.shape:
  269. v_pred = v_pred.view(-1)
  270. returns = returns.view(-1)
  271. loss = -torch.min(surr1, surr2).mean() + 0.5 * F.mse_loss(v_pred, returns)
  272. self.optimizer.zero_grad()
  273. loss.backward()
  274. self.optimizer.step()
  275. def evaluate(self, test_scores):
  276. print("\n>>> [Step 4] 评估测试集...")
  277. self.model.eval()
  278. results = []
  279. cnt_detected = 0
  280. cnt_kb_covered = 0
  281. cnt_path_match = 0
  282. cnt_root_match = 0
  283. cnt_new = 0
  284. env = CausalTracingEnv(self.causal_graph, test_scores, self.threshold_df, self.expert_knowledge)
  285. for win_idx in range(len(test_scores)):
  286. scores = test_scores[win_idx]
  287. active = []
  288. for t_name in config.TRIGGER_SENSORS:
  289. if t_name in self.sensor_map:
  290. idx = self.sensor_map[t_name]
  291. if scores[idx] > config.TRIGGER_SCORE_THRESH:
  292. active.append((t_name, idx))
  293. for t_name, t_idx in active:
  294. cnt_detected += 1
  295. state_data = env.reset(force_window_idx=win_idx, force_trigger=t_name)
  296. path_idxs = [t_idx]
  297. done = False
  298. while not done:
  299. s_int = state_data[0].unsqueeze(0)
  300. s_float = state_data[1].unsqueeze(0)
  301. valid = env.get_valid_actions(path_idxs[-1])
  302. if len(valid) == 0: break
  303. logits, _ = self.model(s_int, s_float)
  304. mask = torch.full_like(logits, -1e9)
  305. mask[0, valid] = 0
  306. act = torch.argmax(logits + mask, dim=1).item()
  307. state_data, _, done, _ = env.step(act)
  308. path_idxs.append(act)
  309. if len(path_idxs) >= config.MAX_PATH_LENGTH: done = True
  310. path_names = [self.idx_to_sensor[i] for i in path_idxs]
  311. root = path_names[-1]
  312. root_score = scores[self.sensor_map[root]]
  313. match_status = "未定义"
  314. logic = ""
  315. if t_idx in self.expert_knowledge:
  316. cnt_kb_covered += 1
  317. entry = self.expert_knowledge[t_idx]
  318. logic = entry.get('logic', '')
  319. real_roots = [self.idx_to_sensor[r] for r in entry['roots']]
  320. rm = False
  321. for p_node in path_names:
  322. if p_node in real_roots:
  323. rm = True
  324. break
  325. pm = False
  326. path_set = set(path_idxs)
  327. for exp_p in entry['paths']:
  328. exp_set = set(exp_p)
  329. intersection = len(path_set.intersection(exp_set))
  330. union = len(path_set.union(exp_set))
  331. if union > 0 and (intersection / union) >= 0.6:
  332. pm = True
  333. break
  334. if pm:
  335. match_status = "路径吻合"
  336. cnt_path_match += 1
  337. cnt_root_match += 1
  338. elif rm:
  339. match_status = "仅根因吻合"
  340. cnt_root_match += 1
  341. else:
  342. match_status = "不吻合"
  343. else:
  344. match_status = "新链路"
  345. cnt_new += 1
  346. results.append({
  347. "窗口ID": win_idx,
  348. "诱发变量": t_name,
  349. "溯源路径": "->".join(path_names),
  350. "根因变量": root,
  351. "根因异常分": f"{root_score:.3f}",
  352. "是否知识库": "是" if t_idx in self.expert_knowledge else "否",
  353. "匹配情况": match_status,
  354. "机理描述": logic
  355. })
  356. denom = max(cnt_kb_covered, 1)
  357. summary = [
  358. {"指标": "检测到的总异常样本数", "数值": cnt_detected},
  359. {"指标": "知识库覆盖的样本数", "数值": cnt_kb_covered},
  360. {"指标": "异常链路准确率", "数值": f"{cnt_path_match/denom:.2%}"},
  361. {"指标": "根因准确率", "数值": f"{cnt_root_match/denom:.2%}"},
  362. {"指标": "新发现异常模式数", "数值": cnt_new}
  363. ]
  364. save_path = os.path.join(config.RESULT_SAVE_DIR, config.TEST_RESULT_FILENAME)
  365. with pd.ExcelWriter(save_path, engine='openpyxl') as writer:
  366. pd.DataFrame(summary).to_excel(writer, sheet_name='Sheet1_概览指标', index=False)
  367. pd.DataFrame(results).to_excel(writer, sheet_name='Sheet2_测试集详情', index=False)
  368. print("\n" + "="*50)
  369. print(pd.DataFrame(summary).to_string(index=False))
  370. print(f"\n文件已保存: {save_path}")
  371. print("="*50)
  372. def save_model(self):
  373. path = os.path.join(config.MODEL_SAVE_DIR, "ppo_tracing_model.pth")
  374. torch.save(self.model.state_dict(), path)