""" DQN 超滤强化学习训练与测试主脚本(工程化优化版) """ import random from pathlib import Path import numpy as np import torch # ============================================================ # 1. 导入模块 # ============================================================ CURRENT_DIR = Path(__file__).resolve().parent PROJECT_ROOT = CURRENT_DIR.parents[2] # uf-rl # ---------- 数据 ---------- from data_to_rl.data_splitter import ResetStatePoolLoader # ---------- 阻力模型 ---------- from env.uf_resistance_models_load import load_resistance_models from env.uf_physics import UFPhysicsModel # ---------- 强化学习环境 ---------- from env.env_params import (UFActionSpec, UFRewardParams, UFStateBounds) from env.env_config_loader import EnvConfigLoader, create_env_params_from_yaml from env.uf_env import UFSuperCycleEnv from env.env_visual import UFEpisodeRecorder, UFTrainingCallback from rl_model.DQN.dqn_model.dqn_config_loader import DQNConfigLoader from rl_model.DQN.uf_train.dqn_trainer import DQNTrainer # ---------- SB3 ---------- from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv # ============================================================ # 随机种子 # ============================================================ def set_global_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False print(f"[Seed] Global random seed = {seed}") # ============================================================ # 4. Reset State Pool 加载与划分 # ============================================================ def load_reset_state_pools(): loader = ResetStatePoolLoader( csv_path=RESET_STATE_CSV, train_ratio=0.8, shuffle=True, random_state=RANDOM_SEED, ) train_pool, val_pool = loader.split() print("[Data] Reset state pool loaded") print(f" Train pool size: {len(train_pool)}") print(f" Val pool size: {len(val_pool)}") return train_pool, val_pool # ============================================================ # 5. 环境构造函数 # ============================================================ def make_env( physics: UFPhysicsModel, reward_params: UFRewardParams, action_spec: UFActionSpec, statebounds: UFStateBounds, reset_state_pool, seed: int, ): def _init(): env = UFSuperCycleEnv( physics=physics, reward_params=reward_params, action_spec=action_spec, statebounds=statebounds, real_state_pool=reset_state_pool, RANDOM_SEED=seed, ) env.action_space.seed(seed) env.observation_space.seed(seed) return Monitor(env) return _init # ============================================================ # 6. 主流程 # ============================================================ def main(): # 创建配置加载器 config_loader = EnvConfigLoader(ENV_CONFIG_PATH) # 验证配置 config_loader.validate_config() config_loader.print_config_summary() # 加载所有参数类 ( uf_state_default, # UFState默认值(可用于reset) phys_params, # UFPhysicsParams action_spec, # UFActionSpec reward_params, # UFRewardParams state_bounds # UFStateBounds ) = create_env_params_from_yaml(ENV_CONFIG_PATH) # ---------- Seed ---------- set_global_seed(RANDOM_SEED) # ---------- Reset states ---------- train_pool, val_pool = load_reset_state_pools() # ---------- Resistance models ---------- res_fp, res_bw = load_resistance_models(phys_params) # ---------- Physics ---------- physics_model = UFPhysicsModel( phys_params=phys_params, state_bounds=state_bounds, resistance_model_fp=res_fp, resistance_model_bw=res_bw, IS_TIMES=IS_TIMES ) # ---------- Environments ---------- train_env = DummyVecEnv([ make_env( physics_model, reward_params, action_spec, state_bounds, train_pool, RANDOM_SEED, ) ]) val_env = DummyVecEnv([ make_env( physics_model, reward_params, action_spec, state_bounds, val_pool, RANDOM_SEED, ) ]) # ---------- Callback ---------- recorder = UFEpisodeRecorder() callback = UFTrainingCallback(recorder, verbose=1) # ---------- Trainer ---------- # ========== 2. 加载DQN配置 ========== dqn_loader = DQNConfigLoader(MODEL_CONFIG_PATH) dqn_loader.validate_config() dqn_params = dqn_loader.load_params() dqn_loader.print_config_summary() trainer = DQNTrainer( env=train_env, params=dqn_params, callback=callback, PROJECT_ROOT=PROJECT_ROOT, DIR_NAME=DIR_NAME, ) # ---------- Training ---------- print("\n Start training") trainer.train(total_timesteps=TOTAL_TIMESTEPS) trainer.save() # ======================================================== # 验证 # ======================================================== print("\n[Eval] Start validation rollout") TMP0_min = 0.01 TMP0_max = 0.08 rewards = [] # ---------- 用于可视化的容器(只记录第一个 episode) ---------- vis_tmp_series = [] vis_action_series = [] for ep_idx in range(len(val_pool)): obs = val_env.reset() episode_reward = 0.0 for step in range(10): # ==================================================== # 可视化:只记录第一个 validation episode # ==================================================== if ep_idx == 0: TMP0_norm = obs[0] TMP0 = ( TMP0_norm * (TMP0_max - TMP0_min) + TMP0_min ) vis_tmp_series.append(TMP0) # ---------------- 策略决策 ---------------- action, _ = trainer.model.predict( obs, deterministic=True ) if ep_idx == 0: vis_action_series.append(action[0]) # ---------------- 环境推进 ---------------- obs, reward, done, _ = val_env.step(action) episode_reward += reward[0] if done: break rewards.append(episode_reward) # ======================================================== # 验证结果保存 # ======================================================== rewards = np.asarray(rewards) save_path = Path(trainer.log_dir) / "val_rewards.npy" np.save(save_path, rewards) print(f"[Eval] Saved to {save_path}") print(f"[Eval] Mean reward = {rewards.mean():.3f}") # ======================================================== # 可视化(第一个 validation episode) # ======================================================== import matplotlib.pyplot as plt vis_tmp_series = np.asarray(vis_tmp_series) vis_action_series = np.asarray(vis_action_series) steps = np.arange(len(vis_tmp_series)) # ---------- TMP 曲线 ---------- plt.figure() plt.plot(steps, vis_tmp_series, marker="o") plt.axhline( TMP0_max, linestyle="--", label="TMP Upper Limit" ) plt.xlabel("Step") plt.ylabel("TMP (MPa)") plt.title("Validation Episode TMP Evolution") plt.legend() plt.grid(True) plt.show() # ---------- Action 曲线 ---------- plt.figure() plt.plot(steps, vis_action_series, marker="o") plt.xlabel("Step") plt.ylabel("Action") plt.title("Validation Episode Action Output") plt.grid(True) plt.show() # ============================================================ # 入口 # ============================================================ if __name__ == "__main__": # ============================================================ # 2. 全局配置 # ============================================================ RANDOM_SEED = 2025 TOTAL_TIMESTEPS = 300000 IS_TIMES = False RESET_STATE_CSV = ( PROJECT_ROOT / "datasets/UF_longting_data/rl_ready/output/reset_state_pool.csv" ) ENV_CONFIG_PATH = PROJECT_ROOT / "longting" / "env_config.yaml" MODEL_CONFIG_PATH = PROJECT_ROOT / "longting" / "dqn_config.yaml" DIR_NAME = "longting48h" main()