DQN_train.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. import os
  2. import time
  3. import random
  4. import numpy as np
  5. import torch
  6. import gymnasium as gym
  7. from gymnasium import spaces
  8. from stable_baselines3 import DQN
  9. from stable_baselines3.common.monitor import Monitor
  10. from stable_baselines3.common.vec_env import DummyVecEnv
  11. from stable_baselines3.common.callbacks import BaseCallback
  12. from DQN_env import UFParams, UFSuperCycleEnv
  13. # ==== 定义强化学习超参数 ====
  14. class DQNParams:
  15. """
  16. DQN 超参数定义类
  17. 用于统一管理模型训练参数
  18. """
  19. # 学习率,控制神经网络更新步长
  20. learning_rate: float = 1e-4
  21. # 经验回放缓冲区大小(步数)
  22. buffer_size: int = 10000
  23. # 学习开始前需要收集的步数
  24. learning_starts: int = 200
  25. # 每次从经验池中采样的样本数量
  26. batch_size: int = 32
  27. # 折扣因子,越接近1越重视长期奖励
  28. gamma: float = 0.95
  29. # 每隔多少步训练一次
  30. train_freq: int = 4
  31. # 目标网络更新间隔
  32. target_update_interval: int = 2000
  33. # 初始探索率 ε
  34. exploration_initial_eps: float = 1.0
  35. # 从初始ε衰减到最终ε所占的训练比例
  36. exploration_fraction: float = 0.3
  37. # 最终探索率 ε
  38. exploration_final_eps: float = 0.02
  39. # 日志备注(用于区分不同实验)
  40. remark: str = "default"
  41. class UFEpisodeRecorder:
  42. """记录episode中的决策和结果"""
  43. def __init__(self):
  44. self.episode_data = []
  45. self.current_episode = []
  46. def record_step(self, obs, action, reward, done, info):
  47. """记录单步信息"""
  48. step_data = {
  49. "obs": obs.copy(),
  50. "action": action.copy(),
  51. "reward": reward,
  52. "done": done,
  53. "info": info.copy() if info else {}
  54. }
  55. self.current_episode.append(step_data)
  56. if done:
  57. self.episode_data.append(self.current_episode)
  58. self.current_episode = []
  59. def get_episode_stats(self, episode_idx=-1):
  60. """获取episode统计信息"""
  61. if not self.episode_data:
  62. return {}
  63. episode = self.episode_data[episode_idx]
  64. total_reward = sum(step["reward"] for step in episode)
  65. avg_recovery = np.mean([step["info"].get("recovery", 0) for step in episode if "recovery" in step["info"]])
  66. feasible_steps = sum(1 for step in episode if step["info"].get("feasible", False))
  67. return {
  68. "total_reward": total_reward,
  69. "avg_recovery": avg_recovery,
  70. "feasible_steps": feasible_steps,
  71. "total_steps": len(episode)
  72. }
  73. # ==== 定义强化学习训练回调器 ====
  74. class UFTrainingCallback(BaseCallback):
  75. """
  76. 强化学习训练回调,用于记录每一步的数据到 recorder。
  77. 1. 不依赖环境内部 last_* 属性
  78. 2. 使用环境接口提供的 obs、actions、rewards、dones、infos
  79. 3. 自动处理 episode 结束时的统计
  80. """
  81. def __init__(self, recorder, verbose=0):
  82. super(UFTrainingCallback, self).__init__(verbose)
  83. self.recorder = recorder
  84. def _on_step(self) -> bool:
  85. try:
  86. new_obs = self.locals.get("new_obs")
  87. actions = self.locals.get("actions")
  88. rewards = self.locals.get("rewards")
  89. dones = self.locals.get("dones")
  90. infos = self.locals.get("infos")
  91. if len(new_obs) > 0:
  92. step_obs = new_obs[0]
  93. step_action = actions[0] if actions is not None else None
  94. step_reward = rewards[0] if rewards is not None else 0.0
  95. step_done = dones[0] if dones is not None else False
  96. step_info = infos[0] if infos is not None else {}
  97. # 打印当前 step 的信息
  98. if self.verbose:
  99. print(f"[Step {self.num_timesteps}] 动作={step_action}, 奖励={step_reward:.3f}, Done={step_done}")
  100. # 记录数据
  101. self.recorder.record_step(
  102. obs=step_obs,
  103. action=step_action,
  104. reward=step_reward,
  105. done=step_done,
  106. info=step_info,
  107. )
  108. except Exception as e:
  109. if self.verbose:
  110. print(f"[Callback Error] {e}")
  111. return True
  112. class DQNTrainer:
  113. def __init__(self, env, params, callback=None):
  114. self.env = env
  115. self.params = params
  116. self.callback = callback
  117. self.log_dir = self._create_log_dir()
  118. self.model = self._create_model()
  119. def _create_log_dir(self):
  120. # 创建训练日志
  121. timestamp = time.strftime("%Y%m%d-%H%M%S")
  122. log_name = (
  123. f"DQN_lr{self.params.learning_rate}_buf{self.params.buffer_size}_bs{self.params.batch_size}"
  124. f"_gamma{self.params.gamma}_exp{self.params.exploration_fraction}"
  125. f"_{self.params.remark}_{timestamp}"
  126. )
  127. log_dir = os.path.join("./uf_dqn_tensorboard", log_name)
  128. os.makedirs(log_dir, exist_ok=True)
  129. return log_dir
  130. def _create_model(self):
  131. return DQN(
  132. policy="MlpPolicy",
  133. env=self.env,
  134. learning_rate=self.params.learning_rate,
  135. buffer_size=self.params.buffer_size,
  136. learning_starts=self.params.learning_starts,
  137. batch_size=self.params.batch_size,
  138. gamma=self.params.gamma,
  139. train_freq=self.params.train_freq,
  140. target_update_interval=1,
  141. tau=0.005,
  142. exploration_initial_eps=self.params.exploration_initial_eps,
  143. exploration_fraction=self.params.exploration_fraction,
  144. exploration_final_eps=self.params.exploration_final_eps,
  145. verbose=1,
  146. tensorboard_log=self.log_dir
  147. )
  148. def train(self, total_timesteps: int):
  149. if self.callback:
  150. self.model.learn(total_timesteps=total_timesteps, callback=self.callback)
  151. else:
  152. self.model.learn(total_timesteps=total_timesteps)
  153. print(f"模型训练完成,日志保存在:{self.log_dir}")
  154. def save(self, path=None):
  155. if path is None:
  156. path = os.path.join(self.log_dir, "dqn_model.zip")
  157. self.model.save(path)
  158. print(f"模型已保存到:{path}")
  159. def load(self, path):
  160. self.model = DQN.load(path, env=self.env)
  161. print(f"模型已从 {path} 加载")
  162. def set_global_seed(seed: int):
  163. """固定全局随机种子,保证训练可复现"""
  164. random.seed(seed)
  165. np.random.seed(seed)
  166. torch.manual_seed(seed)
  167. torch.cuda.manual_seed_all(seed) # 如果使用GPU
  168. torch.backends.cudnn.deterministic = True
  169. torch.backends.cudnn.benchmark = False
  170. def train_uf_rl_agent(params: UFParams, total_timesteps: int = 10000, seed: int = 2025):
  171. set_global_seed(seed)
  172. recorder = UFEpisodeRecorder()
  173. callback = UFTrainingCallback(recorder, verbose=1)
  174. def make_env():
  175. env = UFSuperCycleEnv(params)
  176. env = Monitor(env)
  177. return env
  178. env = DummyVecEnv([make_env])
  179. dqn_params = DQNParams()
  180. trainer = DQNTrainer(env, dqn_params, callback=callback)
  181. trainer.train(total_timesteps)
  182. trainer.save()
  183. stats = callback.recorder.get_episode_stats()
  184. print(f"训练完成 - 总奖励: {stats.get('total_reward', 0):.2f}, 平均回收率: {stats.get('avg_recovery', 0):.3f}")
  185. return trainer.model
  186. # 训练
  187. if __name__ == "__main__":
  188. # 初始化参数
  189. params = UFParams()
  190. # 训练RL代理
  191. print("开始训练RL代理...")
  192. train_uf_rl_agent(params, total_timesteps=50000)