| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- import numpy as np
- from stable_baselines3.common.callbacks import BaseCallback
- class UFEpisodeRecorder:
- """记录episode中的决策和结果"""
- def __init__(self):
- self.episode_data = []
- self.current_episode = []
- def record_step(self, obs, action, reward, done, info):
- """记录单步信息"""
- step_data = {
- "obs": obs.copy(),
- "action": action.copy(),
- "reward": reward,
- "done": done,
- "info": info.copy() if info else {}
- }
- self.current_episode.append(step_data)
- if done:
- self.episode_data.append(self.current_episode)
- self.current_episode = []
- def get_episode_stats(self, episode_idx=-1):
- """获取episode统计信息"""
- if not self.episode_data:
- return {}
- episode = self.episode_data[episode_idx]
- total_reward = sum(step["reward"] for step in episode)
- avg_recovery = np.mean([step["info"].get("recovery", 0) for step in episode if "recovery" in step["info"]])
- feasible_steps = sum(1 for step in episode if step["info"].get("feasible", False))
- return {
- "total_reward": total_reward,
- "avg_recovery": avg_recovery,
- "feasible_steps": feasible_steps,
- "total_steps": len(episode)
- }
- # ==== 定义强化学习训练回调器 ====
- class UFTrainingCallback(BaseCallback):
- """
- 强化学习训练回调,用于记录每一步的数据到 recorder。
- 1. 不依赖环境内部 last_* 属性
- 2. 使用环境接口提供的 obs、actions、rewards、dones、infos
- 3. 自动处理 episode 结束时的统计
- """
- def __init__(self, recorder, verbose=0):
- super(UFTrainingCallback, self).__init__(verbose)
- self.recorder = recorder
- def _on_step(self) -> bool:
- try:
- new_obs = self.locals.get("new_obs")
- actions = self.locals.get("actions")
- rewards = self.locals.get("rewards")
- dones = self.locals.get("dones")
- infos = self.locals.get("infos")
- if len(new_obs) > 0:
- step_obs = new_obs[0]
- step_action = actions[0] if actions is not None else None
- step_reward = rewards[0] if rewards is not None else 0.0
- step_done = dones[0] if dones is not None else False
- step_info = infos[0] if infos is not None else {}
- L_s = step_info["L_s"]
- t_bw_s = step_info["t_bw_s"]
- initial_tmp = step_info["initial_tmp"]
- tmp_after_ceb = step_info["tmp_after_ceb"]
- max_TMP_during_filtration = step_info["max_TMP_during_filtration"]
- tmp_penalty = step_info["tmp_penalty"]
- residual_ratio =step_info["residual_ratio"]
- res_penalty = step_info["res_penalty"]
- econ_reward = step_info["econ_reward"]
- recovery = step_info["recovery"]
- # 打印当前 step 的信息
- if self.verbose:
- print(f"[Step {self.num_timesteps}] 动作={step_action}, 奖励={step_reward:.3f}, Done={step_done}, L_s={L_s}, t_bw_s={t_bw_s},"
- f"residual_ratio = {residual_ratio:.4f}, res_penalty = {res_penalty:.4f},"
- f"recovery = {recovery:.4f},econ_reward = {econ_reward :.4f},"
- f"initial_tmp = {initial_tmp:.4f}, tmp_after_ceb = {tmp_after_ceb:.4f}, max_TMP_during_filtration ={max_TMP_during_filtration:.4f}, tmp_penalty = {tmp_penalty:.4f}")
- # 记录数据
- self.recorder.record_step(
- obs=step_obs,
- action=step_action,
- reward=step_reward,
- done=step_done,
- info=step_info,
- )
- except Exception as e:
- if self.verbose:
- print(f"[Callback Error] {e}")
- return True
|