# -*- coding: utf-8 -*- """config.py: 纯相对路径动态配置加载器""" import os import yaml class Config: def __init__(self): self._config_data = {} self.PLANT_NAME = "" self.PLANT_DIR = "" def load(self, plant_name: str): """传入水厂名称 (如 'longting'),自动挂载该水厂所有相对路径""" self.PLANT_NAME = plant_name self.PLANT_DIR = f"./{plant_name}" yaml_path = f"{self.PLANT_DIR}/config.yaml" if not os.path.exists(yaml_path): raise FileNotFoundError(f"找不到配置文件: {yaml_path}") with open(yaml_path, 'r', encoding='utf-8') as f: self._config_data = yaml.safe_load(f) self._parse_config() self._init_directories() def _parse_config(self): files = self._config_data.get('files', {}) # 1. 目录路径 self.DATASET_SENSOR_DIR = f"{self.PLANT_DIR}/datasets" self.RESULT_SAVE_DIR = f"{self.PLANT_DIR}/results" self.MODEL_SAVE_DIR = self.PLANT_DIR # 模型保存在水厂根目录 # 2. 完整文件相对路径 self.THRESHOLD_FILENAME = f"{self.PLANT_DIR}/{files.get('threshold_filename', 'sensor_threshold.xlsx')}" self.ABNORMAL_LINK_FILENAME = f"{self.PLANT_DIR}/{files.get('abnormal_link_filename', 'abnormal_link.xlsx')}" self.MODEL_FILE_PATH = f"{self.PLANT_DIR}/{files.get('model_filename', 'ppo_tracing_model.pth')}" self.TEST_RESULT_FILENAME = files.get('test_result_filename', 'Final_Test_Report.xlsx') # 这个留给 pd.ExcelWriter 处理 self.SENSOR_FILE_PREFIX = files.get('sensor_file_prefix', 'data_process_') # 3. 传感器与关键字 sensors = self._config_data.get('sensors', {}) self.KEYWORD_LAYER = sensors.get('keyword_layer', 'One_layer') self.KEYWORD_DEVICE = sensors.get('keyword_device', 'Device') self.TRIGGER_SENSORS = sensors.get('trigger_sensors', []) # 4. 数据处理参数 data = self._config_data.get('data_processing', {}) self.SENSOR_FILE_NUM_RANGE = tuple(data.get('sensor_file_num_range', (1, 10))) self.ORIGINAL_SAMPLE_INTERVAL = data.get('original_sample_interval', 4) self.TARGET_SAMPLE_INTERVAL = data.get('target_sample_interval', 20) self.WINDOW_DURATION_MIN = data.get('window_duration_min', 40) # 衍生变量 self.POINTS_PER_WINDOW = int((self.WINDOW_DURATION_MIN * 60) / self.TARGET_SAMPLE_INTERVAL) self.WINDOW_STEP = self.POINTS_PER_WINDOW // 2 self.VALID_DATA_RATIO = data.get('valid_data_ratio', 0.6) self.WINDOW_ANOMALY_THRESHOLD = data.get('window_anomaly_threshold', 0.2) self.TRAIN_TEST_SPLIT = data.get('train_test_split', 0.8) self.TRIGGER_SCORE_THRESH = data.get('trigger_score_thresh', 0.5) self.ABSOLUTE_SCORE_WEIGHT = data.get('absolute_score_weight', 0.6) self.DYNAMIC_SCORE_WEIGHT = data.get('dynamic_score_weight', 0.4) self.MAD_HISTORY_WINDOW = data.get('mad_history_window', 360) self.MAD_THRESHOLD = data.get('mad_threshold', 3.0) # 5. 强化学习参数 rl = self._config_data.get('rl_params', {}) self.MIN_PATH_LENGTH = rl.get('min_path_length', 3) self.MAX_PATH_LENGTH = rl.get('max_path_length', 6) self.EMBEDDING_DIM = rl.get('embedding_dim', 64) self.HIDDEN_DIM = rl.get('hidden_dim', 256) self.PPO_LR = float(rl.get('ppo_lr', 3e-4)) self.PPO_GAMMA = rl.get('ppo_gamma', 0.90) self.PPO_EPS_CLIP = rl.get('ppo_eps_clip', 0.2) self.PPO_K_EPOCHS = rl.get('ppo_k_epochs', 10) self.PPO_BATCH_SIZE = rl.get('ppo_batch_size', 64) self.BC_EPOCHS = rl.get('bc_epochs', 20000) self.RL_EPISODES = rl.get('rl_episodes', 20) def _init_directories(self): """确保当前水厂的数据和结果目录存在""" os.makedirs(self.DATASET_SENSOR_DIR, exist_ok=True) os.makedirs(self.RESULT_SAVE_DIR, exist_ok=True) # 实例化全局单例 config = Config()