main.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # main.py
  2. import os
  3. import torch
  4. import numpy as np
  5. import random
  6. from gat_lstm import GAT_LSTM
  7. from data_trainer import Trainer
  8. from args import lstm_args_parser
  9. from torch.nn import MSELoss
  10. from data_preprocessor import DataPreprocessor
  11. def set_seed(seed):
  12. random.seed(seed)
  13. os.environ['PYTHONHASHSEED'] = str(seed)
  14. np.random.seed(seed)
  15. torch.manual_seed(seed)
  16. torch.cuda.manual_seed(seed)
  17. torch.cuda.manual_seed_all(seed)
  18. torch.backends.cudnn.deterministic = True
  19. torch.backends.cudnn.benchmark = False
  20. def main():
  21. args = lstm_args_parser()
  22. set_seed(args.random_seed)
  23. device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
  24. args.device = device # 将device存入args,方便后续使用
  25. # 数据预处理
  26. data = DataPreprocessor.read_and_combine_csv_files(args)
  27. train_loader, val_loader, test_loader, _ = DataPreprocessor.load_and_process_data(args, data)
  28. # 初始化包含16个子模型的整体模型
  29. model = GAT_LSTM(args).to(device)
  30. # 初始化训练器和MSE损失函数
  31. trainer = Trainer(model, args, data)
  32. criterion = MSELoss()
  33. # 优化器:优化所有子模型的参数(联合训练)
  34. optimizer = torch.optim.Adam(
  35. model.parameters(), # 整体模型参数
  36. lr=args.lr
  37. )
  38. scheduler = torch.optim.lr_scheduler.StepLR(
  39. optimizer,
  40. step_size=args.scheduler_step_size,
  41. gamma=args.scheduler_gamma
  42. )
  43. # 整体训练大模型(包含所有16个子模型)
  44. print("=== 开始训练包含16个子模型的整体模型 ===")
  45. trainer.train_full_model(
  46. train_loader,
  47. val_loader,
  48. optimizer,
  49. criterion,
  50. scheduler
  51. )
  52. # 保存包含所有16个子模型参数的整体模型
  53. trainer.save_model()
  54. print("\n=== 模型训练完成,已保存整体模型 ===")
  55. # 评估模型
  56. trainer.evaluate_model(test_loader, MSELoss())
  57. if __name__ == "__main__":
  58. main()