| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 |
- """
- UF 超滤系统 DQN 决策脚本(与当前 DQNTrainer 严格对齐)
- 功能定位:
- - 加载已训练好的 DQN 模型
- - 构造与训练阶段完全一致的环境
- - 执行单步动作推理(predict)
- - 输出模型建议的工程动作参数(L_s, t_bw_s)
- 注意:
- - 本脚本【不 step 环境】
- - 不计算 reward
- - 不进行 episode rollout
- """
- from pathlib import Path
- import numpy as np
- # ============================================================
- # 1. UF 环境与物理模型
- # ============================================================
- from env.uf_env import UFSuperCycleEnv
- from env.env_params import (
- UFRewardParams,
- UFActionSpec,
- UFStateBounds,
- )
- # ============================================================
- # 2. Stable-Baselines3
- # ============================================================
- from stable_baselines3 import DQN
- # ============================================================
- # 3. DQN 决策器
- # ============================================================
- class UFDQNDecider:
- """
- UF 超滤 DQN 决策器(Inference Only)
- 设计原则:
- 1. 与训练环境参数级一致
- 2. 决策侧不推进环境
- 3. 不依赖 Trainer 内部状态
- """
- def __init__(
- self,
- physics,
- action_spec,
- reward_params,
- state_bounds,
- model_path,
- seed: int = 0,
- ):
- """
- Parameters
- ----------
- model_path
- dqn_model.zip 的路径
- reset_state_pool :
- ResetStatePoolLoader.split() 得到的 pool(train / val 均可)
- seed : int
- 随机种子(推理阶段主要用于 env.reset)
- """
- self.action_spec = action_spec
- reward_params = reward_params
- state_bounds = state_bounds
- self.env = UFSuperCycleEnv(
- physics=physics,
- reward_params=reward_params,
- action_spec=self.action_spec,
- statebounds=state_bounds,
- real_state_pool=None,
- RANDOM_SEED=seed,
- )
- model_path = Path(model_path)
- if not model_path.exists():
- raise FileNotFoundError(f"DQN 模型不存在: {model_path}")
- self.model = DQN.load(
- path=str(model_path),
- env=self.env, # ⚠ 必须提供 env
- )
- # ========================================================
- # 对外决策接口
- # ========================================================
- def decide(self, state: np.ndarray | None = None) -> dict:
- """
- 单步决策(不 step 环境)
- Parameters
- ----------
- state : np.ndarray | None
- - None:env.reset() 从 reset_state_pool 抽样状态
- - 非 None:使用外部系统提供的状态
- Returns
- -------
- dict
- {
- "action_id": int,
- "L_s": float,
- "t_bw_s": float,
- }
- """
- # ----------------------------------------------------
- # 4.1 获取观测状态
- # ----------------------------------------------------
- if state is None:
- obs = self.env.reset()
- else:
- obs = self.env.get_obs(state) # 获取归一化状态作为策略网络输入
- # ----------------------------------------------------
- # 4.2 DQN 推理(确定性)
- # ----------------------------------------------------
- action, _ = self.model.predict(obs, deterministic=True)
- action_id = int(action)
- # ----------------------------------------------------
- # 4.3 动作解码(工程语义)
- # ----------------------------------------------------
- L_s, t_bw_s = self.env.get_action_values(action_id)
- return {
- "action_id": action_id,
- "L_s": L_s,
- "t_bw_s": t_bw_s,
- }
- # ============================================================
- # 5. 示例调用(调试用)
- # ============================================================
- if __name__ == "__main__":
- from data_to_rl import ResetStatePoolLoader
- # --------------------------------------------------------
- # 模型路径(来自 Trainer.save())
- # --------------------------------------------------------
- MODEL_PATH = Path(
- "models/uf-rl/model_result/uf_dqn_tensorboard/xxx/dqn_model.zip"
- )
- # --------------------------------------------------------
- # Reset state pool
- # --------------------------------------------------------
- RESET_STATE_CSV = Path(
- "datasets/rl_ready/output/reset_state_pool.csv"
- )
- loader = ResetStatePoolLoader(
- csv_path=RESET_STATE_CSV,
- train_ratio=0.8,
- shuffle=False,
- random_state=2025,
- )
- _, val_pool = loader.split()
- # --------------------------------------------------------
- # 初始化决策器
- # --------------------------------------------------------
- decider = UFDQNDecider(
- model_path=MODEL_PATH,
- reset_state_pool=val_pool,
- )
- # --------------------------------------------------------
- # 执行一次决策
- # --------------------------------------------------------
- decision = decider.decide()
- print("===== DQN 决策结果 =====")
- print(f"Action ID : {decision['action_id']}")
- print(f"L_s : {decision['L_s']} s")
- print(f"t_bw_s : {decision['t_bw_s']} s")
|