| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- # -*- coding: utf-8 -*-
- """config.py: 纯相对路径动态配置加载器"""
- import os
- import yaml
- class Config:
- def __init__(self):
- self._config_data = {}
- self.PLANT_NAME = ""
- self.PLANT_DIR = ""
-
- def load(self, plant_name: str):
- """传入水厂名称 (如 'longting'),自动挂载该水厂所有相对路径"""
- self.PLANT_NAME = plant_name
- self.PLANT_DIR = f"./{plant_name}"
-
- yaml_path = f"{self.PLANT_DIR}/config.yaml"
- if not os.path.exists(yaml_path):
- raise FileNotFoundError(f"找不到配置文件: {yaml_path}")
-
- with open(yaml_path, 'r', encoding='utf-8') as f:
- self._config_data = yaml.safe_load(f)
-
- self._parse_config()
- self._init_directories()
- def _parse_config(self):
- files = self._config_data.get('files', {})
-
- # 1. 目录路径
- self.DATASET_SENSOR_DIR = f"{self.PLANT_DIR}/datasets"
- self.RESULT_SAVE_DIR = f"{self.PLANT_DIR}/results"
- self.MODEL_SAVE_DIR = self.PLANT_DIR # 模型保存在水厂根目录
-
- # 2. 完整文件相对路径
- self.THRESHOLD_FILENAME = f"{self.PLANT_DIR}/{files.get('threshold_filename', 'sensor_threshold.xlsx')}"
- self.ABNORMAL_LINK_FILENAME = f"{self.PLANT_DIR}/{files.get('abnormal_link_filename', 'abnormal_link.xlsx')}"
- self.MODEL_FILE_PATH = f"{self.PLANT_DIR}/{files.get('model_filename', 'ppo_tracing_model.pth')}"
- self.TEST_RESULT_FILENAME = files.get('test_result_filename', 'Final_Test_Report.xlsx') # 这个留给 pd.ExcelWriter 处理
- self.SENSOR_FILE_PREFIX = files.get('sensor_file_prefix', 'data_process_')
- # 3. 传感器与关键字
- sensors = self._config_data.get('sensors', {})
- self.KEYWORD_LAYER = sensors.get('keyword_layer', 'One_layer')
- self.KEYWORD_DEVICE = sensors.get('keyword_device', 'Device')
- self.TRIGGER_SENSORS = sensors.get('trigger_sensors', [])
- # 4. 数据处理参数
- data = self._config_data.get('data_processing', {})
- self.SENSOR_FILE_NUM_RANGE = tuple(data.get('sensor_file_num_range', (1, 10)))
- self.ORIGINAL_SAMPLE_INTERVAL = data.get('original_sample_interval', 4)
- self.TARGET_SAMPLE_INTERVAL = data.get('target_sample_interval', 20)
- self.WINDOW_DURATION_MIN = data.get('window_duration_min', 40)
-
- # 衍生变量
- self.POINTS_PER_WINDOW = int((self.WINDOW_DURATION_MIN * 60) / self.TARGET_SAMPLE_INTERVAL)
- self.WINDOW_STEP = self.POINTS_PER_WINDOW // 2
-
- self.VALID_DATA_RATIO = data.get('valid_data_ratio', 0.6)
- self.WINDOW_ANOMALY_THRESHOLD = data.get('window_anomaly_threshold', 0.2)
- self.TRAIN_TEST_SPLIT = data.get('train_test_split', 0.8)
- self.TRIGGER_SCORE_THRESH = data.get('trigger_score_thresh', 0.5)
- self.ABSOLUTE_SCORE_WEIGHT = data.get('absolute_score_weight', 0.6)
- self.DYNAMIC_SCORE_WEIGHT = data.get('dynamic_score_weight', 0.4)
- self.MAD_HISTORY_WINDOW = data.get('mad_history_window', 360)
- self.MAD_THRESHOLD = data.get('mad_threshold', 3.0)
- # 5. 强化学习参数
- rl = self._config_data.get('rl_params', {})
- self.MIN_PATH_LENGTH = rl.get('min_path_length', 3)
- self.MAX_PATH_LENGTH = rl.get('max_path_length', 6)
- self.EMBEDDING_DIM = rl.get('embedding_dim', 64)
- self.HIDDEN_DIM = rl.get('hidden_dim', 256)
- self.PPO_LR = float(rl.get('ppo_lr', 3e-4))
- self.PPO_GAMMA = rl.get('ppo_gamma', 0.90)
- self.PPO_EPS_CLIP = rl.get('ppo_eps_clip', 0.2)
- self.PPO_K_EPOCHS = rl.get('ppo_k_epochs', 10)
- self.PPO_BATCH_SIZE = rl.get('ppo_batch_size', 64)
- self.BC_EPOCHS = rl.get('bc_epochs', 20000)
- self.RL_EPISODES = rl.get('rl_episodes', 20)
- def _init_directories(self):
- """确保当前水厂的数据和结果目录存在"""
- os.makedirs(self.DATASET_SENSOR_DIR, exist_ok=True)
- os.makedirs(self.RESULT_SAVE_DIR, exist_ok=True)
- # 实例化全局单例
- config = Config()
|