DQN_train.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. """
  2. DQN 强化学习训练模块
  3. ======================
  4. 本模块实现基于 Stable-Baselines3 的 DQN 强化学习训练流程,包括:
  5. 1. DQNParams: DQN超参数配置类
  6. 2. UFEpisodeRecorder: Episode数据记录器
  7. 3. UFTrainingCallback: 训练回调器
  8. 4. DQNTrainer: DQN训练器封装
  9. 5. train_uf_rl_agent: 主训练函数
  10. DQN算法简介:
  11. - Deep Q-Network(深度Q网络)
  12. - 基于价值的强化学习算法
  13. - 使用经验回放和目标网络稳定训练
  14. - 适用于离散动作空间
  15. 训练流程:
  16. 1. 初始化环境和DQN智能体
  17. 2. 收集经验(exploration)
  18. 3. 从经验池采样训练(exploitation)
  19. 4. 周期性更新目标网络
  20. 5. 记录训练指标到TensorBoard
  21. """
  22. import os
  23. import time
  24. import random
  25. import numpy as np
  26. import torch
  27. from stable_baselines3 import DQN
  28. from stable_baselines3.common.monitor import Monitor
  29. from stable_baselines3.common.vec_env import DummyVecEnv
  30. from stable_baselines3.common.callbacks import BaseCallback
  31. from DQN_env import UFParams, UFSuperCycleEnv
  32. # ==================== DQN超参数配置类 ====================
  33. class DQNParams:
  34. """
  35. DQN 超参数配置类
  36. 功能:统一管理DQN算法的所有超参数
  37. 超参数说明:
  38. - learning_rate: 神经网络学习率,控制梯度下降的步长
  39. - buffer_size: 经验回放缓冲区大小,存储历史经验
  40. - learning_starts: 开始训练前先收集的经验数量(warm-up)
  41. - batch_size: 每次训练采样的batch大小
  42. - gamma: 折扣因子,权衡即时奖励和长期奖励
  43. - train_freq: 训练频率,每隔多少步训练一次
  44. - target_update_interval: 目标网络更新频率
  45. - tau: 软更新系数(soft update)
  46. - exploration_*: ε-贪心策略的探索率参数
  47. """
  48. # ========== 神经网络参数 ==========
  49. learning_rate: float = 1e-4
  50. # 学习率,控制神经网络权重更新的步长
  51. # 典型范围:1e-5 ~ 1e-3
  52. # 过大:训练不稳定;过小:收敛慢
  53. # ========== 经验回放参数 ==========
  54. buffer_size: int = 100000
  55. # 经验回放缓冲区大小(可存储的transition数量)
  56. # 作用:打破样本间的时间相关性,提高训练稳定性
  57. # 建议:至少存储几个完整episode的经验
  58. learning_starts: int = 10000
  59. # 开始训练前先收集的步数(预填充缓冲区)
  60. # 作用:确保缓冲区有足够的多样性样本再开始训练
  61. # 建议:设为buffer_size的10%-20%
  62. batch_size: int = 32
  63. # 每次训练从缓冲区采样的样本数量
  64. # 典型值:32, 64, 128, 256
  65. # 过大:显存占用高,训练慢;过小:梯度估计不准确
  66. # ========== 强化学习参数 ==========
  67. gamma: float = 0.95
  68. # 折扣因子(discount factor),γ ∈ [0, 1]
  69. # 作用:权衡即时奖励和长期奖励
  70. # γ=0:只考虑当前奖励(短视)
  71. # γ=1:完全考虑未来奖励(长视)
  72. # 通常设为0.9-0.99
  73. train_freq: int = 4
  74. # 训练频率:每收集多少步执行一次训练
  75. # 作用:平衡数据收集和网络更新
  76. # 典型值:1(每步训练)或4-16(批量训练)
  77. # ========== 目标网络参数 ==========
  78. target_update_interval: int = 1
  79. # 目标网络更新间隔(硬更新)
  80. # 作用:目标网络每隔多少次训练更新一次
  81. # 注:使用软更新(tau)时此参数通常设为1
  82. tau: float = 0.005
  83. # 软更新系数(soft update)
  84. # θ_target = τ×θ + (1-τ)×θ_target
  85. # τ=1:硬更新(完全复制)
  86. # τ<<1:软更新(平滑过渡,更稳定)
  87. # 典型值:0.001 - 0.01
  88. # ========== 探索策略参数(ε-greedy) ==========
  89. exploration_initial_eps: float = 1.0
  90. # 初始探索率 ε_0
  91. # ε=1:完全随机探索
  92. # ε=0:完全利用已学知识
  93. exploration_fraction: float = 0.3
  94. # 探索率衰减比例
  95. # 表示训练总步数的前30%进行ε衰减
  96. # 例:总共10万步,前3万步ε从1.0衰减到0.02
  97. exploration_final_eps: float = 0.02
  98. # 最终探索率 ε_final
  99. # 衰减结束后保持此值(保留小概率探索)
  100. # 典型值:0.01 - 0.05
  101. # ========== 日志参数 ==========
  102. remark: str = "default"
  103. # 实验备注,用于区分不同训练实验
  104. # 会自动添加到TensorBoard日志目录名中
  105. # ==================== Episode数据记录器 ====================
  106. class UFEpisodeRecorder:
  107. """
  108. Episode数据记录器
  109. 功能:
  110. - 记录训练过程中每个episode的详细数据
  111. - 存储每步的状态、动作、奖励、info等信息
  112. - 计算episode级别的统计指标
  113. 用途:
  114. - 训练监控:实时查看智能体表现
  115. - 调试分析:定位问题episode
  116. - 数据分析:评估策略改进效果
  117. """
  118. def __init__(self):
  119. """初始化记录器"""
  120. self.episode_data = [] # 存储所有完成的episode数据
  121. self.current_episode = [] # 当前正在进行的episode数据
  122. def record_step(self, obs, action, reward, done, info):
  123. """
  124. 记录单步交互数据
  125. 参数:
  126. obs: 当前状态观测
  127. action: 执行的动作
  128. reward: 获得的奖励
  129. done: 是否结束
  130. info: 额外信息字典
  131. """
  132. # 构建单步数据字典
  133. step_data = {
  134. "obs": obs.copy(), # 状态(深拷贝避免引用问题)
  135. "action": action.copy(), # 动作
  136. "reward": reward, # 奖励
  137. "done": done, # 是否终止
  138. "info": info.copy() if info else {} # 环境信息
  139. }
  140. # 添加到当前episode
  141. self.current_episode.append(step_data)
  142. # 如果episode结束,保存并重置
  143. if done:
  144. self.episode_data.append(self.current_episode)
  145. self.current_episode = []
  146. def get_episode_stats(self, episode_idx=-1):
  147. """
  148. 获取指定episode的统计信息
  149. 参数:
  150. episode_idx (int): episode索引,默认-1(最后一个)
  151. 返回:
  152. dict: 包含以下统计指标的字典
  153. - total_reward: 总奖励
  154. - avg_recovery: 平均回收率
  155. - feasible_steps: 可行步数
  156. - total_steps: 总步数
  157. """
  158. if not self.episode_data:
  159. return {}
  160. episode = self.episode_data[episode_idx]
  161. # 计算总奖励
  162. total_reward = sum(step["reward"] for step in episode)
  163. # 计算平均回收率(从info中提取)
  164. recovery_values = [
  165. step["info"].get("recovery", 0)
  166. for step in episode
  167. if "recovery" in step["info"]
  168. ]
  169. avg_recovery = np.mean(recovery_values) if recovery_values else 0.0
  170. # 计算可行步数(成功的超级周期数)
  171. feasible_steps = sum(
  172. 1 for step in episode
  173. if step["info"].get("feasible", False)
  174. )
  175. return {
  176. "total_reward": total_reward,
  177. "avg_recovery": avg_recovery,
  178. "feasible_steps": feasible_steps,
  179. "total_steps": len(episode)
  180. }
  181. # ==================== 训练回调器 ====================
  182. class UFTrainingCallback(BaseCallback):
  183. """
  184. 自定义训练回调器
  185. 功能:
  186. - 在每个训练步骤调用,记录数据到recorder
  187. - 兼容Stable-Baselines3的回调机制
  188. - 不依赖环境内部属性,使用标准接口获取数据
  189. 回调时机:
  190. - _on_step(): 每执行一步环境交互后调用
  191. 设计特点:
  192. 1. 从self.locals获取当前步的数据(SB3提供的接口)
  193. 2. 处理向量化环境(DummyVecEnv)的数据格式
  194. 3. 自动检测episode结束并触发记录
  195. """
  196. def __init__(self, recorder, verbose=0):
  197. """
  198. 初始化回调器
  199. 参数:
  200. recorder (UFEpisodeRecorder): 数据记录器实例
  201. verbose (int): 日志详细程度,0=关闭,1=打印每步信息
  202. """
  203. super(UFTrainingCallback, self).__init__(verbose)
  204. self.recorder = recorder
  205. def _on_step(self) -> bool:
  206. """
  207. 每步回调函数(Stable-Baselines3标准接口)
  208. 返回:
  209. bool: True表示继续训练,False表示提前终止
  210. """
  211. try:
  212. # 从SB3的self.locals获取当前步数据
  213. new_obs = self.locals.get("new_obs") # 新状态
  214. actions = self.locals.get("actions") # 执行的动作
  215. rewards = self.locals.get("rewards") # 获得的奖励
  216. dones = self.locals.get("dones") # 是否结束
  217. infos = self.locals.get("infos") # 环境信息
  218. # 处理向量化环境(取第一个环境的数据)
  219. if len(new_obs) > 0:
  220. step_obs = new_obs[0]
  221. step_action = actions[0] if actions is not None else None
  222. step_reward = rewards[0] if rewards is not None else 0.0
  223. step_done = dones[0] if dones is not None else False
  224. step_info = infos[0] if infos is not None else {}
  225. # 可选:打印当前步信息(用于调试)
  226. if self.verbose:
  227. print(f"[Step {self.num_timesteps}] "
  228. f"动作={step_action}, "
  229. f"奖励={step_reward:.3f}, "
  230. f"Done={step_done}")
  231. # 记录数据到recorder
  232. self.recorder.record_step(
  233. obs=step_obs,
  234. action=step_action,
  235. reward=step_reward,
  236. done=step_done,
  237. info=step_info,
  238. )
  239. except Exception as e:
  240. # 异常处理:避免回调错误中断训练
  241. if self.verbose:
  242. print(f"[Callback Error] {e}")
  243. # 返回True继续训练
  244. return True
  245. # ==================== DQN训练器封装类 ====================
  246. class DQNTrainer:
  247. def __init__(self, env, params, callback=None):
  248. """
  249. 初始化训练器
  250. 参数:
  251. env: Gymnasium环境
  252. params: DQN 超参数配置
  253. callback: 可选训练回调
  254. """
  255. self.env = env
  256. self.params = params
  257. self.callback = callback
  258. self.log_dir = self._create_log_dir() # 创建日志目录
  259. self.model = self._create_model() # 创建 DQN 模型
  260. def _create_log_dir(self):
  261. """
  262. 创建 TensorBoard 日志目录,保证 Windows 下路径安全
  263. 返回:
  264. str: 可用的日志目录路径
  265. """
  266. timestamp = time.strftime("%Y%m%d-%H%M%S")
  267. # 用整数代替浮点数,避免路径中包含小数点
  268. lr_int = int(self.params.learning_rate * 1e4)
  269. gamma_int = int(self.params.gamma * 100)
  270. exp_int = int(self.params.exploration_fraction * 100)
  271. # 生成目录名
  272. log_name = f"DQN_lr{lr_int}_buf{self.params.buffer_size}_bs{self.params.batch_size}_gamma{gamma_int}_exp{exp_int}_{self.params.remark}_{timestamp}"
  273. # 使用短路径,避免 Windows 路径过长
  274. base_dir = r"E:\Greentech\models\uf-rl\uf_dqn_tensorboard"
  275. os.makedirs(base_dir, exist_ok=True)
  276. log_dir = os.path.join(base_dir, log_name)
  277. # 尝试创建目录,防止偶发锁或占用
  278. attempt = 0
  279. while attempt < 5:
  280. try:
  281. os.makedirs(log_dir, exist_ok=True)
  282. if not os.path.isdir(log_dir):
  283. raise RuntimeError(f"{log_dir} 已存在但不是目录!")
  284. break
  285. except Exception as e:
  286. attempt += 1
  287. time.sleep(0.1)
  288. log_dir += f"_{attempt}"
  289. else:
  290. raise RuntimeError(f"无法创建日志目录: {log_dir}")
  291. return log_dir
  292. def _create_model(self):
  293. """
  294. 创建 Stable-Baselines3 DQN 模型
  295. """
  296. model = DQN(
  297. policy="MlpPolicy",
  298. env=self.env,
  299. learning_rate=self.params.learning_rate,
  300. buffer_size=self.params.buffer_size,
  301. learning_starts=self.params.learning_starts,
  302. batch_size=self.params.batch_size,
  303. gamma=self.params.gamma,
  304. train_freq=self.params.train_freq,
  305. target_update_interval=1,
  306. tau=0.005,
  307. exploration_initial_eps=self.params.exploration_initial_eps,
  308. exploration_fraction=self.params.exploration_fraction,
  309. exploration_final_eps=self.params.exploration_final_eps,
  310. verbose=1,
  311. tensorboard_log=self.log_dir
  312. )
  313. return model
  314. def train(self, total_timesteps: int):
  315. """
  316. 执行训练
  317. 参数:
  318. total_timesteps (int): 总训练步数
  319. 注:对于超滤环境,每步代表一个超级周期(约2-3天)
  320. 150000步 ≈ 10000个episode ≈ 10000个超级周期 ≈ 约54年
  321. """
  322. if self.callback:
  323. # 使用回调器训练
  324. self.model.learn(total_timesteps=total_timesteps, callback=self.callback)
  325. else:
  326. # 不使用回调器训练
  327. self.model.learn(total_timesteps=total_timesteps)
  328. print(f"✅ 模型训练完成!")
  329. print(f"📊 日志保存在:{self.log_dir}")
  330. print(f"💡 使用以下命令查看TensorBoard:")
  331. print(f" tensorboard --logdir={self.log_dir}")
  332. def save(self, path=None):
  333. """
  334. 保存模型
  335. 参数:
  336. path (str, optional): 保存路径,默认保存到日志目录下的dqn_model.zip
  337. """
  338. if path is None:
  339. path = os.path.join(self.log_dir, "dqn_model.zip")
  340. self.model.save(path)
  341. print(f"💾 模型已保存到:{path}")
  342. def load(self, path):
  343. """
  344. 加载模型
  345. 参数:
  346. path (str): 模型文件路径(.zip文件)
  347. """
  348. self.model = DQN.load(path, env=self.env)
  349. print(f"📥 模型已从 {path} 加载")
  350. # ==================== 辅助函数:随机种子设置 ====================
  351. def set_global_seed(seed: int):
  352. """
  353. 固定全局随机种子,保证训练可复现
  354. 参数:
  355. seed (int): 随机种子
  356. 作用:
  357. - 固定Python、NumPy、PyTorch的随机数生成器
  358. - 确保相同种子产生相同的训练结果
  359. - 便于实验对比和问题复现
  360. 注意:
  361. - 即使固定种子,多线程/多进程仍可能产生微小差异
  362. - GPU运算的非确定性也可能影响复现性
  363. """
  364. random.seed(seed) # Python随机数
  365. np.random.seed(seed) # NumPy随机数
  366. torch.manual_seed(seed) # PyTorch CPU随机数
  367. torch.cuda.manual_seed_all(seed) # PyTorch GPU随机数
  368. # 设置PyTorch为确定性模式(可能影响性能)
  369. torch.backends.cudnn.deterministic = True
  370. torch.backends.cudnn.benchmark = False
  371. # ==================== 主训练函数 ====================
  372. def train_uf_rl_agent(params: UFParams, total_timesteps: int = 10000, seed: int = 2025):
  373. """
  374. 超滤强化学习智能体训练主函数
  375. 参数:
  376. params (UFParams): 超滤环境参数
  377. total_timesteps (int): 总训练步数,默认10000
  378. seed (int): 随机种子,默认2025
  379. 返回:
  380. DQN: 训练好的DQN模型
  381. 训练流程:
  382. 1. 固定随机种子(确保可复现)
  383. 2. 创建记录器和回调器
  384. 3. 创建并包装环境(Monitor + DummyVecEnv)
  385. 4. 初始化DQN训练器
  386. 5. 执行训练
  387. 6. 保存模型
  388. 7. 输出统计信息
  389. """
  390. # 步骤1:固定随机种子
  391. set_global_seed(seed)
  392. print(f"🎲 随机种子已设置为: {seed}")
  393. # 步骤2:创建数据记录器和回调器
  394. recorder = UFEpisodeRecorder()
  395. callback = UFTrainingCallback(recorder, verbose=1)
  396. # 步骤3:创建环境(使用闭包和向量化)
  397. def make_env():
  398. """环境工厂函数"""
  399. env = UFSuperCycleEnv(params) # 创建超滤环境
  400. env = Monitor(env) # 包装Monitor(记录episode统计)
  401. return env
  402. # 向量化环境(即使只有一个环境,也需要向量化以兼容SB3)
  403. env = DummyVecEnv([make_env])
  404. # 步骤4:创建DQN训练器
  405. dqn_params = DQNParams()
  406. trainer = DQNTrainer(env, dqn_params, callback=callback)
  407. # 步骤5:执行训练
  408. trainer.train(total_timesteps)
  409. # 步骤6:保存模型
  410. trainer.save()
  411. # 步骤7:输出最终统计信息
  412. stats = callback.recorder.get_episode_stats()
  413. print("\n" + "="*60)
  414. print("📈 训练统计")
  415. print("="*60)
  416. print(f"总奖励: {stats.get('total_reward', 0):.2f}")
  417. print(f"平均回收率: {stats.get('avg_recovery', 0):.3f}")
  418. print(f"可行步数: {stats.get('feasible_steps', 0)}")
  419. print(f"总步数: {stats.get('total_steps', 0)}")
  420. print("="*60)
  421. return trainer.model
  422. # ==================== 主程序入口 ====================
  423. if __name__ == "__main__":
  424. """
  425. 训练脚本入口
  426. 使用方法:
  427. python fixed_DQN_train.py
  428. 训练参数:
  429. - total_timesteps=150000: 总训练步数
  430. - 约10000个episode(每个episode最多15步)
  431. - 约需训练数小时至数天(取决于硬件)
  432. """
  433. print("="*60)
  434. print("🚀 开始训练超滤强化学习智能体")
  435. print("="*60)
  436. # 初始化超滤参数
  437. params = UFParams()
  438. # 执行训练
  439. train_uf_rl_agent(params, total_timesteps=300000)
  440. print("\n🎉 训练流程全部完成!")