causal_structure.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # -*- coding: utf-8 -*-
  2. """causal_structure.py: 第二层 - 物理因果结构构建"""
  3. import numpy as np
  4. import pandas as pd
  5. from config import config
  6. class CausalStructureBuilder:
  7. def __init__(self, threshold_df):
  8. self.df = threshold_df
  9. self.sensor_list = self.df['ID'].tolist()
  10. self.id_to_idx = {name: i for i, name in enumerate(self.sensor_list)}
  11. self.num_sensors = len(self.sensor_list)
  12. self.col_layer = self._find_col_by_keyword(config.KEYWORD_LAYER)
  13. self.col_device = self._find_col_by_keyword(config.KEYWORD_DEVICE)
  14. def _find_col_by_keyword(self, keyword):
  15. if keyword in self.df.columns: return keyword
  16. for col in self.df.columns:
  17. if col.lower() == keyword.lower(): return col
  18. raise ValueError(f"错误: 未找到列名包含 '{keyword}' 的列")
  19. def build(self):
  20. adj_matrix = np.zeros((self.num_sensors, self.num_sensors), dtype=int)
  21. nodes_info = {}
  22. for _, row in self.df.iterrows():
  23. d_val = row[self.col_device]
  24. dev_id = str(d_val).strip() if pd.notna(d_val) and str(d_val).strip() != '' else None
  25. try: l_val = int(row[self.col_layer])
  26. except: l_val = -1
  27. nodes_info[row['ID']] = {'layer': l_val, 'device': dev_id}
  28. count_edges = 0
  29. for i, src_name in enumerate(self.sensor_list):
  30. src_node = nodes_info.get(src_name)
  31. if not src_node or src_node['layer'] == -1: continue
  32. src_l, src_d = src_node['layer'], src_node['device']
  33. for j, dst_name in enumerate(self.sensor_list):
  34. if i == j: continue
  35. dst_node = nodes_info.get(dst_name)
  36. if not dst_node or dst_node['layer'] == -1: continue
  37. dst_l, dst_d = dst_node['layer'], dst_node['device']
  38. is_layer_valid = (dst_l == src_l) or (dst_l == src_l - 1)
  39. if not is_layer_valid: continue
  40. is_dev_valid = True
  41. if (src_d is not None) and (dst_d is not None):
  42. if src_d != dst_d: is_dev_valid = False
  43. if is_dev_valid:
  44. adj_matrix[i, j] = 1
  45. count_edges += 1
  46. return {"sensor_list": self.sensor_list, "sensor_to_idx": self.id_to_idx, "adj_matrix": adj_matrix}