main.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import torch.optim as optim
  2. from args import get_args
  3. from data_preprocessor import DataPreprocessor
  4. from gat import GAT
  5. from data_trainer import DataTrainer
  6. from rl_optimizer import RLOptimizer
  7. import logging
  8. import os
  9. def setup_logger(args):
  10. """设置日志记录"""
  11. if not os.path.exists('logs'):
  12. os.makedirs('logs')
  13. logger = logging.getLogger('GAT-Training')
  14. logger.setLevel(logging.INFO)
  15. # 文件处理器
  16. file_handler = logging.FileHandler(f'logs/training_{args.num_files}.log')
  17. file_handler.setLevel(logging.INFO)
  18. # 控制台处理器
  19. console_handler = logging.StreamHandler()
  20. console_handler.setLevel(logging.INFO)
  21. # 格式化器
  22. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  23. file_handler.setFormatter(formatter)
  24. console_handler.setFormatter(formatter)
  25. logger.addHandler(file_handler)
  26. logger.addHandler(console_handler)
  27. return logger
  28. def main():
  29. # 获取参数
  30. args = get_args()
  31. logger = setup_logger(args)
  32. logger.info(f"使用设备: {args.device}")
  33. # 数据预处理
  34. preprocessor = DataPreprocessor(args, logger)
  35. train_loader, val_loader, test_loader, preprocessor = preprocessor.preprocess()
  36. # 创建有向图邻接矩阵
  37. adj = preprocessor.create_adjacency_matrix()
  38. logger.info(f"邻接矩阵形状: {adj.shape}")
  39. # 步骤1: 使用强化学习优化超参数
  40. rl_optimizer = RLOptimizer(args, preprocessor, train_loader, val_loader, adj, logger)
  41. best_hparams = rl_optimizer.optimize()
  42. # 步骤2: 使用最优超参数训练最终模型
  43. logger.info("\n使用最优超参数训练最终模型...")
  44. final_model = GAT(
  45. nfeat=1,
  46. nhid=best_hparams['hidden_dim'],
  47. noutput=args.num_targets,
  48. dropout=best_hparams['dropout'],
  49. nheads=best_hparams['num_heads'],
  50. alpha=0.2
  51. ).to(args.device)
  52. # 配置优化器和学习率调度器
  53. optimizer = optim.Adam(
  54. final_model.parameters(),
  55. lr=best_hparams['lr'],
  56. weight_decay=args.weight_decay
  57. )
  58. # 学习率调度器
  59. scheduler = optim.lr_scheduler.ReduceLROnPlateau(
  60. optimizer, mode='min', factor=0.5, patience=10, verbose=True
  61. )
  62. # 训练最终模型
  63. trainer = DataTrainer(final_model, args, preprocessor, optimizer, scheduler, logger)
  64. trained_model = trainer.train(train_loader, val_loader, adj)
  65. # 步骤3: 在测试集上评估
  66. logger.info("\n在测试集上评估最终模型...")
  67. test_results = trainer.test(test_loader, adj)
  68. logger.info("所有任务完成!")
  69. if __name__ == "__main__":
  70. main()