run_dqn_train.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. """
  2. DQN 超滤强化学习训练与测试主脚本(工程化优化版)
  3. """
  4. import os
  5. import sys
  6. import random
  7. from pathlib import Path
  8. import numpy as np
  9. import torch
  10. # ============================================================
  11. # 1. 导入模块
  12. # ============================================================
  13. CURRENT_DIR = Path(__file__).resolve().parent
  14. PROJECT_ROOT = CURRENT_DIR.parents[2] # uf_train # uf-rl
  15. # ---------- 数据 ----------
  16. from uf_train.data_to_rl.data_splitter import ResetStatePoolLoader
  17. # ---------- 阻力模型 ----------
  18. from uf_train.env.uf_resistance_models_load import load_resistance_models
  19. from uf_train.env.uf_physics import UFPhysicsModel
  20. # ---------- 强化学习环境 ----------
  21. from uf_train.env.env_params import (UFState,UFPhysicsParams,UFActionSpec,UFRewardParams,UFStateBounds)
  22. from uf_train.env.env_config_loader import EnvConfigLoader, create_env_params_from_yaml
  23. from uf_train.env.uf_env import UFSuperCycleEnv
  24. from uf_train.env.env_visual import UFEpisodeRecorder, UFTrainingCallback
  25. from uf_train.rl_model.DQN.dqn_params import DQNParams
  26. from uf_train.rl_model.DQN.dqn_config_loader import DQNConfigLoader, load_dqn_config_with_validation
  27. from uf_train.rl_model.DQN.dqn_trainer import DQNTrainer
  28. # ---------- SB3 ----------
  29. from stable_baselines3.common.monitor import Monitor
  30. from stable_baselines3.common.vec_env import DummyVecEnv
  31. # ============================================================
  32. # 随机种子
  33. # ============================================================
  34. def set_global_seed(seed: int):
  35. random.seed(seed)
  36. np.random.seed(seed)
  37. torch.manual_seed(seed)
  38. torch.cuda.manual_seed_all(seed)
  39. torch.backends.cudnn.deterministic = True
  40. torch.backends.cudnn.benchmark = False
  41. print(f"[Seed] Global random seed = {seed}")
  42. # ============================================================
  43. # 4. Reset State Pool 加载与划分
  44. # ============================================================
  45. def load_reset_state_pools():
  46. loader = ResetStatePoolLoader(
  47. csv_path=RESET_STATE_CSV,
  48. train_ratio=0.8,
  49. shuffle=True,
  50. random_state=RANDOM_SEED,
  51. )
  52. train_pool, val_pool = loader.split()
  53. print("[Data] Reset state pool loaded")
  54. print(f" Train pool size: {len(train_pool)}")
  55. print(f" Val pool size: {len(val_pool)}")
  56. return train_pool, val_pool
  57. # ============================================================
  58. # 5. 环境构造函数
  59. # ============================================================
  60. def make_env(
  61. physics: UFPhysicsModel,
  62. reward_params: UFRewardParams,
  63. action_spec: UFActionSpec,
  64. statebounds: UFStateBounds,
  65. reset_state_pool,
  66. seed: int,
  67. ):
  68. def _init():
  69. env = UFSuperCycleEnv(
  70. physics=physics,
  71. reward_params=reward_params,
  72. action_spec=action_spec,
  73. statebounds=statebounds,
  74. real_state_pool=reset_state_pool,
  75. RANDOM_SEED=seed,
  76. )
  77. env.action_space.seed(seed)
  78. env.observation_space.seed(seed)
  79. return Monitor(env)
  80. return _init
  81. # ============================================================
  82. # 6. 主流程
  83. # ============================================================
  84. def main():
  85. # 创建配置加载器
  86. config_loader = EnvConfigLoader(ENV_CONFIG_PATH)
  87. # 验证配置
  88. config_loader.validate_config()
  89. config_loader.print_config_summary()
  90. # 加载所有参数类
  91. (
  92. uf_state_default, # UFState默认值(可用于reset)
  93. phys_params, # UFPhysicsParams
  94. action_spec, # UFActionSpec
  95. reward_params, # UFRewardParams
  96. state_bounds # UFStateBounds
  97. ) = create_env_params_from_yaml(ENV_CONFIG_PATH)
  98. # ---------- Seed ----------
  99. set_global_seed(RANDOM_SEED)
  100. # ---------- Reset states ----------
  101. train_pool, val_pool = load_reset_state_pools()
  102. # ---------- Resistance models ----------
  103. res_fp, res_bw = load_resistance_models(phys_params)
  104. # ---------- Physics ----------
  105. physics_model = UFPhysicsModel(
  106. phys_params=phys_params,
  107. resistance_model_fp=res_fp,
  108. resistance_model_bw=res_bw,
  109. IS_TIMES=IS_TIMES
  110. )
  111. # ---------- Environments ----------
  112. train_env = DummyVecEnv([
  113. make_env(
  114. physics_model,
  115. reward_params,
  116. action_spec,
  117. state_bounds,
  118. train_pool,
  119. RANDOM_SEED,
  120. )
  121. ])
  122. val_env = DummyVecEnv([
  123. make_env(
  124. physics_model,
  125. reward_params,
  126. action_spec,
  127. state_bounds,
  128. val_pool,
  129. RANDOM_SEED,
  130. )
  131. ])
  132. # ---------- Callback ----------
  133. recorder = UFEpisodeRecorder()
  134. callback = UFTrainingCallback(recorder, verbose=1)
  135. # ---------- Trainer ----------
  136. # ========== 2. 加载DQN配置 ==========
  137. dqn_loader = DQNConfigLoader(MODEL_CONFIG_PATH)
  138. dqn_loader.validate_config()
  139. dqn_params = dqn_loader.load_params()
  140. dqn_loader.print_config_summary()
  141. trainer = DQNTrainer(
  142. env=train_env,
  143. params=dqn_params,
  144. callback=callback,
  145. PROJECT_ROOT=PROJECT_ROOT
  146. )
  147. # ---------- Training ----------
  148. print("\n Start training")
  149. trainer.train(total_timesteps=TOTAL_TIMESTEPS)
  150. trainer.save()
  151. # ========================================================
  152. # 验证
  153. # ========================================================
  154. print("\n[Eval] Start validation rollout")
  155. TMP0_min = 0.01
  156. TMP0_max = 0.08
  157. rewards = []
  158. # ---------- 用于可视化的容器(只记录第一个 episode) ----------
  159. vis_tmp_series = []
  160. vis_action_series = []
  161. for ep_idx in range(len(val_pool)):
  162. obs = val_env.reset()
  163. episode_reward = 0.0
  164. for step in range(10):
  165. # ====================================================
  166. # 可视化:只记录第一个 validation episode
  167. # ====================================================
  168. if ep_idx == 0:
  169. TMP0_norm = obs[0]
  170. TMP0 = (
  171. TMP0_norm * (TMP0_max - TMP0_min)
  172. + TMP0_min
  173. )
  174. vis_tmp_series.append(TMP0)
  175. # ---------------- 策略决策 ----------------
  176. action, _ = trainer.model.predict(
  177. obs, deterministic=True
  178. )
  179. if ep_idx == 0:
  180. vis_action_series.append(action[0])
  181. # ---------------- 环境推进 ----------------
  182. obs, reward, done, _ = val_env.step(action)
  183. episode_reward += reward[0]
  184. if done:
  185. break
  186. rewards.append(episode_reward)
  187. # ========================================================
  188. # 验证结果保存
  189. # ========================================================
  190. rewards = np.asarray(rewards)
  191. save_path = Path(trainer.log_dir) / "val_rewards.npy"
  192. np.save(save_path, rewards)
  193. print(f"[Eval] Saved to {save_path}")
  194. print(f"[Eval] Mean reward = {rewards.mean():.3f}")
  195. # ========================================================
  196. # 可视化(第一个 validation episode)
  197. # ========================================================
  198. import matplotlib.pyplot as plt
  199. vis_tmp_series = np.asarray(vis_tmp_series)
  200. vis_action_series = np.asarray(vis_action_series)
  201. steps = np.arange(len(vis_tmp_series))
  202. # ---------- TMP 曲线 ----------
  203. plt.figure()
  204. plt.plot(steps, vis_tmp_series, marker="o")
  205. plt.axhline(
  206. TMP0_max,
  207. linestyle="--",
  208. label="TMP Upper Limit"
  209. )
  210. plt.xlabel("Step")
  211. plt.ylabel("TMP (MPa)")
  212. plt.title("Validation Episode TMP Evolution")
  213. plt.legend()
  214. plt.grid(True)
  215. plt.show()
  216. # ---------- Action 曲线 ----------
  217. plt.figure()
  218. plt.plot(steps, vis_action_series, marker="o")
  219. plt.xlabel("Step")
  220. plt.ylabel("Action")
  221. plt.title("Validation Episode Action Output")
  222. plt.grid(True)
  223. plt.show()
  224. # ============================================================
  225. # 入口
  226. # ============================================================
  227. if __name__ == "__main__":
  228. # ============================================================
  229. # 2. 全局配置
  230. # ============================================================
  231. RANDOM_SEED = 2025
  232. TOTAL_TIMESTEPS = 300000
  233. IS_TIMES = True
  234. RESET_STATE_CSV = (
  235. PROJECT_ROOT
  236. / "datasets/UF_xishan_data/rl_ready/output/reset_state_pool.csv"
  237. )
  238. ENV_CONFIG_PATH = PROJECT_ROOT / "xishan" / "env_config.yaml"
  239. MODEL_CONFIG_PATH = PROJECT_ROOT / "xishan" / "dqn_config.yaml"
  240. main()