Explorar o código

debug:修复了简单调用模型时动作空间参数未传入的错误

junc_WHU hai 1 mes
pai
achega
35d9643ba4

+ 6 - 3
models/uf-rl/rl_model/DQN/uf_decide/dqn_decider.py

@@ -48,6 +48,9 @@ class UFDQNDecider:
     def __init__(
         self,
         physics,
+        action_spec,
+        reward_params,
+        state_bounds,
         model_path,
         seed: int = 0,
     ):
@@ -62,9 +65,9 @@ class UFDQNDecider:
             随机种子(推理阶段主要用于 env.reset)
         """
 
-        self.action_spec = UFActionSpec()
-        reward_params = UFRewardParams()
-        state_bounds = UFStateBounds()
+        self.action_spec = action_spec
+        reward_params = reward_params
+        state_bounds = state_bounds
 
         self.env = UFSuperCycleEnv(
             physics=physics,

+ 18 - 8
models/uf-rl/rl_model/DQN/uf_decide/run_dqn_decide.py

@@ -19,7 +19,7 @@ from pathlib import Path
 # ============================================================
 CURRENT_DIR = Path(__file__).resolve().parent
 
-PROJECT_ROOT = CURRENT_DIR.parents[2]     # uf_train  # uf-rl
+UF_RL_ROOT = CURRENT_DIR.parents[2]     # uf_train  # uf-rl
 
 # ========== 参数 / 物理 ==========
 from env.uf_resistance_models_load import load_resistance_models
@@ -45,7 +45,7 @@ def build_physics(IS_TIMES, phys_params):
     )
     return physics
 
-def generate_plc_instructions(current_L_s, current_t_bw_s, model_prev_L_s, model_prev_t_bw_s, model_L_s, model_t_bw_s):
+def generate_plc_instructions(action_spec,current_L_s, current_t_bw_s, model_prev_L_s, model_prev_t_bw_s, model_L_s, model_t_bw_s):
     """
     根据工厂当前值、模型上一轮决策值和模型当前轮决策值,生成PLC指令。
 
@@ -54,7 +54,7 @@ def generate_plc_instructions(current_L_s, current_t_bw_s, model_prev_L_s, model
        如果工厂当前值也为None,则返回None并提示错误。
     """
 
-    action_spec = UFActionSpec()
+    action_spec = action_spec
     adjustment_threshold = 1.0
 
     # 处理None值情况
@@ -191,7 +191,10 @@ def calc_uf_cycle_metrics(current_state, max_tmp_during_filtration, min_tmp_duri
 def run_dqn_decide(
     model_path: Path,
     physics,
-    # -------- 工厂当前值 --------
+    action_spec,
+    reward_params,
+    state_bounds,
+# -------- 工厂当前值 --------
     current_state: UFState
 ):
     """
@@ -201,6 +204,9 @@ def run_dqn_decide(
     # 构造决策器
     decider = UFDQNDecider(
         physics=physics,
+        action_spec=action_spec,
+        reward_params=reward_params,
+        state_bounds=state_bounds,
         model_path=model_path,
         seed=0,
     )
@@ -219,19 +225,20 @@ def run_dqn_decide(
 # ==============================
 if __name__ == "__main__":
 
-    MODEL_PATH = PROJECT_ROOT / "xishan" / "48h_dqn_model.zip"
-    ENV_CONFIG_PATH = PROJECT_ROOT / "xishan" / "env_config.yaml"
+    MODEL_PATH = UF_RL_ROOT / "longting" / "48h_dqn_model.zip"
+    ENV_CONFIG_PATH = UF_RL_ROOT / "longting" / "env_config.yaml"
     TMP0 = 0.019  # 原始 TMP0
     q_UF = 300 # 进水流量
     temp = 20.0 #进水温度
     IS_TIMES = False # 新增指定变量,表示CEB间隔为时间控制/次数控制,T表示48次bw一次CEB,F表示48h一次CEB
 
     current_state = UFState(TMP=TMP0, q_UF=q_UF, temp=temp)
+
     config_loader = EnvConfigLoader(ENV_CONFIG_PATH)
     config_loader.validate_config()
     config_loader.print_config_summary()
     (
-        uf_state_default,  # UFState默认值(可用于reset)
+        uf_state_default,  # UFState默认值
         phys_params,  # UFPhysicsParams
         action_spec,  # UFActionSpec
         reward_params,  # UFRewardParams
@@ -243,6 +250,9 @@ if __name__ == "__main__":
     action_id, model_L_s, model_t_bw_s = run_dqn_decide(
         model_path=MODEL_PATH,
         physics=physics,
+        action_spec=action_spec,
+        reward_params=reward_params,
+        state_bounds=state_bounds,
         current_state=current_state,
     ) # 环境实例化,模型加载等功能放在UFDQNDecider类中
 
@@ -250,7 +260,7 @@ if __name__ == "__main__":
     current_t_bw_s = 40
     model_prev_L_s = 4040
     model_prev_t_bw_s = 60
-    L_s, t_bw_s = generate_plc_instructions(current_L_s, current_t_bw_s, model_prev_L_s, model_prev_t_bw_s, model_L_s,
+    L_s, t_bw_s = generate_plc_instructions(action_spec, current_L_s, current_t_bw_s, model_prev_L_s, model_prev_t_bw_s, model_L_s,
                                             model_t_bw_s)  # 获取模型下发指令
 
     L_s = 4100