main.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. """
  2. 因果推理模型主程序(Causal Inference Main Program)
  3. 本程序实现了基于强化学习优化的图注意力网络训练流程,用于工业时间序列预测。
  4. 整个系统分为三个核心阶段:
  5. 1. 数据预处理阶段: 数据加载、清洗、降噪、归一化、图构建
  6. 2. RL超参数优化阶段: 使用PPO算法自动搜索最优超参数
  7. 3. 最终训练评估阶段: 使用最优参数训练模型并在测试集上评估
  8. 核心特点:
  9. - 自动化超参数优化: 无需手动调参,RL智能体自动寻找最优配置
  10. - 有向图注意力: 建模特征间的因果关系,支持非对称影响
  11. - 小波降噪预处理: 提升数据质量,增强模型精度
  12. - 完善的监控机制: 日志记录、早停、学习率调度、模型保存
  13. 技术栈:
  14. - PyTorch: 深度学习框架
  15. - Stable-Baselines3: 强化学习库(PPO算法)
  16. - PyWavelets: 小波变换库
  17. - Scikit-learn: 数据预处理
  18. 工作流程:
  19. main() → 数据预处理 → RL优化超参数 → 训练最终模型 → 测试评估
  20. """
  21. import torch.optim as optim
  22. from args import get_args
  23. from data_preprocessor import DataPreprocessor
  24. from gat import GAT
  25. from data_trainer import DataTrainer
  26. from rl_optimizer import RLOptimizer
  27. import logging
  28. import os
  29. def setup_logger(args):
  30. """
  31. 配置日志系统
  32. 功能:
  33. 创建并配置日志记录器,同时输出到控制台和文件。
  34. 日志文件以训练数据文件数量命名,便于区分不同实验。
  35. 参数:
  36. args: 命令行参数对象
  37. - args.num_files: 数据文件数量,用于日志文件命名
  38. 返回:
  39. logging.Logger: 配置好的日志记录器
  40. 日志级别:
  41. INFO: 记录关键步骤和指标信息
  42. 输出位置:
  43. - 控制台: 实时查看训练进度
  44. - 文件: logs/training_{num_files}.log,便于事后分析
  45. 日志格式:
  46. 时间戳 - 记录器名称 - 日志级别 - 消息内容
  47. 示例: 2025-01-10 10:30:45 - GAT-Training - INFO - 开始训练
  48. 技术要点:
  49. - 自动创建logs目录
  50. - 文件和控制台使用相同的格式化器
  51. - 避免重复添加处理器
  52. """
  53. # 创建日志目录(如果不存在)
  54. if not os.path.exists('logs'):
  55. os.makedirs('logs')
  56. # 创建日志记录器
  57. logger = logging.getLogger('GAT-Training')
  58. logger.setLevel(logging.INFO)
  59. # 文件处理器: 将日志写入文件
  60. file_handler = logging.FileHandler(f'logs/training_{args.num_files}.log')
  61. file_handler.setLevel(logging.INFO)
  62. # 控制台处理器: 将日志输出到终端
  63. console_handler = logging.StreamHandler()
  64. console_handler.setLevel(logging.INFO)
  65. # 格式化器: 定义日志消息的格式
  66. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  67. file_handler.setFormatter(formatter)
  68. console_handler.setFormatter(formatter)
  69. # 添加处理器到记录器
  70. logger.addHandler(file_handler) # 添加文件处理器
  71. logger.addHandler(console_handler) # 添加控制台处理器
  72. return logger
  73. def main():
  74. """
  75. 主程序入口
  76. 功能:
  77. 协调整个训练流程,包括数据预处理、RL优化、模型训练和测试评估。
  78. 这是整个系统的控制中心,按顺序执行各个阶段的任务。
  79. 执行流程:
  80. 第一阶段: 数据预处理
  81. 1. 加载50个CSV数据文件
  82. 2. 时间特征分解(年月日时分秒)
  83. 3. 小波降噪(db4小波,1层分解)
  84. 4. 数据归一化(StandardScaler)
  85. 5. 划分训练集/验证集/测试集(70%/10%/20%)
  86. 6. 构建有向图邻接矩阵(相关性阈值0.3)
  87. 第二阶段: RL超参数优化
  88. 1. 创建GATEnv强化学习环境
  89. 2. 使用PPO算法训练5000时间步
  90. 3. 搜索最优超参数(lr, hidden_dim, num_heads, dropout)
  91. 4. 快速评估策略(1-2个batch)加速收敛
  92. 5. 选择奖励最高的超参数组合
  93. 第三阶段: 最终模型训练
  94. 1. 使用最优超参数创建GAT模型
  95. 2. 配置Adam优化器和学习率调度器
  96. 3. 训练最多100轮,早停耐心20轮
  97. 4. 保存最佳模型和最终模型
  98. 5. 生成训练曲线图
  99. 第四阶段: 测试评估
  100. 1. 加载最佳模型
  101. 2. 在测试集上评估性能
  102. 3. 计算归一化和原始尺度的MSE/MAE/RMSE
  103. 4. 生成预测对比图
  104. 输出文件:
  105. 日志文件:
  106. - logs/training_{num_files}.log
  107. 归一化器:
  108. - scalers/features_scaler.joblib
  109. - scalers/targets_scaler.joblib
  110. 模型文件:
  111. - models/best_model.pth (验证损失最低的模型)
  112. - models/final_model.pth (训练完成后的最终模型)
  113. - gat_ppo_agent (RL优化器模型)
  114. 可视化图表:
  115. - plots/loss_curve.png (训练/验证损失曲线)
  116. - plots/mae_curve.png (训练/验证MAE曲线)
  117. - plots/prediction_examples.png (预测vs真实值对比)
  118. 关键技术:
  119. 1. RL自动调参: 避免手动网格搜索,智能寻优
  120. 2. 有向图建模: 捕捉特征间的因果关系
  121. 3. 小波降噪: 提升数据质量
  122. 4. 早停机制: 防止过拟合
  123. 5. 学习率调度: 自适应调整学习率
  124. 性能优化:
  125. - GPU加速: 自动检测并使用CUDA
  126. - 梯度裁剪: 防止梯度爆炸
  127. - Dropout正则化: 防止过拟合
  128. - ReduceLROnPlateau: 验证损失停滞时降低学习率
  129. 使用示例:
  130. >>> python main.py
  131. # 使用默认参数训练
  132. >>> python main.py --num_files 30 --epochs 50
  133. # 自定义参数训练
  134. """
  135. # ========== 阶段0: 初始化配置 ==========
  136. # 获取命令行参数(或使用默认值)
  137. args = get_args()
  138. # 配置日志系统
  139. logger = setup_logger(args)
  140. logger.info(f"使用设备: {args.device}")
  141. logger.info("=" * 80)
  142. logger.info("因果推理模型训练系统启动")
  143. logger.info("=" * 80)
  144. # ========== 阶段1: 数据预处理 ==========
  145. logger.info("\n" + "=" * 80)
  146. logger.info("阶段1: 数据预处理")
  147. logger.info("=" * 80)
  148. # 创建数据预处理器
  149. preprocessor = DataPreprocessor(args, logger)
  150. # 执行完整的预处理流程
  151. # 返回: train_loader(训练数据加载器), val_loader(验证数据加载器),
  152. # test_loader(测试数据加载器), preprocessor(预处理器对象)
  153. train_loader, val_loader, test_loader, preprocessor = preprocessor.preprocess()
  154. logger.info("数据预处理完成!")
  155. # 创建有向图邻接矩阵
  156. # 基于特征相关性构建图结构,相关性>0.3的特征对之间建立有向边
  157. adj = preprocessor.create_adjacency_matrix()
  158. logger.info(f"邻接矩阵形状: {adj.shape}")
  159. logger.info(f"边的数量: {int(adj.sum())}")
  160. # ========== 阶段2: RL超参数优化 ==========
  161. logger.info("\n" + "=" * 80)
  162. logger.info("阶段2: 强化学习超参数优化")
  163. logger.info("=" * 80)
  164. logger.info("使用PPO算法搜索最优超参数...")
  165. # 创建RL优化器
  166. # 在环境中评估不同的超参数组合,找到使验证损失最小的配置
  167. rl_optimizer = RLOptimizer(args, preprocessor, train_loader, val_loader, adj, logger)
  168. # 执行优化,返回最优超参数字典
  169. # best_hparams包含: lr(学习率), hidden_dim(隐藏层维度),
  170. # num_heads(注意力头数), dropout(dropout率)
  171. best_hparams = rl_optimizer.optimize()
  172. logger.info(f"最优超参数: {best_hparams}")
  173. # ========== 阶段3: 使用最优超参数训练最终模型 ==========
  174. logger.info("\n" + "=" * 80)
  175. logger.info("阶段3: 训练最终模型")
  176. logger.info("=" * 80)
  177. logger.info("使用RL优化得到的最优超参数...")
  178. # 创建GAT模型,使用最优超参数
  179. final_model = GAT(
  180. nfeat=1, # 输入特征维度(每个节点1维)
  181. nhid=best_hparams['hidden_dim'], # 隐藏层维度(RL优化得到)
  182. noutput=args.num_targets, # 输出维度(47个目标变量)
  183. dropout=best_hparams['dropout'], # Dropout率(RL优化得到)
  184. nheads=best_hparams['num_heads'], # 注意力头数(RL优化得到)
  185. alpha=0.2 # LeakyReLU斜率(固定值)
  186. ).to(args.device) # 移动到GPU(如果可用)
  187. logger.info(f"模型结构: nfeat=1, nhid={best_hparams['hidden_dim']}, "
  188. f"noutput={args.num_targets}, dropout={best_hparams['dropout']}, "
  189. f"nheads={best_hparams['num_heads']}")
  190. # 配置优化器
  191. # Adam优化器: 自适应学习率,使用RL优化得到的学习率
  192. optimizer = optim.Adam(
  193. final_model.parameters(),
  194. lr=best_hparams['lr'], # 学习率(RL优化得到)
  195. weight_decay=args.weight_decay # L2正则化系数
  196. )
  197. logger.info(f"优化器: Adam(lr={best_hparams['lr']}, weight_decay={args.weight_decay})")
  198. # 配置学习率调度器
  199. # ReduceLROnPlateau: 当验证损失停滞时,将学习率降低一半
  200. scheduler = optim.lr_scheduler.ReduceLROnPlateau(
  201. optimizer,
  202. mode='min', # 监控指标越小越好(损失函数)
  203. factor=0.5, # 降低因子(新lr = 旧lr * 0.5)
  204. patience=10, # 容忍10轮无改善
  205. verbose=True # 打印学习率变化信息
  206. )
  207. logger.info("学习率调度器: ReduceLROnPlateau(factor=0.5, patience=10)")
  208. # 创建训练器
  209. # 负责模型训练、验证、测试和可视化
  210. trainer = DataTrainer(
  211. model=final_model,
  212. args=args,
  213. preprocessor=preprocessor,
  214. optimizer=optimizer,
  215. scheduler=scheduler,
  216. logger=logger
  217. )
  218. # 执行训练
  219. # 训练最多100轮,使用早停机制(耐心20轮)
  220. # 自动保存最佳模型(验证损失最低)和最终模型
  221. logger.info("开始训练循环...")
  222. trained_model = trainer.train(train_loader, val_loader, adj)
  223. logger.info("模型训练完成!")
  224. # ========== 阶段4: 在测试集上评估 ==========
  225. logger.info("\n" + "=" * 80)
  226. logger.info("阶段4: 测试集评估")
  227. logger.info("=" * 80)
  228. logger.info("在测试集上评估最终模型性能...")
  229. # 测试模型性能
  230. # 返回归一化和原始尺度的MSE/MAE/RMSE指标
  231. test_results = trainer.test(test_loader, adj)
  232. # 打印最终结果摘要
  233. logger.info("\n" + "=" * 80)
  234. logger.info("训练完成总结")
  235. logger.info("=" * 80)
  236. logger.info(f"最优超参数: {best_hparams}")
  237. logger.info(f"测试集性能(归一化):")
  238. logger.info(f" - MSE: {test_results['normalized_mse']:.6f}")
  239. logger.info(f" - MAE: {test_results['normalized_mae']:.6f}")
  240. logger.info(f" - RMSE: {test_results['normalized_rmse']:.6f}")
  241. logger.info(f"测试集性能(原始尺度):")
  242. logger.info(f" - MSE: {test_results['original_mse']:.6f}")
  243. logger.info(f" - MAE: {test_results['original_mae']:.6f}")
  244. logger.info(f" - RMSE: {test_results['original_rmse']:.6f}")
  245. logger.info("=" * 80)
  246. logger.info("所有任务完成!")
  247. logger.info("=" * 80)
  248. if __name__ == "__main__":
  249. """
  250. 程序入口点
  251. 直接运行此文件时执行main()函数。
  252. 支持命令行参数自定义配置,详见args.py。
  253. 运行方式:
  254. python main.py # 使用默认参数
  255. python main.py --epochs 50 # 自定义训练轮数
  256. python main.py --num_files 30 # 自定义数据文件数量
  257. """
  258. main()