config.py 4.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # -*- coding: utf-8 -*-
  2. """config.py: 纯相对路径动态配置加载器"""
  3. import os
  4. import yaml
  5. class Config:
  6. def __init__(self):
  7. self._config_data = {}
  8. self.PLANT_NAME = ""
  9. self.PLANT_DIR = ""
  10. def load(self, plant_name: str):
  11. """传入水厂名称 (如 'longting'),自动挂载该水厂所有相对路径"""
  12. self.PLANT_NAME = plant_name
  13. self.PLANT_DIR = f"./{plant_name}"
  14. yaml_path = f"{self.PLANT_DIR}/config.yaml"
  15. if not os.path.exists(yaml_path):
  16. raise FileNotFoundError(f"找不到配置文件: {yaml_path}")
  17. with open(yaml_path, 'r', encoding='utf-8') as f:
  18. self._config_data = yaml.safe_load(f)
  19. self._parse_config()
  20. self._init_directories()
  21. def _parse_config(self):
  22. files = self._config_data.get('files', {})
  23. # 1. 目录路径
  24. self.DATASET_SENSOR_DIR = f"{self.PLANT_DIR}/datasets"
  25. self.RESULT_SAVE_DIR = f"{self.PLANT_DIR}/results"
  26. self.MODEL_SAVE_DIR = self.PLANT_DIR # 模型保存在水厂根目录
  27. # 2. 完整文件相对路径
  28. self.THRESHOLD_FILENAME = f"{self.PLANT_DIR}/{files.get('threshold_filename', 'sensor_threshold.xlsx')}"
  29. self.ABNORMAL_LINK_FILENAME = f"{self.PLANT_DIR}/{files.get('abnormal_link_filename', 'abnormal_link.xlsx')}"
  30. self.MODEL_FILE_PATH = f"{self.PLANT_DIR}/{files.get('model_filename', 'ppo_tracing_model.pth')}"
  31. self.TEST_RESULT_FILENAME = files.get('test_result_filename', 'Final_Test_Report.xlsx') # 这个留给 pd.ExcelWriter 处理
  32. self.SENSOR_FILE_PREFIX = files.get('sensor_file_prefix', 'data_process_')
  33. # 3. 传感器与关键字
  34. sensors = self._config_data.get('sensors', {})
  35. self.KEYWORD_LAYER = sensors.get('keyword_layer', 'One_layer')
  36. self.KEYWORD_DEVICE = sensors.get('keyword_device', 'Device')
  37. self.TRIGGER_SENSORS = sensors.get('trigger_sensors', [])
  38. # 4. 数据处理参数
  39. data = self._config_data.get('data_processing', {})
  40. self.SENSOR_FILE_NUM_RANGE = tuple(data.get('sensor_file_num_range', (1, 10)))
  41. self.ORIGINAL_SAMPLE_INTERVAL = data.get('original_sample_interval', 4)
  42. self.TARGET_SAMPLE_INTERVAL = data.get('target_sample_interval', 20)
  43. self.WINDOW_DURATION_MIN = data.get('window_duration_min', 40)
  44. # 衍生变量
  45. self.POINTS_PER_WINDOW = int((self.WINDOW_DURATION_MIN * 60) / self.TARGET_SAMPLE_INTERVAL)
  46. self.WINDOW_STEP = self.POINTS_PER_WINDOW // 2
  47. self.VALID_DATA_RATIO = data.get('valid_data_ratio', 0.6)
  48. self.WINDOW_ANOMALY_THRESHOLD = data.get('window_anomaly_threshold', 0.2)
  49. self.TRAIN_TEST_SPLIT = data.get('train_test_split', 0.8)
  50. self.TRIGGER_SCORE_THRESH = data.get('trigger_score_thresh', 0.5)
  51. self.ABSOLUTE_SCORE_WEIGHT = data.get('absolute_score_weight', 0.6)
  52. self.DYNAMIC_SCORE_WEIGHT = data.get('dynamic_score_weight', 0.4)
  53. self.MAD_HISTORY_WINDOW = data.get('mad_history_window', 360)
  54. self.MAD_THRESHOLD = data.get('mad_threshold', 3.0)
  55. # 5. 强化学习参数
  56. rl = self._config_data.get('rl_params', {})
  57. self.MIN_PATH_LENGTH = rl.get('min_path_length', 3)
  58. self.MAX_PATH_LENGTH = rl.get('max_path_length', 6)
  59. self.EMBEDDING_DIM = rl.get('embedding_dim', 64)
  60. self.HIDDEN_DIM = rl.get('hidden_dim', 256)
  61. self.PPO_LR = float(rl.get('ppo_lr', 3e-4))
  62. self.PPO_GAMMA = rl.get('ppo_gamma', 0.90)
  63. self.PPO_EPS_CLIP = rl.get('ppo_eps_clip', 0.2)
  64. self.PPO_K_EPOCHS = rl.get('ppo_k_epochs', 10)
  65. self.PPO_BATCH_SIZE = rl.get('ppo_batch_size', 64)
  66. self.BC_EPOCHS = rl.get('bc_epochs', 20000)
  67. self.RL_EPISODES = rl.get('rl_episodes', 20)
  68. def _init_directories(self):
  69. """确保当前水厂的数据和结果目录存在"""
  70. os.makedirs(self.DATASET_SENSOR_DIR, exist_ok=True)
  71. os.makedirs(self.RESULT_SAVE_DIR, exist_ok=True)
  72. # 实例化全局单例
  73. config = Config()