Ver código fonte

主运行文件

zhanghao 4 meses atrás
pai
commit
998767661d
1 arquivos alterados com 87 adições e 87 exclusões
  1. 87 87
      models/causal-inference/main.py

+ 87 - 87
models/causal-inference/main.py

@@ -1,88 +1,88 @@
-import torch.optim as optim
-from args import get_args
-from data_preprocessor import DataPreprocessor
-from gat import GAT
-from data_trainer import DataTrainer
-from rl_optimizer import RLOptimizer
-import logging
-import os
-
-def setup_logger(args):
-    """设置日志记录"""
-    if not os.path.exists('logs'):
-        os.makedirs('logs')
-    
-    logger = logging.getLogger('GAT-Training')
-    logger.setLevel(logging.INFO)
-    
-    # 文件处理器
-    file_handler = logging.FileHandler(f'logs/training_{args.num_files}.log')
-    file_handler.setLevel(logging.INFO)
-    
-    # 控制台处理器
-    console_handler = logging.StreamHandler()
-    console_handler.setLevel(logging.INFO)
-    
-    # 格式化器
-    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
-    file_handler.setFormatter(formatter)
-    console_handler.setFormatter(formatter)
-    
-    logger.addHandler(file_handler)
-    logger.addHandler(console_handler)
-    
-    return logger
-
-def main():
-    # 获取参数
-    args = get_args()
-    logger = setup_logger(args)
-    logger.info(f"使用设备: {args.device}")
-    
-    # 数据预处理
-    preprocessor = DataPreprocessor(args, logger)
-    train_loader, val_loader, test_loader, preprocessor = preprocessor.preprocess()
-    
-    # 创建有向图邻接矩阵
-    adj = preprocessor.create_adjacency_matrix()
-    logger.info(f"邻接矩阵形状: {adj.shape}")
-    
-    # 步骤1: 使用强化学习优化超参数
-    rl_optimizer = RLOptimizer(args, preprocessor, train_loader, val_loader, adj, logger)
-    best_hparams = rl_optimizer.optimize()
-    
-    # 步骤2: 使用最优超参数训练最终模型
-    logger.info("\n使用最优超参数训练最终模型...")
-    final_model = GAT(
-        nfeat=1,
-        nhid=best_hparams['hidden_dim'],
-        noutput=args.num_targets,
-        dropout=best_hparams['dropout'],
-        nheads=best_hparams['num_heads'],
-        alpha=0.2
-    ).to(args.device)
-    
-    # 配置优化器和学习率调度器
-    optimizer = optim.Adam(
-        final_model.parameters(),
-        lr=best_hparams['lr'],
-        weight_decay=args.weight_decay
-    )
-    
-    # 学习率调度器
-    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
-        optimizer, mode='min', factor=0.5, patience=10, verbose=True
-    )
-    
-    # 训练最终模型
-    trainer = DataTrainer(final_model, args, preprocessor, optimizer, scheduler, logger)
-    trained_model = trainer.train(train_loader, val_loader, adj)
-    
-    # 步骤3: 在测试集上评估
-    logger.info("\n在测试集上评估最终模型...")
-    test_results = trainer.test(test_loader, adj)
-    
-    logger.info("所有任务完成!")
-
-if __name__ == "__main__":
+import torch.optim as optim
+from args import get_args
+from data_preprocessor import DataPreprocessor
+from gat import GAT
+from data_trainer import DataTrainer
+from rl_optimizer import RLOptimizer
+import logging
+import os
+
+def setup_logger(args):
+    """设置日志记录"""
+    if not os.path.exists('logs'):
+        os.makedirs('logs')
+    
+    logger = logging.getLogger('GAT-Training')
+    logger.setLevel(logging.INFO)
+    
+    # 文件处理器
+    file_handler = logging.FileHandler(f'logs/training_{args.num_files}.log')
+    file_handler.setLevel(logging.INFO)
+    
+    # 控制台处理器
+    console_handler = logging.StreamHandler()
+    console_handler.setLevel(logging.INFO)
+    
+    # 格式化器
+    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+    file_handler.setFormatter(formatter)
+    console_handler.setFormatter(formatter)
+    
+    logger.addHandler(file_handler)
+    logger.addHandler(console_handler)
+    
+    return logger
+
+def main():
+    # 获取参数
+    args = get_args()
+    logger = setup_logger(args)
+    logger.info(f"使用设备: {args.device}")
+    
+    # 数据预处理
+    preprocessor = DataPreprocessor(args, logger)
+    train_loader, val_loader, test_loader, preprocessor = preprocessor.preprocess()
+    
+    # 创建有向图邻接矩阵
+    adj = preprocessor.create_adjacency_matrix()
+    logger.info(f"邻接矩阵形状: {adj.shape}")
+    
+    # 步骤1: 使用强化学习优化超参数
+    rl_optimizer = RLOptimizer(args, preprocessor, train_loader, val_loader, adj, logger)
+    best_hparams = rl_optimizer.optimize()
+    
+    # 步骤2: 使用最优超参数训练最终模型
+    logger.info("\n使用最优超参数训练最终模型...")
+    final_model = GAT(
+        nfeat=1,
+        nhid=best_hparams['hidden_dim'],
+        noutput=args.num_targets,
+        dropout=best_hparams['dropout'],
+        nheads=best_hparams['num_heads'],
+        alpha=0.2
+    ).to(args.device)
+    
+    # 配置优化器和学习率调度器
+    optimizer = optim.Adam(
+        final_model.parameters(),
+        lr=best_hparams['lr'],
+        weight_decay=args.weight_decay
+    )
+    
+    # 学习率调度器
+    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
+        optimizer, mode='min', factor=0.5, patience=10, verbose=True
+    )
+    
+    # 训练最终模型
+    trainer = DataTrainer(final_model, args, preprocessor, optimizer, scheduler, logger)
+    trained_model = trainer.train(train_loader, val_loader, adj)
+    
+    # 步骤3: 在测试集上评估
+    logger.info("\n在测试集上评估最终模型...")
+    test_results = trainer.test(test_loader, adj)
+    
+    logger.info("所有任务完成!")
+
+if __name__ == "__main__":
     main()