import numpy as np from stable_baselines3.common.callbacks import BaseCallback class UFEpisodeRecorder: """记录episode中的决策和结果""" def __init__(self): self.episode_data = [] self.current_episode = [] def record_step(self, obs, action, reward, done, info): """记录单步信息""" step_data = { "obs": obs.copy(), "action": action.copy(), "reward": reward, "done": done, "info": info.copy() if info else {} } self.current_episode.append(step_data) if done: self.episode_data.append(self.current_episode) self.current_episode = [] def get_episode_stats(self, episode_idx=-1): """获取episode统计信息""" if not self.episode_data: return {} episode = self.episode_data[episode_idx] total_reward = sum(step["reward"] for step in episode) avg_recovery = np.mean([step["info"].get("recovery", 0) for step in episode if "recovery" in step["info"]]) feasible_steps = sum(1 for step in episode if step["info"].get("feasible", False)) return { "total_reward": total_reward, "avg_recovery": avg_recovery, "feasible_steps": feasible_steps, "total_steps": len(episode) } # ==== 定义强化学习训练回调器 ==== class UFTrainingCallback(BaseCallback): """ 强化学习训练回调,用于记录每一步的数据到 recorder。 1. 不依赖环境内部 last_* 属性 2. 使用环境接口提供的 obs、actions、rewards、dones、infos 3. 自动处理 episode 结束时的统计 """ def __init__(self, recorder, verbose=0): super(UFTrainingCallback, self).__init__(verbose) self.recorder = recorder def _on_step(self) -> bool: try: new_obs = self.locals.get("new_obs") actions = self.locals.get("actions") rewards = self.locals.get("rewards") dones = self.locals.get("dones") infos = self.locals.get("infos") if len(new_obs) > 0: step_obs = new_obs[0] step_action = actions[0] if actions is not None else None step_reward = rewards[0] if rewards is not None else 0.0 step_done = dones[0] if dones is not None else False step_info = infos[0] if infos is not None else {} L_s = step_info["L_s"] t_bw_s = step_info["t_bw_s"] initial_tmp = step_info["initial_tmp"] tmp_after_ceb = step_info["tmp_after_ceb"] max_TMP_during_filtration = step_info["max_TMP_during_filtration"] tmp_penalty = step_info["tmp_penalty"] residual_ratio =step_info["residual_ratio"] res_penalty = step_info["res_penalty"] econ_reward = step_info["econ_reward"] recovery = step_info["recovery"] # 打印当前 step 的信息 if self.verbose: print(f"[Step {self.num_timesteps}] 动作={step_action}, 奖励={step_reward:.3f}, Done={step_done}, L_s={L_s}, t_bw_s={t_bw_s}," f"residual_ratio = {residual_ratio:.4f}, res_penalty = {res_penalty:.4f}," f"recovery = {recovery:.4f},econ_reward = {econ_reward :.4f}," f"initial_tmp = {initial_tmp:.4f}, tmp_after_ceb = {tmp_after_ceb:.4f}, max_TMP_during_filtration ={max_TMP_during_filtration:.4f}, tmp_penalty = {tmp_penalty:.4f}") # 记录数据 self.recorder.record_step( obs=step_obs, action=step_action, reward=step_reward, done=step_done, info=step_info, ) except Exception as e: if self.verbose: print(f"[Callback Error] {e}") return True