dqn_decider.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. """
  2. UF 超滤系统 DQN 决策脚本(与当前 DQNTrainer 严格对齐)
  3. 功能定位:
  4. - 加载已训练好的 DQN 模型
  5. - 构造与训练阶段完全一致的环境
  6. - 执行单步动作推理(predict)
  7. - 输出模型建议的工程动作参数(L_s, t_bw_s)
  8. 注意:
  9. - 本脚本【不 step 环境】
  10. - 不计算 reward
  11. - 不进行 episode rollout
  12. """
  13. from pathlib import Path
  14. import numpy as np
  15. # ============================================================
  16. # 1. UF 环境与物理模型
  17. # ============================================================
  18. from env.uf_env import UFSuperCycleEnv
  19. from env.env_params import (
  20. UFRewardParams,
  21. UFActionSpec,
  22. UFStateBounds,
  23. )
  24. # ============================================================
  25. # 2. Stable-Baselines3
  26. # ============================================================
  27. from stable_baselines3 import DQN
  28. # ============================================================
  29. # 3. DQN 决策器
  30. # ============================================================
  31. class UFDQNDecider:
  32. """
  33. UF 超滤 DQN 决策器(Inference Only)
  34. 设计原则:
  35. 1. 与训练环境参数级一致
  36. 2. 决策侧不推进环境
  37. 3. 不依赖 Trainer 内部状态
  38. """
  39. def __init__(
  40. self,
  41. physics,
  42. model_path,
  43. seed: int = 0,
  44. ):
  45. """
  46. Parameters
  47. ----------
  48. model_path
  49. dqn_model.zip 的路径
  50. reset_state_pool :
  51. ResetStatePoolLoader.split() 得到的 pool(train / val 均可)
  52. seed : int
  53. 随机种子(推理阶段主要用于 env.reset)
  54. """
  55. self.action_spec = UFActionSpec()
  56. reward_params = UFRewardParams()
  57. state_bounds = UFStateBounds()
  58. self.env = UFSuperCycleEnv(
  59. physics=physics,
  60. reward_params=reward_params,
  61. action_spec=self.action_spec,
  62. statebounds=state_bounds,
  63. real_state_pool=None,
  64. RANDOM_SEED=seed,
  65. )
  66. model_path = Path(model_path)
  67. if not model_path.exists():
  68. raise FileNotFoundError(f"DQN 模型不存在: {model_path}")
  69. self.model = DQN.load(
  70. path=str(model_path),
  71. env=self.env, # ⚠ 必须提供 env
  72. )
  73. # ========================================================
  74. # 对外决策接口
  75. # ========================================================
  76. def decide(self, state: np.ndarray | None = None) -> dict:
  77. """
  78. 单步决策(不 step 环境)
  79. Parameters
  80. ----------
  81. state : np.ndarray | None
  82. - None:env.reset() 从 reset_state_pool 抽样状态
  83. - 非 None:使用外部系统提供的状态
  84. Returns
  85. -------
  86. dict
  87. {
  88. "action_id": int,
  89. "L_s": float,
  90. "t_bw_s": float,
  91. }
  92. """
  93. # ----------------------------------------------------
  94. # 4.1 获取观测状态
  95. # ----------------------------------------------------
  96. if state is None:
  97. obs = self.env.reset()
  98. else:
  99. obs = self.env.get_obs(state) # 获取归一化状态作为策略网络输入
  100. # ----------------------------------------------------
  101. # 4.2 DQN 推理(确定性)
  102. # ----------------------------------------------------
  103. action, _ = self.model.predict(obs, deterministic=True)
  104. action_id = int(action)
  105. # ----------------------------------------------------
  106. # 4.3 动作解码(工程语义)
  107. # ----------------------------------------------------
  108. L_s, t_bw_s = self.env.get_action_values(action_id)
  109. return {
  110. "action_id": action_id,
  111. "L_s": L_s,
  112. "t_bw_s": t_bw_s,
  113. }
  114. # ============================================================
  115. # 5. 示例调用(调试用)
  116. # ============================================================
  117. if __name__ == "__main__":
  118. from data_to_rl import ResetStatePoolLoader
  119. # --------------------------------------------------------
  120. # 模型路径(来自 Trainer.save())
  121. # --------------------------------------------------------
  122. MODEL_PATH = Path(
  123. "models/uf-rl/model_result/uf_dqn_tensorboard/xxx/dqn_model.zip"
  124. )
  125. # --------------------------------------------------------
  126. # Reset state pool
  127. # --------------------------------------------------------
  128. RESET_STATE_CSV = Path(
  129. "datasets/rl_ready/output/reset_state_pool.csv"
  130. )
  131. loader = ResetStatePoolLoader(
  132. csv_path=RESET_STATE_CSV,
  133. train_ratio=0.8,
  134. shuffle=False,
  135. random_state=2025,
  136. )
  137. _, val_pool = loader.split()
  138. # --------------------------------------------------------
  139. # 初始化决策器
  140. # --------------------------------------------------------
  141. decider = UFDQNDecider(
  142. model_path=MODEL_PATH,
  143. reset_state_pool=val_pool,
  144. )
  145. # --------------------------------------------------------
  146. # 执行一次决策
  147. # --------------------------------------------------------
  148. decision = decider.decide()
  149. print("===== DQN 决策结果 =====")
  150. print(f"Action ID : {decision['action_id']}")
  151. print(f"L_s : {decision['L_s']} s")
  152. print(f"t_bw_s : {decision['t_bw_s']} s")