dqn_trainer.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import os
  2. import time
  3. from stable_baselines3 import DQN
  4. class DQNTrainer:
  5. """
  6. DQN训练器封装类
  7. 功能:
  8. - 创建训练模型
  9. - 训练智能体
  10. - 保存与加载模型
  11. - 在测试集环境上评估策略
  12. """
  13. def __init__(self, env, params, callback=None):
  14. """
  15. 初始化训练器
  16. 参数:
  17. env: Gym环境实例(向量化环境)
  18. params: DQNParams超参数对象
  19. callback: 可选训练回调
  20. """
  21. self.env = env
  22. self.params = params
  23. self.callback = callback
  24. self.log_dir = self._create_log_dir() # 创建TensorBoard日志目录
  25. self.model = self._create_model() # 创建DQN模型
  26. # ------------------- 私有方法 -------------------
  27. def _create_log_dir(self):
  28. """
  29. 创建TensorBoard日志目录
  30. """
  31. timestamp = time.strftime("%Y%m%d-%H%M%S")
  32. lr_int = int(self.params.learning_rate * 1e4)
  33. gamma_int = int(self.params.gamma * 100)
  34. exp_int = int(self.params.exploration_fraction * 100)
  35. log_name = f"DQN_lr{lr_int}_buf{self.params.buffer_size}_bs{self.params.batch_size}_gamma{gamma_int}_exp{exp_int}_{self.params.remark}_{timestamp}"
  36. base_dir = os.path.join(os.getcwd(), "uf_dqn_tensorboard")
  37. os.makedirs(base_dir, exist_ok=True)
  38. log_dir = os.path.join(base_dir, log_name)
  39. os.makedirs(log_dir, exist_ok=True)
  40. return log_dir
  41. def _create_model(self):
  42. """
  43. 创建Stable-Baselines3 DQN模型
  44. """
  45. model = DQN(
  46. policy="MlpPolicy",
  47. env=self.env,
  48. learning_rate=self.params.learning_rate,
  49. buffer_size=self.params.buffer_size,
  50. learning_starts=self.params.learning_starts,
  51. batch_size=self.params.batch_size,
  52. gamma=self.params.gamma,
  53. train_freq=self.params.train_freq,
  54. target_update_interval=self.params.target_update_interval,
  55. tau=self.params.tau,
  56. exploration_initial_eps=self.params.exploration_initial_eps,
  57. exploration_fraction=self.params.exploration_fraction,
  58. exploration_final_eps=self.params.exploration_final_eps,
  59. verbose=1,
  60. tensorboard_log=self.log_dir
  61. )
  62. return model
  63. # ------------------- 公共方法 -------------------
  64. def train(self, total_timesteps: int):
  65. """
  66. 执行训练
  67. 参数:
  68. total_timesteps: 总训练步数
  69. """
  70. if self.callback:
  71. self.model.learn(total_timesteps=total_timesteps, callback=self.callback)
  72. else:
  73. self.model.learn(total_timesteps=total_timesteps)
  74. print(f"✅ 模型训练完成!")
  75. print(f"📊 日志保存在:{self.log_dir}")
  76. print(f"💡 使用以下命令查看TensorBoard:")
  77. print(f" tensorboard --logdir={self.log_dir}")
  78. def save(self, path=None):
  79. """
  80. 保存模型
  81. 参数:
  82. path: 可选路径,默认保存到日志目录下 dqn_model.zip
  83. """
  84. if path is None:
  85. path = os.path.join(self.log_dir, "dqn_model.zip")
  86. self.model.save(path)
  87. print(f"💾 模型已保存到:{path}")
  88. def load(self, path):
  89. """
  90. 加载模型
  91. 参数:
  92. path: 模型文件路径
  93. """
  94. self.model = DQN.load(path, env=self.env)
  95. print(f"📥 模型已从 {path} 加载")
  96. def evaluate(self, test_env, n_episodes=10, deterministic=True):
  97. """
  98. 在测试环境上评估模型
  99. 参数:
  100. test_env: 测试用Gym环境(非向量化)
  101. n_episodes: 测试episode数量
  102. deterministic: 是否使用确定性策略
  103. 返回:
  104. list[dict]: 每个episode的统计信息,包括总奖励、步数、TMP序列等
  105. """
  106. results = []
  107. for ep in range(n_episodes):
  108. obs = test_env.reset()
  109. done = False
  110. total_reward = 0
  111. steps = 0
  112. tmp_after_ceb_list = []
  113. while not done:
  114. action, _ = self.model.predict(obs, deterministic=deterministic)
  115. obs, reward, done, info = test_env.step(action)
  116. total_reward += reward
  117. steps += 1
  118. # 可根据需要记录TMP、回收率等
  119. tmp_after_ceb_list.append(info.get("tmp_after_ceb", 0))
  120. ep_result = {
  121. "episode": ep,
  122. "total_reward": total_reward,
  123. "steps": steps,
  124. "tmp_after_ceb_sequence": tmp_after_ceb_list
  125. }
  126. results.append(ep_result)
  127. return results