main.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # main.py
  2. import os
  3. import torch
  4. import numpy as np
  5. import random
  6. from gat_lstm import GAT_LSTM
  7. from data_trainer import Trainer
  8. from args import lstm_args_parser
  9. from torch.nn import MSELoss
  10. from data_preprocessor import DataPreprocessor
  11. def set_seed(seed):
  12. random.seed(seed)
  13. os.environ['PYTHONHASHSEED'] = str(seed)
  14. np.random.seed(seed)
  15. torch.manual_seed(seed)
  16. torch.cuda.manual_seed(seed)
  17. torch.backends.cudnn.deterministic = True
  18. torch.backends.cudnn.benchmark = False
  19. def main():
  20. args = lstm_args_parser()
  21. set_seed(args.random_seed)
  22. device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
  23. args.device = device
  24. print(f"当前配置: 序列长度={args.seq_len}, 特征数={args.feature_num}, 目标数={args.labels_num}")
  25. # 数据预处理
  26. data = DataPreprocessor.read_and_combine_csv_files(args)
  27. train_loader, val_loader, test_loader, _ = DataPreprocessor.load_and_process_data(args, data)
  28. # 初始化模型
  29. model = GAT_LSTM(args).to(device)
  30. # 加载 edge_index.pt
  31. if os.path.exists('edge_index.pt'):
  32. edge_index = torch.load('edge_index.pt', map_location=device, weights_only=True)
  33. model.set_edge_index(edge_index)
  34. print("已加载 edge_index.pt")
  35. else:
  36. print("未找到 edge_index.pt")
  37. # 训练器
  38. trainer = Trainer(model, args, data)
  39. criterion = MSELoss()
  40. optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
  41. scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step_size, gamma=args.scheduler_gamma)
  42. print("=== 开始训练 ===")
  43. trainer.train_full_model(train_loader, val_loader, optimizer, criterion, scheduler)
  44. trainer.save_model()
  45. print("=== 开始评估 ===")
  46. trainer.evaluate_model(test_loader, MSELoss())
  47. if __name__ == "__main__":
  48. main()