dqn_decider.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. """
  2. UF 超滤系统 DQN 决策脚本(与当前 DQNTrainer 严格对齐)
  3. 功能定位:
  4. - 加载已训练好的 DQN 模型
  5. - 构造与训练阶段完全一致的环境
  6. - 执行单步动作推理(predict)
  7. - 输出模型建议的工程动作参数(L_s, t_bw_s)
  8. 注意:
  9. - 本脚本【不 step 环境】
  10. - 不计算 reward
  11. - 不进行 episode rollout
  12. """
  13. from pathlib import Path
  14. import numpy as np
  15. # ============================================================
  16. # 1. UF 环境与物理模型
  17. # ============================================================
  18. from env.uf_env import UFSuperCycleEnv
  19. from env.env_params import (
  20. UFRewardParams,
  21. UFActionSpec,
  22. UFStateBounds,
  23. )
  24. # ============================================================
  25. # 2. Stable-Baselines3
  26. # ============================================================
  27. from stable_baselines3 import DQN
  28. # ============================================================
  29. # 3. DQN 决策器
  30. # ============================================================
  31. class UFDQNDecider:
  32. """
  33. UF 超滤 DQN 决策器(Inference Only)
  34. 设计原则:
  35. 1. 与训练环境参数级一致
  36. 2. 决策侧不推进环境
  37. 3. 不依赖 Trainer 内部状态
  38. """
  39. def __init__(
  40. self,
  41. physics,
  42. action_spec,
  43. reward_params,
  44. state_bounds,
  45. model_path,
  46. seed: int = 0,
  47. ):
  48. """
  49. Parameters
  50. ----------
  51. model_path
  52. dqn_model.zip 的路径
  53. reset_state_pool :
  54. ResetStatePoolLoader.split() 得到的 pool(train / val 均可)
  55. seed : int
  56. 随机种子(推理阶段主要用于 env.reset)
  57. """
  58. self.action_spec = action_spec
  59. reward_params = reward_params
  60. state_bounds = state_bounds
  61. self.env = UFSuperCycleEnv(
  62. physics=physics,
  63. reward_params=reward_params,
  64. action_spec=self.action_spec,
  65. statebounds=state_bounds,
  66. real_state_pool=None,
  67. RANDOM_SEED=seed,
  68. )
  69. model_path = Path(model_path)
  70. if not model_path.exists():
  71. raise FileNotFoundError(f"DQN 模型不存在: {model_path}")
  72. self.model = DQN.load(
  73. path=str(model_path),
  74. env=self.env, # ⚠ 必须提供 env
  75. )
  76. # ========================================================
  77. # 对外决策接口
  78. # ========================================================
  79. def decide(self, state: np.ndarray | None = None) -> dict:
  80. """
  81. 单步决策(不 step 环境)
  82. Parameters
  83. ----------
  84. state : np.ndarray | None
  85. - None:env.reset() 从 reset_state_pool 抽样状态
  86. - 非 None:使用外部系统提供的状态
  87. Returns
  88. -------
  89. dict
  90. {
  91. "action_id": int,
  92. "L_s": float,
  93. "t_bw_s": float,
  94. }
  95. """
  96. # ----------------------------------------------------
  97. # 4.1 获取观测状态
  98. # ----------------------------------------------------
  99. if state is None:
  100. obs = self.env.reset()
  101. else:
  102. obs = self.env.get_obs(state) # 获取归一化状态作为策略网络输入
  103. # ----------------------------------------------------
  104. # 4.2 DQN 推理(确定性)
  105. # ----------------------------------------------------
  106. action, _ = self.model.predict(obs, deterministic=True)
  107. action_id = int(action)
  108. # ----------------------------------------------------
  109. # 4.3 动作解码(工程语义)
  110. # ----------------------------------------------------
  111. L_s, t_bw_s = self.env.get_action_values(action_id)
  112. return {
  113. "action_id": action_id,
  114. "L_s": L_s,
  115. "t_bw_s": t_bw_s,
  116. }
  117. # ============================================================
  118. # 5. 示例调用(调试用)
  119. # ============================================================
  120. if __name__ == "__main__":
  121. from data_to_rl import ResetStatePoolLoader
  122. # --------------------------------------------------------
  123. # 模型路径(来自 Trainer.save())
  124. # --------------------------------------------------------
  125. MODEL_PATH = Path(
  126. "models/uf-rl/model_result/uf_dqn_tensorboard/xxx/dqn_model.zip"
  127. )
  128. # --------------------------------------------------------
  129. # Reset state pool
  130. # --------------------------------------------------------
  131. RESET_STATE_CSV = Path(
  132. "datasets/rl_ready/output/reset_state_pool.csv"
  133. )
  134. loader = ResetStatePoolLoader(
  135. csv_path=RESET_STATE_CSV,
  136. train_ratio=0.8,
  137. shuffle=False,
  138. random_state=2025,
  139. )
  140. _, val_pool = loader.split()
  141. # --------------------------------------------------------
  142. # 初始化决策器
  143. # --------------------------------------------------------
  144. decider = UFDQNDecider(
  145. model_path=MODEL_PATH,
  146. reset_state_pool=val_pool,
  147. )
  148. # --------------------------------------------------------
  149. # 执行一次决策
  150. # --------------------------------------------------------
  151. decision = decider.decide()
  152. print("===== DQN 决策结果 =====")
  153. print(f"Action ID : {decision['action_id']}")
  154. print(f"L_s : {decision['L_s']} s")
  155. print(f"t_bw_s : {decision['t_bw_s']} s")