| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301 |
- """
- DQN 超滤强化学习训练与测试主脚本(工程化优化版)
- """
- import random
- from pathlib import Path
- import numpy as np
- import torch
- # ============================================================
- # 1. 导入模块
- # ============================================================
- CURRENT_DIR = Path(__file__).resolve().parent
- PROJECT_ROOT = CURRENT_DIR.parents[2] # uf-rl
- # ---------- 数据 ----------
- from data_to_rl.data_splitter import ResetStatePoolLoader
- # ---------- 阻力模型 ----------
- from env.uf_resistance_models_load import load_resistance_models
- from env.uf_physics import UFPhysicsModel
- # ---------- 强化学习环境 ----------
- from env.env_params import (UFActionSpec, UFRewardParams, UFStateBounds)
- from env.env_config_loader import EnvConfigLoader, create_env_params_from_yaml
- from env.uf_env import UFSuperCycleEnv
- from env.env_visual import UFEpisodeRecorder, UFTrainingCallback
- from rl_model.DQN.dqn_model.dqn_config_loader import DQNConfigLoader
- from rl_model.DQN.uf_train.dqn_trainer import DQNTrainer
- # ---------- SB3 ----------
- from stable_baselines3.common.monitor import Monitor
- from stable_baselines3.common.vec_env import DummyVecEnv
- # ============================================================
- # 随机种子
- # ============================================================
- def set_global_seed(seed: int):
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
- print(f"[Seed] Global random seed = {seed}")
- # ============================================================
- # 4. Reset State Pool 加载与划分
- # ============================================================
- def load_reset_state_pools():
- loader = ResetStatePoolLoader(
- csv_path=RESET_STATE_CSV,
- train_ratio=0.8,
- shuffle=True,
- random_state=RANDOM_SEED,
- )
- train_pool, val_pool = loader.split()
- print("[Data] Reset state pool loaded")
- print(f" Train pool size: {len(train_pool)}")
- print(f" Val pool size: {len(val_pool)}")
- return train_pool, val_pool
- # ============================================================
- # 5. 环境构造函数
- # ============================================================
- def make_env(
- physics: UFPhysicsModel,
- reward_params: UFRewardParams,
- action_spec: UFActionSpec,
- statebounds: UFStateBounds,
- reset_state_pool,
- seed: int,
- ):
- def _init():
- env = UFSuperCycleEnv(
- physics=physics,
- reward_params=reward_params,
- action_spec=action_spec,
- statebounds=statebounds,
- real_state_pool=reset_state_pool,
- RANDOM_SEED=seed,
- )
- env.action_space.seed(seed)
- env.observation_space.seed(seed)
- return Monitor(env)
- return _init
- # ============================================================
- # 6. 主流程
- # ============================================================
- def main():
- # 创建配置加载器
- config_loader = EnvConfigLoader(ENV_CONFIG_PATH)
- # 验证配置
- config_loader.validate_config()
- config_loader.print_config_summary()
- # 加载所有参数类
- (
- uf_state_default, # UFState默认值(可用于reset)
- phys_params, # UFPhysicsParams
- action_spec, # UFActionSpec
- reward_params, # UFRewardParams
- state_bounds # UFStateBounds
- ) = create_env_params_from_yaml(ENV_CONFIG_PATH)
- # ---------- Seed ----------
- set_global_seed(RANDOM_SEED)
- # ---------- Reset states ----------
- train_pool, val_pool = load_reset_state_pools()
- # ---------- Resistance models ----------
- res_fp, res_bw = load_resistance_models(phys_params)
- # ---------- Physics ----------
- physics_model = UFPhysicsModel(
- phys_params=phys_params,
- state_bounds=state_bounds,
- resistance_model_fp=res_fp,
- resistance_model_bw=res_bw,
- IS_TIMES=IS_TIMES
- )
- # ---------- Environments ----------
- train_env = DummyVecEnv([
- make_env(
- physics_model,
- reward_params,
- action_spec,
- state_bounds,
- train_pool,
- RANDOM_SEED,
- )
- ])
- val_env = DummyVecEnv([
- make_env(
- physics_model,
- reward_params,
- action_spec,
- state_bounds,
- val_pool,
- RANDOM_SEED,
- )
- ])
- # ---------- Callback ----------
- recorder = UFEpisodeRecorder()
- callback = UFTrainingCallback(recorder, verbose=1)
- # ---------- Trainer ----------
- # ========== 2. 加载DQN配置 ==========
- dqn_loader = DQNConfigLoader(MODEL_CONFIG_PATH)
- dqn_loader.validate_config()
- dqn_params = dqn_loader.load_params()
- dqn_loader.print_config_summary()
- trainer = DQNTrainer(
- env=train_env,
- params=dqn_params,
- callback=callback,
- PROJECT_ROOT=PROJECT_ROOT
- )
- # ---------- Training ----------
- print("\n Start training")
- trainer.train(total_timesteps=TOTAL_TIMESTEPS)
- trainer.save()
- # ========================================================
- # 验证
- # ========================================================
- print("\n[Eval] Start validation rollout")
- TMP0_min = 0.01
- TMP0_max = 0.08
- rewards = []
- # ---------- 用于可视化的容器(只记录第一个 episode) ----------
- vis_tmp_series = []
- vis_action_series = []
- for ep_idx in range(len(val_pool)):
- obs = val_env.reset()
- episode_reward = 0.0
- for step in range(10):
- # ====================================================
- # 可视化:只记录第一个 validation episode
- # ====================================================
- if ep_idx == 0:
- TMP0_norm = obs[0]
- TMP0 = (
- TMP0_norm * (TMP0_max - TMP0_min)
- + TMP0_min
- )
- vis_tmp_series.append(TMP0)
- # ---------------- 策略决策 ----------------
- action, _ = trainer.model.predict(
- obs, deterministic=True
- )
- if ep_idx == 0:
- vis_action_series.append(action[0])
- # ---------------- 环境推进 ----------------
- obs, reward, done, _ = val_env.step(action)
- episode_reward += reward[0]
- if done:
- break
- rewards.append(episode_reward)
- # ========================================================
- # 验证结果保存
- # ========================================================
- rewards = np.asarray(rewards)
- save_path = Path(trainer.log_dir) / "val_rewards.npy"
- np.save(save_path, rewards)
- print(f"[Eval] Saved to {save_path}")
- print(f"[Eval] Mean reward = {rewards.mean():.3f}")
- # ========================================================
- # 可视化(第一个 validation episode)
- # ========================================================
- import matplotlib.pyplot as plt
- vis_tmp_series = np.asarray(vis_tmp_series)
- vis_action_series = np.asarray(vis_action_series)
- steps = np.arange(len(vis_tmp_series))
- # ---------- TMP 曲线 ----------
- plt.figure()
- plt.plot(steps, vis_tmp_series, marker="o")
- plt.axhline(
- TMP0_max,
- linestyle="--",
- label="TMP Upper Limit"
- )
- plt.xlabel("Step")
- plt.ylabel("TMP (MPa)")
- plt.title("Validation Episode TMP Evolution")
- plt.legend()
- plt.grid(True)
- plt.show()
- # ---------- Action 曲线 ----------
- plt.figure()
- plt.plot(steps, vis_action_series, marker="o")
- plt.xlabel("Step")
- plt.ylabel("Action")
- plt.title("Validation Episode Action Output")
- plt.grid(True)
- plt.show()
- # ============================================================
- # 入口
- # ============================================================
- if __name__ == "__main__":
- # ============================================================
- # 2. 全局配置
- # ============================================================
- RANDOM_SEED = 2025
- TOTAL_TIMESTEPS = 300000
- IS_TIMES = False
- RESET_STATE_CSV = (
- PROJECT_ROOT
- / "datasets/UF_anzhen_data/rl_ready/output/reset_state_pool.csv"
- )
- ENV_CONFIG_PATH = PROJECT_ROOT / "anzhen" / "env_config.yaml"
- MODEL_CONFIG_PATH = PROJECT_ROOT / "anzhen" / "dqn_config.yaml"
- main()
|