config.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # config.py
  2. import os
  3. import yaml
  4. class Config:
  5. def __init__(self):
  6. self.PLANT_NAME = ""
  7. self.PLANT_DIR = ""
  8. def load(self, plant_name: str):
  9. self.PLANT_NAME = plant_name
  10. self.PLANT_DIR = f"./{plant_name}"
  11. yaml_path = f"{self.PLANT_DIR}/config.yaml"
  12. if not os.path.exists(yaml_path):
  13. raise FileNotFoundError(f"找不到配置文件: {yaml_path}")
  14. with open(yaml_path, 'r', encoding='utf-8') as f:
  15. cfg = yaml.safe_load(f)
  16. # 1. 路径挂载
  17. files = cfg.get('files', {})
  18. self.DATA_DIR = f"{self.PLANT_DIR}/{files.get('dataset_dir', 'datasets')}"
  19. self.FILE_PATTERN = files.get('file_pattern', 'data_process_{}.csv')
  20. self.MODEL_PATH = f"{self.PLANT_DIR}/{files.get('model_filename', 'model.pth')}"
  21. self.SCALER_PATH = f"{self.PLANT_DIR}/{files.get('scaler_filename', 'scaler.pkl')}"
  22. self.OUTPUT_CSV_PATH = f"{self.PLANT_DIR}/{files.get('output_csv_filename', 'predictions.csv')}"
  23. self.EDGE_INDEX_PATH = f"{self.PLANT_DIR}/{files.get('edge_index_filename', 'edge_index.pt')}"
  24. # 2. 数据划分
  25. split = cfg.get('data_split', {})
  26. self.START_FILES = split.get('start_files', 1)
  27. self.END_FILES = split.get('end_files', 10)
  28. self.TRAIN_START_DATE = split.get('train_start_date')
  29. self.TRAIN_END_DATE = split.get('train_end_date')
  30. self.VAL_START_DATE = split.get('val_start_date')
  31. self.VAL_END_DATE = split.get('val_end_date')
  32. self.TEST_START_DATE = split.get('test_start_date')
  33. self.TEST_END_DATE = split.get('test_end_date')
  34. # 3. 模型与超参数
  35. mp = cfg.get('model_params', {})
  36. self.SEQ_LEN = mp.get('seq_len', 10)
  37. self.OUTPUT_SIZE = mp.get('output_size', 5)
  38. self.STEP_SIZE = mp.get('step_size', 5)
  39. self.RESOLUTION = mp.get('resolution', 60)
  40. self.HIDDEN_SIZE = mp.get('hidden_size', 64)
  41. self.NUM_LAYERS = mp.get('num_layers', 1)
  42. self.DROPOUT = mp.get('dropout', 0.0)
  43. # 4. 传感器特征推导 (核心:自动计算特征维度)
  44. sensors = cfg.get('sensors', {})
  45. self.REQUIRED_COLUMNS = sensors.get('required_columns', [])
  46. self.TARGET_COLUMNS = sensors.get('target_columns', [])
  47. self.LABELS_NUM = len(self.TARGET_COLUMNS)
  48. # 业务特征(不含index) + 4维手动注入的时间编码
  49. self.FEATURE_NUM = (len(self.REQUIRED_COLUMNS) - 1) + 4
  50. # 5. 训练参数
  51. tp = cfg.get('training_params', {})
  52. self.EPOCHS = tp.get('epochs', 200)
  53. self.LR = tp.get('lr', 0.01)
  54. self.BATCH_SIZE = tp.get('batch_size', 512)
  55. self.SCHEDULER_STEP_SIZE = tp.get('scheduler_step_size', 100)
  56. self.SCHEDULER_GAMMA = tp.get('scheduler_gamma', 0.9)
  57. self.PATIENCE = tp.get('patience', 200)
  58. self.MIN_DELTA = float(tp.get('min_delta', 1e-10))
  59. self.DEVICE_ID = tp.get('device', 1)
  60. self.RANDOM_SEED = tp.get('random_seed', 1314)
  61. os.makedirs(self.DATA_DIR, exist_ok=True)
  62. config = Config()