| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- import os
- import time
- from stable_baselines3 import DQN
- class DQNTrainer:
- """
- DQN训练器封装类
- 功能:
- - 创建训练模型
- - 训练智能体
- - 保存与加载模型
- - 在测试集环境上评估策略
- """
- def __init__(self, env, params, callback=None,PROJECT_ROOT=None):
- """
- 初始化训练器
- 参数:
- env: Gym环境实例(向量化环境)
- params: DQNParams超参数对象
- callback: 可选训练回调
- """
- self.env = env
- self.params = params
- self.callback = callback
- self.PROJECT_ROOT = PROJECT_ROOT
- self.log_dir = self._create_log_dir() # 创建TensorBoard日志目录
- self.model = self._create_model() # 创建DQN模型
- # ------------------- 私有方法 -------------------
- def _create_log_dir(self):
- """
- 创建 TensorBoard 日志目录(固定在 PROJECT_ROOT/model_result/uf_dqn_tensorboard 下)
- """
- import os
- import time
- # 1️⃣ 时间戳,用于区分每次训练
- timestamp = time.strftime("%Y%m%d-%H%M%S")
- # 2️⃣ 将浮点参数转成整数便于命名
- lr_int = int(self.params.learning_rate * 1e4)
- gamma_int = int(self.params.gamma * 100)
- exp_int = int(self.params.exploration_fraction * 100)
- # 3️⃣ 构建日志目录名称
- log_name = (
- f"DQN_lr{lr_int}_buf{self.params.buffer_size}_bs{self.params.batch_size}"
- f"_gamma{gamma_int}_exp{exp_int}_{self.params.remark}_{timestamp}"
- )
- # 4️⃣ 固定日志存放位置:PROJECT_ROOT/model_result/uf_dqn_tensorboard
- # 假设在 run_dqn_train.py 中定义了 PROJECT_ROOT = "models/uf-rl"
- base_dir = os.path.join(self.PROJECT_ROOT, "model_result", "uf_dqn_tensorboard","anzhen48h")
- os.makedirs(base_dir, exist_ok=True)
- # 5️⃣ 完整日志目录路径
- log_dir = os.path.join(base_dir, log_name)
- os.makedirs(log_dir, exist_ok=True)
- return log_dir
- def _create_model(self):
- """
- 创建Stable-Baselines3 DQN模型
- """
- model = DQN(
- policy="MlpPolicy",
- env=self.env,
- learning_rate=self.params.learning_rate,
- buffer_size=self.params.buffer_size,
- learning_starts=self.params.learning_starts,
- batch_size=self.params.batch_size,
- gamma=self.params.gamma,
- train_freq=self.params.train_freq,
- target_update_interval=self.params.target_update_interval,
- tau=self.params.tau,
- exploration_initial_eps=self.params.exploration_initial_eps,
- exploration_fraction=self.params.exploration_fraction,
- exploration_final_eps=self.params.exploration_final_eps,
- verbose=1,
- tensorboard_log=self.log_dir
- )
- return model
- # ------------------- 公共方法 -------------------
- def train(self, total_timesteps: int):
- """
- 执行训练
- 参数:
- total_timesteps: 总训练步数
- """
- if self.callback:
- self.model.learn(total_timesteps=total_timesteps, callback=self.callback)
- else:
- self.model.learn(total_timesteps=total_timesteps)
- print(f"✅ 模型训练完成!")
- print(f"📊 日志保存在:{self.log_dir}")
- print(f"💡 使用以下命令查看TensorBoard:")
- print(f" tensorboard --logdir={self.log_dir}")
- def save(self, path=None):
- """
- 保存模型
- 参数:
- path: 可选路径,默认保存到日志目录下 dqn_model.zip
- """
- if path is None:
- path = os.path.join(self.log_dir, "dqn_model.zip")
- self.model.save(path)
- print(f"💾 模型已保存到:{path}")
- def load(self, path):
- """
- 加载模型
- 参数:
- path: 模型文件路径
- """
- self.model = DQN.load(path, env=self.env)
- print(f"📥 模型已从 {path} 加载")
- def evaluate(self, test_env, n_episodes=10, deterministic=True):
- """
- 在测试环境上评估模型
- 参数:
- test_env: 测试用Gym环境(非向量化)
- n_episodes: 测试episode数量
- deterministic: 是否使用确定性策略
- 返回:
- list[dict]: 每个episode的统计信息,包括总奖励、步数、TMP序列等
- """
- results = []
- for ep in range(n_episodes):
- obs = test_env.reset()
- done = False
- total_reward = 0
- steps = 0
- tmp_after_ceb_list = []
- while not done:
- action, _ = self.model.predict(obs, deterministic=deterministic)
- obs, reward, done, info = test_env.step(action)
- total_reward += reward
- steps += 1
- # 可根据需要记录TMP、回收率等
- tmp_after_ceb_list.append(info.get("tmp_after_ceb", 0))
- ep_result = {
- "episode": ep,
- "total_reward": total_reward,
- "steps": steps,
- "tmp_after_ceb_sequence": tmp_after_ceb_list
- }
- results.append(ep_result)
- return results
|