# -*- coding: utf-8 -*- """rl_tracing.py: 强化学习链路级异常溯源""" import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.distributions import Categorical import numpy as np import pandas as pd import os from tqdm import tqdm from config import config # ----------------- 1. 环境 ----------------- class CausalTracingEnv: def __init__(self, causal_graph, window_scores, threshold_df, expert_knowledge=None): self.sensor_list = causal_graph['sensor_list'] self.map = causal_graph['sensor_to_idx'] self.idx_to_sensor = {v: k for k, v in self.map.items()} self.adj = causal_graph['adj_matrix'] self.scores = window_scores self.expert_knowledge = expert_knowledge if expert_knowledge else {} self.num_sensors = len(self.sensor_list) # 解析属性 self.node_props = {} col_one_layer = self._find_col(threshold_df, config.KEYWORD_LAYER) col_device = self._find_col(threshold_df, config.KEYWORD_DEVICE) df_indexed = threshold_df.set_index('ID') dict_one = df_indexed[col_one_layer].to_dict() if col_one_layer else {} dict_dev = df_indexed[col_device].to_dict() if col_device else {} for name, idx in self.map.items(): l_val = dict_one.get(name, -1) try: l_val = int(l_val) except: l_val = 0 d_val = dict_dev.get(name, None) d_val = str(d_val).strip() if pd.notna(d_val) and str(d_val).strip() != '' else None self.node_props[idx] = {'one_layer': l_val, 'device': d_val} self.current_window_idx = 0 self.current_node_idx = 0 self.prev_node_idx = 0 self.trigger_node_idx = 0 self.path = [] self.current_expert_paths = [] self.target_roots = set() def _find_col(self, df, keyword): if keyword in df.columns: return keyword for c in df.columns: if c.lower() == keyword.lower(): return c return None def reset(self, force_window_idx=None, force_trigger=None): if force_window_idx is not None: self.current_window_idx = force_window_idx t_name = force_trigger else: found = False for _ in range(100): w_idx = np.random.randint(len(self.scores)) win_scores = self.scores[w_idx] candidates = [] for t_name in config.TRIGGER_SENSORS: if t_name in self.map: idx = self.map[t_name] if win_scores[idx] > config.TRIGGER_SCORE_THRESH: candidates.append(t_name) if candidates: self.current_window_idx = w_idx t_name = np.random.choice(candidates) found = True break if not found: self.current_window_idx = np.random.randint(len(self.scores)) t_name = list(self.map.keys())[0] self.current_node_idx = self.map.get(t_name, 0) self.trigger_node_idx = self.current_node_idx self.prev_node_idx = self.current_node_idx self.path = [self.current_node_idx] self.target_roots = set() self.current_expert_paths = [] if self.current_node_idx in self.expert_knowledge: entry = self.expert_knowledge[self.current_node_idx] self.target_roots = entry['roots'] self.current_expert_paths = entry['paths'] return self._get_state() def _get_state(self): curr_score = self.scores[self.current_window_idx, self.current_node_idx] prev_score = self.scores[self.current_window_idx, self.prev_node_idx] curr_layer = self.node_props[self.current_node_idx]['one_layer'] / 20.0 return ( torch.LongTensor([self.current_node_idx, self.prev_node_idx, self.trigger_node_idx]), torch.FloatTensor([curr_score, prev_score, curr_layer]) ) def get_valid_actions(self, curr_idx): neighbors = np.where(self.adj[curr_idx] == 1)[0] curr_props = self.node_props[curr_idx] curr_l, curr_d = curr_props['one_layer'], curr_props['device'] valid = [] for n in neighbors: if n in self.path: continue tgt_props = self.node_props[n] tgt_l, tgt_d = tgt_props['one_layer'], tgt_props['device'] if curr_l != 0 and tgt_l != 0: if not ((tgt_l == curr_l) or (tgt_l == curr_l - 1)): continue if (curr_d is not None) and (tgt_d is not None): if curr_d != tgt_d: continue valid.append(n) return np.array(valid) def step(self, action_idx): prev = self.current_node_idx self.prev_node_idx = prev self.current_node_idx = action_idx self.path.append(action_idx) score_curr = self.scores[self.current_window_idx, self.current_node_idx] reward = 0.0 done = False # 奖励机制 (Imitation > Root > Gradient) in_expert_nodes = False for e_path in self.current_expert_paths: if action_idx in e_path: in_expert_nodes = True break if in_expert_nodes: reward += 2.0 else: reward -= 0.2 if action_idx in self.target_roots: reward += 10.0 done = True score_prev = self.scores[self.current_window_idx, prev] diff = score_curr - score_prev if diff > 0: reward += diff * 3.0 else: reward -= 0.5 if len(self.path) >= config.MAX_PATH_LENGTH: done = True if action_idx not in self.target_roots: reward -= 5.0 if score_curr < 0.15 and len(self.path) > 3: done = True reward -= 2.0 return self._get_state(), reward, done, {} # ----------------- 2. 网络 ----------------- class TargetDrivenActorCritic(nn.Module): def __init__(self, num_sensors, embedding_dim=64, hidden_dim=256): super().__init__() self.node_emb = nn.Embedding(num_sensors, embedding_dim) input_dim = (embedding_dim * 3) + 3 self.shared_net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, hidden_dim), nn.ReLU() ) self.actor = nn.Linear(hidden_dim, num_sensors) self.critic = nn.Linear(hidden_dim, 1) def forward(self, int_data, float_data): curr_emb = self.node_emb(int_data[:, 0]) prev_emb = self.node_emb(int_data[:, 1]) trig_emb = self.node_emb(int_data[:, 2]) x = torch.cat([curr_emb, prev_emb, trig_emb, float_data], dim=1) feat = self.shared_net(x) return self.actor(feat), self.critic(feat) # ----------------- 3. 训练器 ----------------- class RLTrainer: def __init__(self, causal_graph, train_scores, threshold_df): self.sensor_map = causal_graph['sensor_to_idx'] self.idx_to_sensor = {v: k for k, v in self.sensor_map.items()} self.threshold_df = threshold_df self.causal_graph = causal_graph self.expert_knowledge, self.bc_samples, _ = self._load_expert_data() self.env = CausalTracingEnv(causal_graph, train_scores, threshold_df, self.expert_knowledge) self.model = TargetDrivenActorCritic(self.env.num_sensors, config.EMBEDDING_DIM, config.HIDDEN_DIM) self.optimizer = optim.Adam(self.model.parameters(), lr=config.PPO_LR) def _load_expert_data(self): path = os.path.join(config.BASE_DIR, config.ABNORMAL_LINK_FILENAME) kb_data = {} bc_data = [] if not os.path.exists(path): return kb_data, bc_data, None df = pd.read_excel(path) for _, row in df.iterrows(): link = str(row.get('Link Path', '')) if not link: continue nodes_str = [n.strip() for n in link.replace('→', '->').split('->')] path_nodes = nodes_str[::-1] ids = [] valid = True for n in path_nodes: if n in self.sensor_map: ids.append(self.sensor_map[n]) else: valid = False; break if not valid or len(ids)<2: continue trigger_id = ids[0] root_id = ids[-1] if trigger_id not in kb_data: kb_data[trigger_id] = {'paths': [], 'roots': set(), 'logic': row.get('Process Logic Basis', '')} kb_data[trigger_id]['paths'].append(ids) kb_data[trigger_id]['roots'].add(root_id) for i in range(len(ids) - 1): curr = ids[i] prev = ids[max(0, i-1)] nxt = ids[i+1] bc_data.append(((curr, prev, trigger_id), nxt)) return kb_data, bc_data, df def pretrain_bc(self): if not self.bc_samples: return print(f"\n>>> [Step 3.1] 启动BC预训练 ({config.BC_EPOCHS}轮)...") states_int = torch.LongTensor([list(s) for s, a in self.bc_samples]) actions = torch.LongTensor([a for s, a in self.bc_samples]) states_float = torch.zeros((len(states_int), 3)) states_float[:, 0] = 0.9 states_float[:, 1] = 0.8 loss_fn = nn.CrossEntropyLoss() pbar = tqdm(range(config.BC_EPOCHS), desc="BC Training") for epoch in pbar: logits, _ = self.model(states_int, states_float) loss = loss_fn(logits, actions) self.optimizer.zero_grad() loss.backward() self.optimizer.step() if epoch%100==0: pbar.set_postfix({'Loss': f"{loss.item():.4f}"}) def train_ppo(self): print(f"\n>>> [Step 3.2] 启动PPO训练 ({config.RL_EPISODES}轮)...") pbar = tqdm(range(config.RL_EPISODES), desc="PPO Training") rewards_hist = [] for _ in pbar: state_data = self.env.reset() done = False ep_r = 0 b_int, b_float, b_act, b_lp, b_rew, b_mask = [], [], [], [], [], [] while not done: s_int = state_data[0].unsqueeze(0) s_float = state_data[1].unsqueeze(0) valid = self.env.get_valid_actions(s_int[0, 0].item()) if len(valid) == 0: break logits, _ = self.model(s_int, s_float) mask = torch.full_like(logits, -1e9) mask[0, valid] = 0 dist = Categorical(F.softmax(logits+mask, dim=-1)) action = dist.sample() next_s, r, done, _ = self.env.step(action.item()) b_int.append(s_int); b_float.append(s_float) b_act.append(action); b_lp.append(dist.log_prob(action)) b_rew.append(r); b_mask.append(1-done) state_data = next_s ep_r += r if len(b_rew) > 1: self._update_ppo(b_int, b_float, b_act, b_lp, b_rew, b_mask) rewards_hist.append(ep_r) if len(rewards_hist)>50: rewards_hist.pop(0) pbar.set_postfix({'AvgR': f"{np.mean(rewards_hist):.2f}"}) def _update_ppo(self, b_int, b_float, b_act, b_lp, b_rew, b_mask): returns = [] R = 0 for r, m in zip(reversed(b_rew), reversed(b_mask)): R = r + config.PPO_GAMMA * R * m returns.insert(0, R) returns = torch.tensor(returns) if returns.numel() > 1 and returns.std() > 1e-5: returns = (returns - returns.mean()) / (returns.std() + 1e-5) elif returns.numel() > 1: returns = returns - returns.mean() s_int = torch.cat(b_int) s_float = torch.cat(b_float) act = torch.stack(b_act) old_lp = torch.stack(b_lp).detach() for _ in range(config.PPO_K_EPOCHS): logits, vals = self.model(s_int, s_float) dist = Categorical(logits=logits) new_lp = dist.log_prob(act) ratio = torch.exp(new_lp - old_lp) surr1 = ratio * returns surr2 = torch.clamp(ratio, 1-config.PPO_EPS_CLIP, 1+config.PPO_EPS_CLIP) * returns v_pred = vals.squeeze() if v_pred.shape != returns.shape: v_pred = v_pred.view(-1) returns = returns.view(-1) loss = -torch.min(surr1, surr2).mean() + 0.5 * F.mse_loss(v_pred, returns) self.optimizer.zero_grad() loss.backward() self.optimizer.step() def evaluate(self, test_scores): print("\n>>> [Step 4] 评估测试集...") self.model.eval() results = [] cnt_detected = 0 cnt_kb_covered = 0 cnt_path_match = 0 cnt_root_match = 0 cnt_new = 0 env = CausalTracingEnv(self.causal_graph, test_scores, self.threshold_df, self.expert_knowledge) for win_idx in range(len(test_scores)): scores = test_scores[win_idx] active = [] for t_name in config.TRIGGER_SENSORS: if t_name in self.sensor_map: idx = self.sensor_map[t_name] if scores[idx] > config.TRIGGER_SCORE_THRESH: active.append((t_name, idx)) for t_name, t_idx in active: cnt_detected += 1 state_data = env.reset(force_window_idx=win_idx, force_trigger=t_name) path_idxs = [t_idx] done = False while not done: s_int = state_data[0].unsqueeze(0) s_float = state_data[1].unsqueeze(0) valid = env.get_valid_actions(path_idxs[-1]) if len(valid) == 0: break logits, _ = self.model(s_int, s_float) mask = torch.full_like(logits, -1e9) mask[0, valid] = 0 act = torch.argmax(logits + mask, dim=1).item() state_data, _, done, _ = env.step(act) path_idxs.append(act) if len(path_idxs) >= config.MAX_PATH_LENGTH: done = True path_names = [self.idx_to_sensor[i] for i in path_idxs] root = path_names[-1] root_score = scores[self.sensor_map[root]] match_status = "未定义" logic = "" if t_idx in self.expert_knowledge: cnt_kb_covered += 1 entry = self.expert_knowledge[t_idx] logic = entry.get('logic', '') real_roots = [self.idx_to_sensor[r] for r in entry['roots']] rm = False for p_node in path_names: if p_node in real_roots: rm = True break pm = False path_set = set(path_idxs) for exp_p in entry['paths']: exp_set = set(exp_p) intersection = len(path_set.intersection(exp_set)) union = len(path_set.union(exp_set)) if union > 0 and (intersection / union) >= 0.6: pm = True break if pm: match_status = "路径吻合" cnt_path_match += 1 cnt_root_match += 1 elif rm: match_status = "仅根因吻合" cnt_root_match += 1 else: match_status = "不吻合" else: match_status = "新链路" cnt_new += 1 results.append({ "窗口ID": win_idx, "诱发变量": t_name, "溯源路径": "->".join(path_names), "根因变量": root, "根因异常分": f"{root_score:.3f}", "是否知识库": "是" if t_idx in self.expert_knowledge else "否", "匹配情况": match_status, "机理描述": logic }) denom = max(cnt_kb_covered, 1) summary = [ {"指标": "检测到的总异常样本数", "数值": cnt_detected}, {"指标": "知识库覆盖的样本数", "数值": cnt_kb_covered}, {"指标": "异常链路准确率", "数值": f"{cnt_path_match/denom:.2%}"}, {"指标": "根因准确率", "数值": f"{cnt_root_match/denom:.2%}"}, {"指标": "新发现异常模式数", "数值": cnt_new} ] save_path = os.path.join(config.RESULT_SAVE_DIR, config.TEST_RESULT_FILENAME) with pd.ExcelWriter(save_path, engine='openpyxl') as writer: pd.DataFrame(summary).to_excel(writer, sheet_name='Sheet1_概览指标', index=False) pd.DataFrame(results).to_excel(writer, sheet_name='Sheet2_测试集详情', index=False) print("\n" + "="*50) print(pd.DataFrame(summary).to_string(index=False)) print(f"\n文件已保存: {save_path}") print("="*50) def save_model(self): path = os.path.join(config.MODEL_SAVE_DIR, "ppo_tracing_model.pth") torch.save(self.model.state_dict(), path)