| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- import torch.optim as optim
- from args import get_args
- from data_preprocessor import DataPreprocessor
- from gat import GAT
- from data_trainer import DataTrainer
- from rl_optimizer import RLOptimizer
- import logging
- import os
- def setup_logger(args):
- """设置日志记录"""
- if not os.path.exists('logs'):
- os.makedirs('logs')
-
- logger = logging.getLogger('GAT-Training')
- logger.setLevel(logging.INFO)
-
- # 文件处理器
- file_handler = logging.FileHandler(f'logs/training_{args.num_files}.log')
- file_handler.setLevel(logging.INFO)
-
- # 控制台处理器
- console_handler = logging.StreamHandler()
- console_handler.setLevel(logging.INFO)
-
- # 格式化器
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- file_handler.setFormatter(formatter)
- console_handler.setFormatter(formatter)
-
- logger.addHandler(file_handler)
- logger.addHandler(console_handler)
-
- return logger
- def main():
- # 获取参数
- args = get_args()
- logger = setup_logger(args)
- logger.info(f"使用设备: {args.device}")
-
- # 数据预处理
- preprocessor = DataPreprocessor(args, logger)
- train_loader, val_loader, test_loader, preprocessor = preprocessor.preprocess()
-
- # 创建有向图邻接矩阵
- adj = preprocessor.create_adjacency_matrix()
- logger.info(f"邻接矩阵形状: {adj.shape}")
-
- # 步骤1: 使用强化学习优化超参数
- rl_optimizer = RLOptimizer(args, preprocessor, train_loader, val_loader, adj, logger)
- best_hparams = rl_optimizer.optimize()
-
- # 步骤2: 使用最优超参数训练最终模型
- logger.info("\n使用最优超参数训练最终模型...")
- final_model = GAT(
- nfeat=1,
- nhid=best_hparams['hidden_dim'],
- noutput=args.num_targets,
- dropout=best_hparams['dropout'],
- nheads=best_hparams['num_heads'],
- alpha=0.2
- ).to(args.device)
-
- # 配置优化器和学习率调度器
- optimizer = optim.Adam(
- final_model.parameters(),
- lr=best_hparams['lr'],
- weight_decay=args.weight_decay
- )
-
- # 学习率调度器
- scheduler = optim.lr_scheduler.ReduceLROnPlateau(
- optimizer, mode='min', factor=0.5, patience=10, verbose=True
- )
-
- # 训练最终模型
- trainer = DataTrainer(final_model, args, preprocessor, optimizer, scheduler, logger)
- trained_model = trainer.train(train_loader, val_loader, adj)
-
- # 步骤3: 在测试集上评估
- logger.info("\n在测试集上评估最终模型...")
- test_results = trainer.test(test_loader, adj)
-
- logger.info("所有任务完成!")
- if __name__ == "__main__":
- main()
|