DQN_train.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. import os
  2. import time
  3. import random
  4. import numpy as np
  5. import torch
  6. from stable_baselines3 import DQN
  7. from stable_baselines3.common.monitor import Monitor
  8. from stable_baselines3.common.vec_env import DummyVecEnv
  9. from stable_baselines3.common.callbacks import BaseCallback
  10. from DQN_env import UFParams, UFSuperCycleEnv
  11. # ==== 定义强化学习超参数 ====
  12. class DQNParams:
  13. """
  14. DQN 超参数定义类
  15. 用于统一管理模型训练参数
  16. """
  17. # 学习率,控制神经网络更新步长
  18. learning_rate: float = 1e-4
  19. # 经验回放缓冲区大小(步数)
  20. buffer_size: int = 100000
  21. # 学习开始前需要收集的步数
  22. learning_starts: int = 10000
  23. # 每次从经验池中采样的样本数量
  24. batch_size: int = 32
  25. # 折扣因子,越接近1越重视长期奖励
  26. gamma: float = 0.95
  27. # 每隔多少步训练一次
  28. train_freq: int = 4
  29. # 目标网络更新间隔
  30. target_update_interval: int = 1
  31. # 软更新系数
  32. tau: float = 0.005
  33. # 初始探索率 ε
  34. exploration_initial_eps: float = 1.0
  35. # 从初始ε衰减到最终ε所占的训练比例
  36. exploration_fraction: float = 0.3
  37. # 最终探索率 ε
  38. exploration_final_eps: float = 0.02
  39. # 日志备注(用于区分不同实验)
  40. remark: str = "default"
  41. class UFEpisodeRecorder:
  42. """记录episode中的决策和结果"""
  43. def __init__(self):
  44. self.episode_data = []
  45. self.current_episode = []
  46. def record_step(self, obs, action, reward, done, info):
  47. """记录单步信息"""
  48. step_data = {
  49. "obs": obs.copy(),
  50. "action": action.copy(),
  51. "reward": reward,
  52. "done": done,
  53. "info": info.copy() if info else {}
  54. }
  55. self.current_episode.append(step_data)
  56. if done:
  57. self.episode_data.append(self.current_episode)
  58. self.current_episode = []
  59. def get_episode_stats(self, episode_idx=-1):
  60. """获取episode统计信息"""
  61. if not self.episode_data:
  62. return {}
  63. episode = self.episode_data[episode_idx]
  64. total_reward = sum(step["reward"] for step in episode)
  65. avg_recovery = np.mean([step["info"].get("recovery", 0) for step in episode if "recovery" in step["info"]])
  66. feasible_steps = sum(1 for step in episode if step["info"].get("feasible", False))
  67. return {
  68. "total_reward": total_reward,
  69. "avg_recovery": avg_recovery,
  70. "feasible_steps": feasible_steps,
  71. "total_steps": len(episode)
  72. }
  73. # ==== 定义强化学习训练回调器 ====
  74. class UFTrainingCallback(BaseCallback):
  75. """
  76. 强化学习训练回调,用于记录每一步的数据到 recorder。
  77. 1. 不依赖环境内部 last_* 属性
  78. 2. 使用环境接口提供的 obs、actions、rewards、dones、infos
  79. 3. 自动处理 episode 结束时的统计
  80. """
  81. def __init__(self, recorder, verbose=0):
  82. super(UFTrainingCallback, self).__init__(verbose)
  83. self.recorder = recorder
  84. def _on_step(self) -> bool:
  85. try:
  86. new_obs = self.locals.get("new_obs")
  87. actions = self.locals.get("actions")
  88. rewards = self.locals.get("rewards")
  89. dones = self.locals.get("dones")
  90. infos = self.locals.get("infos")
  91. if len(new_obs) > 0:
  92. step_obs = new_obs[0]
  93. step_action = actions[0] if actions is not None else None
  94. step_reward = rewards[0] if rewards is not None else 0.0
  95. step_done = dones[0] if dones is not None else False
  96. step_info = infos[0] if infos is not None else {}
  97. # 打印当前 step 的信息
  98. if self.verbose:
  99. print(f"[Step {self.num_timesteps}] 动作={step_action}, 奖励={step_reward:.3f}, Done={step_done}")
  100. # 记录数据
  101. self.recorder.record_step(
  102. obs=step_obs,
  103. action=step_action,
  104. reward=step_reward,
  105. done=step_done,
  106. info=step_info,
  107. )
  108. except Exception as e:
  109. if self.verbose:
  110. print(f"[Callback Error] {e}")
  111. return True
  112. class DQNTrainer:
  113. def __init__(self, env, params, callback=None):
  114. self.env = env
  115. self.params = params
  116. self.callback = callback
  117. self.log_dir = self._create_log_dir()
  118. self.model = self._create_model()
  119. def _create_log_dir(self):
  120. # 创建训练日志
  121. timestamp = time.strftime("%Y%m%d-%H%M%S")
  122. log_name = (
  123. f"DQN_lr{self.params.learning_rate}_buf{self.params.buffer_size}_bs{self.params.batch_size}"
  124. f"_gamma{self.params.gamma}_exp{self.params.exploration_fraction}"
  125. f"_{self.params.remark}_{timestamp}"
  126. )
  127. log_dir = os.path.join("./uf_dqn_tensorboard", log_name)
  128. os.makedirs(log_dir, exist_ok=True)
  129. return log_dir
  130. def _create_model(self):
  131. return DQN(
  132. policy="MlpPolicy",
  133. env=self.env,
  134. learning_rate=self.params.learning_rate,
  135. buffer_size=self.params.buffer_size,
  136. learning_starts=self.params.learning_starts,
  137. batch_size=self.params.batch_size,
  138. gamma=self.params.gamma,
  139. train_freq=self.params.train_freq,
  140. target_update_interval=1,
  141. tau=0.005,
  142. exploration_initial_eps=self.params.exploration_initial_eps,
  143. exploration_fraction=self.params.exploration_fraction,
  144. exploration_final_eps=self.params.exploration_final_eps,
  145. verbose=1,
  146. tensorboard_log=self.log_dir
  147. )
  148. def train(self, total_timesteps: int):
  149. if self.callback:
  150. self.model.learn(total_timesteps=total_timesteps, callback=self.callback)
  151. else:
  152. self.model.learn(total_timesteps=total_timesteps)
  153. print(f"模型训练完成,日志保存在:{self.log_dir}")
  154. def save(self, path=None):
  155. if path is None:
  156. path = os.path.join(self.log_dir, "dqn_model.zip")
  157. self.model.save(path)
  158. print(f"模型已保存到:{path}")
  159. def load(self, path):
  160. self.model = DQN.load(path, env=self.env)
  161. print(f"模型已从 {path} 加载")
  162. def set_global_seed(seed: int):
  163. """固定全局随机种子,保证训练可复现"""
  164. random.seed(seed)
  165. np.random.seed(seed)
  166. torch.manual_seed(seed)
  167. torch.cuda.manual_seed_all(seed) # 如果使用GPU
  168. torch.backends.cudnn.deterministic = True
  169. torch.backends.cudnn.benchmark = False
  170. def train_uf_rl_agent(params: UFParams, total_timesteps: int = 10000, seed: int = 2025):
  171. set_global_seed(seed)
  172. recorder = UFEpisodeRecorder()
  173. callback = UFTrainingCallback(recorder, verbose=1)
  174. def make_env():
  175. env = UFSuperCycleEnv(params)
  176. env = Monitor(env)
  177. return env
  178. env = DummyVecEnv([make_env])
  179. dqn_params = DQNParams()
  180. trainer = DQNTrainer(env, dqn_params, callback=callback)
  181. trainer.train(total_timesteps)
  182. trainer.save()
  183. stats = callback.recorder.get_episode_stats()
  184. print(f"训练完成 - 总奖励: {stats.get('total_reward', 0):.2f}, 平均回收率: {stats.get('avg_recovery', 0):.3f}")
  185. return trainer.model
  186. # 训练
  187. if __name__ == "__main__":
  188. # 初始化参数
  189. params = UFParams()
  190. # 训练RL代理
  191. print("开始训练RL代理...")
  192. train_uf_rl_agent(params, total_timesteps=150000)