import torch.optim as optim from args import get_args from data_preprocessor import DataPreprocessor from gat import GAT from data_trainer import DataTrainer from rl_optimizer import RLOptimizer import logging import os def setup_logger(args): """设置日志记录""" if not os.path.exists('logs'): os.makedirs('logs') logger = logging.getLogger('GAT-Training') logger.setLevel(logging.INFO) # 文件处理器 file_handler = logging.FileHandler(f'logs/training_{args.num_files}.log') file_handler.setLevel(logging.INFO) # 控制台处理器 console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) # 格式化器 formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler.setFormatter(formatter) console_handler.setFormatter(formatter) logger.addHandler(file_handler) logger.addHandler(console_handler) return logger def main(): # 获取参数 args = get_args() logger = setup_logger(args) logger.info(f"使用设备: {args.device}") # 数据预处理 preprocessor = DataPreprocessor(args, logger) train_loader, val_loader, test_loader, preprocessor = preprocessor.preprocess() # 创建有向图邻接矩阵 adj = preprocessor.create_adjacency_matrix() logger.info(f"邻接矩阵形状: {adj.shape}") # 步骤1: 使用强化学习优化超参数 rl_optimizer = RLOptimizer(args, preprocessor, train_loader, val_loader, adj, logger) best_hparams = rl_optimizer.optimize() # 步骤2: 使用最优超参数训练最终模型 logger.info("\n使用最优超参数训练最终模型...") final_model = GAT( nfeat=1, nhid=best_hparams['hidden_dim'], noutput=args.num_targets, dropout=best_hparams['dropout'], nheads=best_hparams['num_heads'], alpha=0.2 ).to(args.device) # 配置优化器和学习率调度器 optimizer = optim.Adam( final_model.parameters(), lr=best_hparams['lr'], weight_decay=args.weight_decay ) # 学习率调度器 scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.5, patience=10, verbose=True ) # 训练最终模型 trainer = DataTrainer(final_model, args, preprocessor, optimizer, scheduler, logger) trained_model = trainer.train(train_loader, val_loader, adj) # 步骤3: 在测试集上评估 logger.info("\n在测试集上评估最终模型...") test_results = trainer.test(test_loader, adj) logger.info("所有任务完成!") if __name__ == "__main__": main()