| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- # config.py
- import os
- import yaml
- class Config:
- def __init__(self):
- self.PLANT_NAME = ""
- self.PLANT_DIR = ""
- def load(self, plant_name: str):
- self.PLANT_NAME = plant_name
- self.PLANT_DIR = f"./{plant_name}"
- yaml_path = f"{self.PLANT_DIR}/config.yaml"
- if not os.path.exists(yaml_path):
- raise FileNotFoundError(f"找不到配置文件: {yaml_path}")
- with open(yaml_path, 'r', encoding='utf-8') as f:
- cfg = yaml.safe_load(f)
- # 1. 路径挂载
- files = cfg.get('files', {})
- self.DATA_DIR = f"{self.PLANT_DIR}/{files.get('dataset_dir', 'datasets')}"
- self.FILE_PATTERN = files.get('file_pattern', 'data_process_{}.csv')
- self.MODEL_PATH = f"{self.PLANT_DIR}/{files.get('model_filename', 'model.pth')}"
- self.SCALER_PATH = f"{self.PLANT_DIR}/{files.get('scaler_filename', 'scaler.pkl')}"
- self.OUTPUT_CSV_PATH = f"{self.PLANT_DIR}/{files.get('output_csv_filename', 'predictions.csv')}"
- self.EDGE_INDEX_PATH = f"{self.PLANT_DIR}/{files.get('edge_index_filename', 'edge_index.pt')}"
- # 2. 数据划分
- split = cfg.get('data_split', {})
- self.START_FILES = split.get('start_files', 1)
- self.END_FILES = split.get('end_files', 10)
- self.TRAIN_START_DATE = split.get('train_start_date')
- self.TRAIN_END_DATE = split.get('train_end_date')
- self.VAL_START_DATE = split.get('val_start_date')
- self.VAL_END_DATE = split.get('val_end_date')
- self.TEST_START_DATE = split.get('test_start_date')
- self.TEST_END_DATE = split.get('test_end_date')
- # 3. 模型与超参数
- mp = cfg.get('model_params', {})
- self.SEQ_LEN = mp.get('seq_len', 10)
- self.OUTPUT_SIZE = mp.get('output_size', 5)
- self.STEP_SIZE = mp.get('step_size', 5)
- self.RESOLUTION = mp.get('resolution', 60)
- self.HIDDEN_SIZE = mp.get('hidden_size', 64)
- self.NUM_LAYERS = mp.get('num_layers', 1)
- self.DROPOUT = mp.get('dropout', 0.0)
- # 4. 传感器特征推导 (核心:自动计算特征维度)
- sensors = cfg.get('sensors', {})
- self.REQUIRED_COLUMNS = sensors.get('required_columns', [])
- self.TARGET_COLUMNS = sensors.get('target_columns', [])
-
- self.LABELS_NUM = len(self.TARGET_COLUMNS)
- # 业务特征(不含index) + 4维手动注入的时间编码
- self.FEATURE_NUM = (len(self.REQUIRED_COLUMNS) - 1) + 4
- # 5. 训练参数
- tp = cfg.get('training_params', {})
- self.EPOCHS = tp.get('epochs', 200)
- self.LR = tp.get('lr', 0.01)
- self.BATCH_SIZE = tp.get('batch_size', 512)
- self.SCHEDULER_STEP_SIZE = tp.get('scheduler_step_size', 100)
- self.SCHEDULER_GAMMA = tp.get('scheduler_gamma', 0.9)
- self.PATIENCE = tp.get('patience', 200)
- self.MIN_DELTA = float(tp.get('min_delta', 1e-10))
- self.DEVICE_ID = tp.get('device', 1)
- self.RANDOM_SEED = tp.get('random_seed', 1314)
- os.makedirs(self.DATA_DIR, exist_ok=True)
- config = Config()
|