# config.py import os import yaml class Config: def __init__(self): self.PLANT_NAME = "" self.PLANT_DIR = "" def load(self, plant_name: str): self.PLANT_NAME = plant_name self.PLANT_DIR = f"./{plant_name}" yaml_path = f"{self.PLANT_DIR}/config.yaml" if not os.path.exists(yaml_path): raise FileNotFoundError(f"找不到配置文件: {yaml_path}") with open(yaml_path, 'r', encoding='utf-8') as f: cfg = yaml.safe_load(f) # 1. 路径挂载 files = cfg.get('files', {}) self.DATA_DIR = f"{self.PLANT_DIR}/{files.get('dataset_dir', 'datasets')}" self.FILE_PATTERN = files.get('file_pattern', 'data_process_{}.csv') self.MODEL_PATH = f"{self.PLANT_DIR}/{files.get('model_filename', 'model.pth')}" self.SCALER_PATH = f"{self.PLANT_DIR}/{files.get('scaler_filename', 'scaler.pkl')}" self.OUTPUT_CSV_PATH = f"{self.PLANT_DIR}/{files.get('output_csv_filename', 'predictions.csv')}" self.EDGE_INDEX_PATH = f"{self.PLANT_DIR}/{files.get('edge_index_filename', 'edge_index.pt')}" # 2. 数据划分 split = cfg.get('data_split', {}) self.START_FILES = split.get('start_files', 1) self.END_FILES = split.get('end_files', 10) self.TRAIN_START_DATE = split.get('train_start_date') self.TRAIN_END_DATE = split.get('train_end_date') self.VAL_START_DATE = split.get('val_start_date') self.VAL_END_DATE = split.get('val_end_date') self.TEST_START_DATE = split.get('test_start_date') self.TEST_END_DATE = split.get('test_end_date') # 3. 模型与超参数 mp = cfg.get('model_params', {}) self.SEQ_LEN = mp.get('seq_len', 10) self.OUTPUT_SIZE = mp.get('output_size', 5) self.STEP_SIZE = mp.get('step_size', 5) self.RESOLUTION = mp.get('resolution', 60) self.HIDDEN_SIZE = mp.get('hidden_size', 64) self.NUM_LAYERS = mp.get('num_layers', 1) self.DROPOUT = mp.get('dropout', 0.0) # 4. 传感器特征推导 (核心:自动计算特征维度) sensors = cfg.get('sensors', {}) self.REQUIRED_COLUMNS = sensors.get('required_columns', []) self.TARGET_COLUMNS = sensors.get('target_columns', []) self.LABELS_NUM = len(self.TARGET_COLUMNS) # 业务特征(不含index) + 4维手动注入的时间编码 self.FEATURE_NUM = (len(self.REQUIRED_COLUMNS) - 1) + 4 # 5. 训练参数 tp = cfg.get('training_params', {}) self.EPOCHS = tp.get('epochs', 200) self.LR = tp.get('lr', 0.01) self.BATCH_SIZE = tp.get('batch_size', 512) self.SCHEDULER_STEP_SIZE = tp.get('scheduler_step_size', 100) self.SCHEDULER_GAMMA = tp.get('scheduler_gamma', 0.9) self.PATIENCE = tp.get('patience', 200) self.MIN_DELTA = float(tp.get('min_delta', 1e-10)) self.DEVICE_ID = tp.get('device', 1) self.RANDOM_SEED = tp.get('random_seed', 1314) os.makedirs(self.DATA_DIR, exist_ok=True) config = Config()