main.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # main.py
  2. import os
  3. import torch
  4. import numpy as np
  5. import random
  6. from gat_lstm import GAT_LSTM
  7. from data_trainer import Trainer
  8. from args import lstm_args_parser
  9. from torch.nn import MSELoss
  10. from data_preprocessor import DataPreprocessor
  11. def set_seed(seed):
  12. random.seed(seed)
  13. os.environ['PYTHONHASHSEED'] = str(seed)
  14. np.random.seed(seed)
  15. torch.manual_seed(seed)
  16. torch.cuda.manual_seed(seed)
  17. torch.backends.cudnn.deterministic = True
  18. torch.backends.cudnn.benchmark = False
  19. def main():
  20. args = lstm_args_parser()
  21. set_seed(args.random_seed)
  22. device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
  23. args.device = device
  24. print(f"当前配置: 序列长度={args.seq_len}, 特征数={args.feature_num}, 目标数={args.labels_num}")
  25. # 数据预处理
  26. data = DataPreprocessor.read_and_combine_csv_files(args)
  27. train_loader, val_loader, test_loader, _ = DataPreprocessor.load_and_process_data(args, data)
  28. # 初始化模型
  29. model = GAT_LSTM(args).to(device)
  30. # 训练器
  31. trainer = Trainer(model, args, data)
  32. criterion = MSELoss()
  33. optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
  34. scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step_size, gamma=args.scheduler_gamma)
  35. print("=== 开始训练 ===")
  36. trainer.train_full_model(train_loader, val_loader, optimizer, criterion, scheduler)
  37. trainer.save_model()
  38. print("=== 开始评估 ===")
  39. trainer.evaluate_model(test_loader, MSELoss())
  40. if __name__ == "__main__":
  41. main()