dqn_trainer.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import os
  2. import time
  3. from stable_baselines3 import DQN
  4. class DQNTrainer:
  5. """
  6. DQN训练器封装类
  7. 功能:
  8. - 创建训练模型
  9. - 训练智能体
  10. - 保存与加载模型
  11. - 在测试集环境上评估策略
  12. """
  13. def __init__(self, env, params, callback=None,PROJECT_ROOT=None):
  14. """
  15. 初始化训练器
  16. 参数:
  17. env: Gym环境实例(向量化环境)
  18. params: DQNParams超参数对象
  19. callback: 可选训练回调
  20. """
  21. self.env = env
  22. self.params = params
  23. self.callback = callback
  24. self.PROJECT_ROOT = PROJECT_ROOT
  25. self.log_dir = self._create_log_dir() # 创建TensorBoard日志目录
  26. self.model = self._create_model() # 创建DQN模型
  27. # ------------------- 私有方法 -------------------
  28. def _create_log_dir(self):
  29. """
  30. 创建 TensorBoard 日志目录(固定在 PROJECT_ROOT/model_result/uf_dqn_tensorboard 下)
  31. """
  32. import os
  33. import time
  34. # 1️⃣ 时间戳,用于区分每次训练
  35. timestamp = time.strftime("%Y%m%d-%H%M%S")
  36. # 2️⃣ 将浮点参数转成整数便于命名
  37. lr_int = int(self.params.learning_rate * 1e4)
  38. gamma_int = int(self.params.gamma * 100)
  39. exp_int = int(self.params.exploration_fraction * 100)
  40. # 3️⃣ 构建日志目录名称
  41. log_name = (
  42. f"DQN_lr{lr_int}_buf{self.params.buffer_size}_bs{self.params.batch_size}"
  43. f"_gamma{gamma_int}_exp{exp_int}_{self.params.remark}_{timestamp}"
  44. )
  45. # 4️⃣ 固定日志存放位置:PROJECT_ROOT/model_result/uf_dqn_tensorboard
  46. # 假设在 run_dqn_train.py 中定义了 PROJECT_ROOT = "models/uf-rl"
  47. base_dir = os.path.join(self.PROJECT_ROOT, "model_result", "uf_dqn_tensorboard","anzhen48h")
  48. os.makedirs(base_dir, exist_ok=True)
  49. # 5️⃣ 完整日志目录路径
  50. log_dir = os.path.join(base_dir, log_name)
  51. os.makedirs(log_dir, exist_ok=True)
  52. return log_dir
  53. def _create_model(self):
  54. """
  55. 创建Stable-Baselines3 DQN模型
  56. """
  57. model = DQN(
  58. policy="MlpPolicy",
  59. env=self.env,
  60. learning_rate=self.params.learning_rate,
  61. buffer_size=self.params.buffer_size,
  62. learning_starts=self.params.learning_starts,
  63. batch_size=self.params.batch_size,
  64. gamma=self.params.gamma,
  65. train_freq=self.params.train_freq,
  66. target_update_interval=self.params.target_update_interval,
  67. tau=self.params.tau,
  68. exploration_initial_eps=self.params.exploration_initial_eps,
  69. exploration_fraction=self.params.exploration_fraction,
  70. exploration_final_eps=self.params.exploration_final_eps,
  71. verbose=1,
  72. tensorboard_log=self.log_dir
  73. )
  74. return model
  75. # ------------------- 公共方法 -------------------
  76. def train(self, total_timesteps: int):
  77. """
  78. 执行训练
  79. 参数:
  80. total_timesteps: 总训练步数
  81. """
  82. if self.callback:
  83. self.model.learn(total_timesteps=total_timesteps, callback=self.callback)
  84. else:
  85. self.model.learn(total_timesteps=total_timesteps)
  86. print(f"✅ 模型训练完成!")
  87. print(f"📊 日志保存在:{self.log_dir}")
  88. print(f"💡 使用以下命令查看TensorBoard:")
  89. print(f" tensorboard --logdir={self.log_dir}")
  90. def save(self, path=None):
  91. """
  92. 保存模型
  93. 参数:
  94. path: 可选路径,默认保存到日志目录下 dqn_model.zip
  95. """
  96. if path is None:
  97. path = os.path.join(self.log_dir, "dqn_model.zip")
  98. self.model.save(path)
  99. print(f"💾 模型已保存到:{path}")
  100. def load(self, path):
  101. """
  102. 加载模型
  103. 参数:
  104. path: 模型文件路径
  105. """
  106. self.model = DQN.load(path, env=self.env)
  107. print(f"📥 模型已从 {path} 加载")
  108. def evaluate(self, test_env, n_episodes=10, deterministic=True):
  109. """
  110. 在测试环境上评估模型
  111. 参数:
  112. test_env: 测试用Gym环境(非向量化)
  113. n_episodes: 测试episode数量
  114. deterministic: 是否使用确定性策略
  115. 返回:
  116. list[dict]: 每个episode的统计信息,包括总奖励、步数、TMP序列等
  117. """
  118. results = []
  119. for ep in range(n_episodes):
  120. obs = test_env.reset()
  121. done = False
  122. total_reward = 0
  123. steps = 0
  124. tmp_after_ceb_list = []
  125. while not done:
  126. action, _ = self.model.predict(obs, deterministic=deterministic)
  127. obs, reward, done, info = test_env.step(action)
  128. total_reward += reward
  129. steps += 1
  130. # 可根据需要记录TMP、回收率等
  131. tmp_after_ceb_list.append(info.get("tmp_after_ceb", 0))
  132. ep_result = {
  133. "episode": ep,
  134. "total_reward": total_reward,
  135. "steps": steps,
  136. "tmp_after_ceb_sequence": tmp_after_ceb_list
  137. }
  138. results.append(ep_result)
  139. return results