dqn_trainer.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import os
  2. import time
  3. from stable_baselines3 import DQN
  4. class DQNTrainer:
  5. """
  6. DQN训练器封装类
  7. 功能:
  8. - 创建训练模型
  9. - 训练智能体
  10. - 保存与加载模型
  11. - 在测试集环境上评估策略
  12. """
  13. def __init__(self, env, params, callback=None,PROJECT_ROOT=None,DIR_NAME=None):
  14. """
  15. 初始化训练器
  16. 参数:
  17. env: Gym环境实例(向量化环境)
  18. params: DQNParams超参数对象
  19. callback: 可选训练回调
  20. """
  21. self.env = env
  22. self.params = params
  23. self.callback = callback
  24. self.PROJECT_ROOT = PROJECT_ROOT
  25. self.dir_name = DIR_NAME
  26. self.log_dir = self._create_log_dir() # 创建TensorBoard日志目录
  27. self.model = self._create_model() # 创建DQN模型
  28. # ------------------- 私有方法 -------------------
  29. def _create_log_dir(self):
  30. """
  31. 创建 TensorBoard 日志目录(固定在 PROJECT_ROOT/model_result/uf_dqn_tensorboard 下)
  32. """
  33. import os
  34. import time
  35. # 1️⃣ 时间戳,用于区分每次训练
  36. timestamp = time.strftime("%Y%m%d-%H%M%S")
  37. # 2️⃣ 将浮点参数转成整数便于命名
  38. lr_int = int(self.params.learning_rate * 1e4)
  39. gamma_int = int(self.params.gamma * 100)
  40. exp_int = int(self.params.exploration_fraction * 100)
  41. # 3️⃣ 构建日志目录名称
  42. log_name = (
  43. f"DQN_lr{lr_int}_buf{self.params.buffer_size}_bs{self.params.batch_size}"
  44. f"_gamma{gamma_int}_exp{exp_int}_{self.params.remark}_{timestamp}"
  45. )
  46. # 4️⃣ 固定日志存放位置:PROJECT_ROOT/model_result/uf_dqn_tensorboard
  47. # 假设在 run_dqn_train.py 中定义了 PROJECT_ROOT = "models/uf-rl"
  48. base_dir = os.path.join(self.PROJECT_ROOT, "model_result", "uf_dqn_tensorboard",self.dir_name)
  49. os.makedirs(base_dir, exist_ok=True)
  50. # 5️⃣ 完整日志目录路径
  51. log_dir = os.path.join(base_dir, log_name)
  52. os.makedirs(log_dir, exist_ok=True)
  53. return log_dir
  54. def _create_model(self):
  55. """
  56. 创建Stable-Baselines3 DQN模型
  57. """
  58. model = DQN(
  59. policy="MlpPolicy",
  60. env=self.env,
  61. learning_rate=self.params.learning_rate,
  62. buffer_size=self.params.buffer_size,
  63. learning_starts=self.params.learning_starts,
  64. batch_size=self.params.batch_size,
  65. gamma=self.params.gamma,
  66. train_freq=self.params.train_freq,
  67. target_update_interval=self.params.target_update_interval,
  68. tau=self.params.tau,
  69. exploration_initial_eps=self.params.exploration_initial_eps,
  70. exploration_fraction=self.params.exploration_fraction,
  71. exploration_final_eps=self.params.exploration_final_eps,
  72. verbose=1,
  73. tensorboard_log=self.log_dir
  74. )
  75. return model
  76. # ------------------- 公共方法 -------------------
  77. def train(self, total_timesteps: int):
  78. """
  79. 执行训练
  80. 参数:
  81. total_timesteps: 总训练步数
  82. """
  83. if self.callback:
  84. self.model.learn(total_timesteps=total_timesteps, callback=self.callback)
  85. else:
  86. self.model.learn(total_timesteps=total_timesteps)
  87. print(f"✅ 模型训练完成!")
  88. print(f"📊 日志保存在:{self.log_dir}")
  89. print(f"💡 使用以下命令查看TensorBoard:")
  90. print(f" tensorboard --logdir={self.log_dir}")
  91. def save(self, path=None):
  92. """
  93. 保存模型
  94. 参数:
  95. path: 可选路径,默认保存到日志目录下 dqn_model.zip
  96. """
  97. if path is None:
  98. path = os.path.join(self.log_dir, "dqn_model.zip")
  99. self.model.save(path)
  100. print(f"💾 模型已保存到:{path}")
  101. def load(self, path):
  102. """
  103. 加载模型
  104. 参数:
  105. path: 模型文件路径
  106. """
  107. self.model = DQN.load(path, env=self.env)
  108. print(f"📥 模型已从 {path} 加载")
  109. def evaluate(self, test_env, n_episodes=10, deterministic=True):
  110. """
  111. 在测试环境上评估模型
  112. 参数:
  113. test_env: 测试用Gym环境(非向量化)
  114. n_episodes: 测试episode数量
  115. deterministic: 是否使用确定性策略
  116. 返回:
  117. list[dict]: 每个episode的统计信息,包括总奖励、步数、TMP序列等
  118. """
  119. results = []
  120. for ep in range(n_episodes):
  121. obs = test_env.reset()
  122. done = False
  123. total_reward = 0
  124. steps = 0
  125. tmp_after_ceb_list = []
  126. while not done:
  127. action, _ = self.model.predict(obs, deterministic=deterministic)
  128. obs, reward, done, info = test_env.step(action)
  129. total_reward += reward
  130. steps += 1
  131. # 可根据需要记录TMP、回收率等
  132. tmp_after_ceb_list.append(info.get("tmp_after_ceb", 0))
  133. ep_result = {
  134. "episode": ep,
  135. "total_reward": total_reward,
  136. "steps": steps,
  137. "tmp_after_ceb_sequence": tmp_after_ceb_list
  138. }
  139. results.append(ep_result)
  140. return results