env_visual.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import numpy as np
  2. from stable_baselines3.common.callbacks import BaseCallback
  3. class UFEpisodeRecorder:
  4. """记录episode中的决策和结果"""
  5. def __init__(self):
  6. self.episode_data = []
  7. self.current_episode = []
  8. def record_step(self, obs, action, reward, done, info):
  9. """记录单步信息"""
  10. step_data = {
  11. "obs": obs.copy(),
  12. "action": action.copy(),
  13. "reward": reward,
  14. "done": done,
  15. "info": info.copy() if info else {}
  16. }
  17. self.current_episode.append(step_data)
  18. if done:
  19. self.episode_data.append(self.current_episode)
  20. self.current_episode = []
  21. def get_episode_stats(self, episode_idx=-1):
  22. """获取episode统计信息"""
  23. if not self.episode_data:
  24. return {}
  25. episode = self.episode_data[episode_idx]
  26. total_reward = sum(step["reward"] for step in episode)
  27. avg_recovery = np.mean([step["info"].get("recovery", 0) for step in episode if "recovery" in step["info"]])
  28. feasible_steps = sum(1 for step in episode if step["info"].get("feasible", False))
  29. return {
  30. "total_reward": total_reward,
  31. "avg_recovery": avg_recovery,
  32. "feasible_steps": feasible_steps,
  33. "total_steps": len(episode)
  34. }
  35. # ==== 定义强化学习训练回调器 ====
  36. class UFTrainingCallback(BaseCallback):
  37. """
  38. 强化学习训练回调,用于记录每一步的数据到 recorder。
  39. 1. 不依赖环境内部 last_* 属性
  40. 2. 使用环境接口提供的 obs、actions、rewards、dones、infos
  41. 3. 自动处理 episode 结束时的统计
  42. """
  43. def __init__(self, recorder, verbose=0):
  44. super(UFTrainingCallback, self).__init__(verbose)
  45. self.recorder = recorder
  46. def _on_step(self) -> bool:
  47. try:
  48. new_obs = self.locals.get("new_obs")
  49. actions = self.locals.get("actions")
  50. rewards = self.locals.get("rewards")
  51. dones = self.locals.get("dones")
  52. infos = self.locals.get("infos")
  53. if len(new_obs) > 0:
  54. step_obs = new_obs[0]
  55. step_action = actions[0] if actions is not None else None
  56. step_reward = rewards[0] if rewards is not None else 0.0
  57. step_done = dones[0] if dones is not None else False
  58. step_info = infos[0] if infos is not None else {}
  59. L_s = step_info["L_s"]
  60. t_bw_s = step_info["t_bw_s"]
  61. initial_tmp = step_info["initial_tmp"]
  62. tmp_after_ceb = step_info["tmp_after_ceb"]
  63. max_TMP_during_filtration = step_info["max_TMP_during_filtration"]
  64. tmp_penalty = step_info["tmp_penalty"]
  65. residual_ratio =step_info["residual_ratio"]
  66. res_penalty = step_info["res_penalty"]
  67. econ_reward = step_info["econ_reward"]
  68. recovery = step_info["recovery"]
  69. # 打印当前 step 的信息
  70. if self.verbose:
  71. print(f"[Step {self.num_timesteps}] 动作={step_action}, 奖励={step_reward:.3f}, Done={step_done}, L_s={L_s}, t_bw_s={t_bw_s},"
  72. f"residual_ratio = {residual_ratio:.4f}, res_penalty = {res_penalty:.4f},"
  73. f",recovery = {recovery:.4f},econ_reward = {econ_reward :.4f}"
  74. f"initial_tmp = {initial_tmp:.4f}, tmp_after_ceb = {tmp_after_ceb:.4f}, max_TMP_during_filtration ={max_TMP_during_filtration:.4f}, tmp_penalty = {tmp_penalty:.4f}")
  75. # 记录数据
  76. self.recorder.record_step(
  77. obs=step_obs,
  78. action=step_action,
  79. reward=step_reward,
  80. done=step_done,
  81. info=step_info,
  82. )
  83. except Exception as e:
  84. if self.verbose:
  85. print(f"[Callback Error] {e}")
  86. return True