|
|
@@ -1,88 +1,88 @@
|
|
|
-import torch.optim as optim
|
|
|
-from args import get_args
|
|
|
-from data_preprocessor import DataPreprocessor
|
|
|
-from gat import GAT
|
|
|
-from data_trainer import DataTrainer
|
|
|
-from rl_optimizer import RLOptimizer
|
|
|
-import logging
|
|
|
-import os
|
|
|
-
|
|
|
-def setup_logger(args):
|
|
|
- """设置日志记录"""
|
|
|
- if not os.path.exists('logs'):
|
|
|
- os.makedirs('logs')
|
|
|
-
|
|
|
- logger = logging.getLogger('GAT-Training')
|
|
|
- logger.setLevel(logging.INFO)
|
|
|
-
|
|
|
- # 文件处理器
|
|
|
- file_handler = logging.FileHandler(f'logs/training_{args.num_files}.log')
|
|
|
- file_handler.setLevel(logging.INFO)
|
|
|
-
|
|
|
- # 控制台处理器
|
|
|
- console_handler = logging.StreamHandler()
|
|
|
- console_handler.setLevel(logging.INFO)
|
|
|
-
|
|
|
- # 格式化器
|
|
|
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
- file_handler.setFormatter(formatter)
|
|
|
- console_handler.setFormatter(formatter)
|
|
|
-
|
|
|
- logger.addHandler(file_handler)
|
|
|
- logger.addHandler(console_handler)
|
|
|
-
|
|
|
- return logger
|
|
|
-
|
|
|
-def main():
|
|
|
- # 获取参数
|
|
|
- args = get_args()
|
|
|
- logger = setup_logger(args)
|
|
|
- logger.info(f"使用设备: {args.device}")
|
|
|
-
|
|
|
- # 数据预处理
|
|
|
- preprocessor = DataPreprocessor(args, logger)
|
|
|
- train_loader, val_loader, test_loader, preprocessor = preprocessor.preprocess()
|
|
|
-
|
|
|
- # 创建有向图邻接矩阵
|
|
|
- adj = preprocessor.create_adjacency_matrix()
|
|
|
- logger.info(f"邻接矩阵形状: {adj.shape}")
|
|
|
-
|
|
|
- # 步骤1: 使用强化学习优化超参数
|
|
|
- rl_optimizer = RLOptimizer(args, preprocessor, train_loader, val_loader, adj, logger)
|
|
|
- best_hparams = rl_optimizer.optimize()
|
|
|
-
|
|
|
- # 步骤2: 使用最优超参数训练最终模型
|
|
|
- logger.info("\n使用最优超参数训练最终模型...")
|
|
|
- final_model = GAT(
|
|
|
- nfeat=1,
|
|
|
- nhid=best_hparams['hidden_dim'],
|
|
|
- noutput=args.num_targets,
|
|
|
- dropout=best_hparams['dropout'],
|
|
|
- nheads=best_hparams['num_heads'],
|
|
|
- alpha=0.2
|
|
|
- ).to(args.device)
|
|
|
-
|
|
|
- # 配置优化器和学习率调度器
|
|
|
- optimizer = optim.Adam(
|
|
|
- final_model.parameters(),
|
|
|
- lr=best_hparams['lr'],
|
|
|
- weight_decay=args.weight_decay
|
|
|
- )
|
|
|
-
|
|
|
- # 学习率调度器
|
|
|
- scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
|
|
- optimizer, mode='min', factor=0.5, patience=10, verbose=True
|
|
|
- )
|
|
|
-
|
|
|
- # 训练最终模型
|
|
|
- trainer = DataTrainer(final_model, args, preprocessor, optimizer, scheduler, logger)
|
|
|
- trained_model = trainer.train(train_loader, val_loader, adj)
|
|
|
-
|
|
|
- # 步骤3: 在测试集上评估
|
|
|
- logger.info("\n在测试集上评估最终模型...")
|
|
|
- test_results = trainer.test(test_loader, adj)
|
|
|
-
|
|
|
- logger.info("所有任务完成!")
|
|
|
-
|
|
|
-if __name__ == "__main__":
|
|
|
+import torch.optim as optim
|
|
|
+from args import get_args
|
|
|
+from data_preprocessor import DataPreprocessor
|
|
|
+from gat import GAT
|
|
|
+from data_trainer import DataTrainer
|
|
|
+from rl_optimizer import RLOptimizer
|
|
|
+import logging
|
|
|
+import os
|
|
|
+
|
|
|
+def setup_logger(args):
|
|
|
+ """设置日志记录"""
|
|
|
+ if not os.path.exists('logs'):
|
|
|
+ os.makedirs('logs')
|
|
|
+
|
|
|
+ logger = logging.getLogger('GAT-Training')
|
|
|
+ logger.setLevel(logging.INFO)
|
|
|
+
|
|
|
+ # 文件处理器
|
|
|
+ file_handler = logging.FileHandler(f'logs/training_{args.num_files}.log')
|
|
|
+ file_handler.setLevel(logging.INFO)
|
|
|
+
|
|
|
+ # 控制台处理器
|
|
|
+ console_handler = logging.StreamHandler()
|
|
|
+ console_handler.setLevel(logging.INFO)
|
|
|
+
|
|
|
+ # 格式化器
|
|
|
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
+ file_handler.setFormatter(formatter)
|
|
|
+ console_handler.setFormatter(formatter)
|
|
|
+
|
|
|
+ logger.addHandler(file_handler)
|
|
|
+ logger.addHandler(console_handler)
|
|
|
+
|
|
|
+ return logger
|
|
|
+
|
|
|
+def main():
|
|
|
+ # 获取参数
|
|
|
+ args = get_args()
|
|
|
+ logger = setup_logger(args)
|
|
|
+ logger.info(f"使用设备: {args.device}")
|
|
|
+
|
|
|
+ # 数据预处理
|
|
|
+ preprocessor = DataPreprocessor(args, logger)
|
|
|
+ train_loader, val_loader, test_loader, preprocessor = preprocessor.preprocess()
|
|
|
+
|
|
|
+ # 创建有向图邻接矩阵
|
|
|
+ adj = preprocessor.create_adjacency_matrix()
|
|
|
+ logger.info(f"邻接矩阵形状: {adj.shape}")
|
|
|
+
|
|
|
+ # 步骤1: 使用强化学习优化超参数
|
|
|
+ rl_optimizer = RLOptimizer(args, preprocessor, train_loader, val_loader, adj, logger)
|
|
|
+ best_hparams = rl_optimizer.optimize()
|
|
|
+
|
|
|
+ # 步骤2: 使用最优超参数训练最终模型
|
|
|
+ logger.info("\n使用最优超参数训练最终模型...")
|
|
|
+ final_model = GAT(
|
|
|
+ nfeat=1,
|
|
|
+ nhid=best_hparams['hidden_dim'],
|
|
|
+ noutput=args.num_targets,
|
|
|
+ dropout=best_hparams['dropout'],
|
|
|
+ nheads=best_hparams['num_heads'],
|
|
|
+ alpha=0.2
|
|
|
+ ).to(args.device)
|
|
|
+
|
|
|
+ # 配置优化器和学习率调度器
|
|
|
+ optimizer = optim.Adam(
|
|
|
+ final_model.parameters(),
|
|
|
+ lr=best_hparams['lr'],
|
|
|
+ weight_decay=args.weight_decay
|
|
|
+ )
|
|
|
+
|
|
|
+ # 学习率调度器
|
|
|
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
|
|
+ optimizer, mode='min', factor=0.5, patience=10, verbose=True
|
|
|
+ )
|
|
|
+
|
|
|
+ # 训练最终模型
|
|
|
+ trainer = DataTrainer(final_model, args, preprocessor, optimizer, scheduler, logger)
|
|
|
+ trained_model = trainer.train(train_loader, val_loader, adj)
|
|
|
+
|
|
|
+ # 步骤3: 在测试集上评估
|
|
|
+ logger.info("\n在测试集上评估最终模型...")
|
|
|
+ test_results = trainer.test(test_loader, adj)
|
|
|
+
|
|
|
+ logger.info("所有任务完成!")
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
main()
|