main.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. # main.py
  2. import os
  3. import torch
  4. import numpy as np
  5. import random
  6. import argparse
  7. from torch.nn import MSELoss
  8. from config import config
  9. def set_seed(seed):
  10. random.seed(seed)
  11. os.environ['PYTHONHASHSEED'] = str(seed)
  12. np.random.seed(seed)
  13. torch.manual_seed(seed)
  14. torch.cuda.manual_seed(seed)
  15. torch.backends.cudnn.deterministic = True
  16. torch.backends.cudnn.benchmark = False
  17. def main():
  18. parser = argparse.ArgumentParser(description="水厂预测模型训练")
  19. parser.add_argument('-p', '--plant', type=str, required=True, help="水厂名称,例如: lankao")
  20. args = parser.parse_args()
  21. # 加载对应水厂的配置
  22. config.load(args.plant)
  23. # 延迟导入,确保 config 已加载
  24. from gat_lstm import GAT_LSTM
  25. from data_trainer import Trainer
  26. from data_preprocessor import DataPreprocessor
  27. set_seed(config.RANDOM_SEED)
  28. device = torch.device(f"cuda:{config.DEVICE_ID}" if torch.cuda.is_available() else "cpu")
  29. print(f"[*] 工作空间: {args.plant} | 序列={config.SEQ_LEN}, 特征={config.FEATURE_NUM}, 目标={config.LABELS_NUM}")
  30. data = DataPreprocessor.read_and_combine_csv_files()
  31. train_loader, val_loader, test_loader, _ = DataPreprocessor.load_and_process_data(data)
  32. model = GAT_LSTM().to(device)
  33. if os.path.exists(config.EDGE_INDEX_PATH):
  34. model.set_edge_index(torch.load(config.EDGE_INDEX_PATH, map_location=device, weights_only=True))
  35. print("已加载 edge_index.pt")
  36. trainer = Trainer(model, data)
  37. optimizer = torch.optim.Adam(model.parameters(), lr=config.LR)
  38. scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.SCHEDULER_STEP_SIZE, gamma=config.SCHEDULER_GAMMA)
  39. print("=== 开始训练 ===")
  40. trainer.train_full_model(train_loader, val_loader, optimizer, MSELoss(), scheduler)
  41. trainer.save_model()
  42. print("=== 开始评估 ===")
  43. trainer.evaluate_model(test_loader)
  44. if __name__ == "__main__":
  45. main()