run_dqn_train.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. """
  2. DQN 超滤强化学习训练与测试主脚本(工程化优化版)
  3. """
  4. import os
  5. import sys
  6. import random
  7. from pathlib import Path
  8. import numpy as np
  9. import torch
  10. # ============================================================
  11. # 1. 导入模块
  12. # ============================================================
  13. CURRENT_DIR = Path(__file__).resolve().parent
  14. PROJECT_ROOT = CURRENT_DIR.parents[2] # uf_train # uf-rl
  15. # ---------- 数据 ----------
  16. from uf_train.data_to_rl.data_splitter import ResetStatePoolLoader
  17. # ---------- 阻力模型 ----------
  18. from uf_train.env.uf_resistance_models_load import load_resistance_models
  19. from uf_train.env.uf_physics import UFPhysicsModel
  20. from uf_train.env.env_params import UFPhysicsParams, UFStateBounds, UFRewardParams, UFActionSpec
  21. from uf_train.env.uf_env import UFSuperCycleEnv
  22. from uf_train.env.env_visual import UFEpisodeRecorder, UFTrainingCallback
  23. from uf_train.rl_model.DQN.dqn_params import DQNParams
  24. from uf_train.rl_model.DQN.dqn_trainer import DQNTrainer
  25. # ---------- SB3 ----------
  26. from stable_baselines3.common.monitor import Monitor
  27. from stable_baselines3.common.vec_env import DummyVecEnv
  28. # ============================================================
  29. # 随机种子
  30. # ============================================================
  31. def set_global_seed(seed: int):
  32. random.seed(seed)
  33. np.random.seed(seed)
  34. torch.manual_seed(seed)
  35. torch.cuda.manual_seed_all(seed)
  36. torch.backends.cudnn.deterministic = True
  37. torch.backends.cudnn.benchmark = False
  38. print(f"[Seed] Global random seed = {seed}")
  39. # ============================================================
  40. # 4. Reset State Pool 加载与划分
  41. # ============================================================
  42. def load_reset_state_pools():
  43. loader = ResetStatePoolLoader(
  44. csv_path=RESET_STATE_CSV,
  45. train_ratio=0.8,
  46. shuffle=True,
  47. random_state=RANDOM_SEED,
  48. )
  49. train_pool, val_pool = loader.split()
  50. print("[Data] Reset state pool loaded")
  51. print(f" Train pool size: {len(train_pool)}")
  52. print(f" Val pool size: {len(val_pool)}")
  53. return train_pool, val_pool
  54. # ============================================================
  55. # 5. 环境构造函数
  56. # ============================================================
  57. def make_env(
  58. physics: UFPhysicsModel,
  59. reward_params: UFRewardParams,
  60. action_spec: UFActionSpec,
  61. statebounds: UFStateBounds,
  62. reset_state_pool,
  63. seed: int,
  64. ):
  65. def _init():
  66. env = UFSuperCycleEnv(
  67. physics=physics,
  68. reward_params=reward_params,
  69. action_spec=action_spec,
  70. statebounds=statebounds,
  71. real_state_pool=reset_state_pool,
  72. RANDOM_SEED=seed,
  73. )
  74. env.action_space.seed(seed)
  75. env.observation_space.seed(seed)
  76. return Monitor(env)
  77. return _init
  78. # ============================================================
  79. # 6. 主流程
  80. # ============================================================
  81. def main():
  82. # ---------- Seed ----------
  83. set_global_seed(RANDOM_SEED)
  84. # ---------- Reset states ----------
  85. train_pool, val_pool = load_reset_state_pools()
  86. # ---------- Resistance models ----------
  87. phys_params = UFPhysicsParams()
  88. res_fp, res_bw = load_resistance_models(phys_params)
  89. # ---------- Physics ----------
  90. physics_model = UFPhysicsModel(
  91. phys_params=phys_params,
  92. resistance_model_fp=res_fp,
  93. resistance_model_bw=res_bw,
  94. )
  95. # ---------- RL specs ----------
  96. reward_params = UFRewardParams()
  97. action_spec = UFActionSpec()
  98. state_bounds = UFStateBounds()
  99. # ---------- Environments ----------
  100. train_env = DummyVecEnv([
  101. make_env(
  102. physics_model,
  103. reward_params,
  104. action_spec,
  105. state_bounds,
  106. train_pool,
  107. RANDOM_SEED,
  108. )
  109. ])
  110. val_env = DummyVecEnv([
  111. make_env(
  112. physics_model,
  113. reward_params,
  114. action_spec,
  115. state_bounds,
  116. val_pool,
  117. RANDOM_SEED,
  118. )
  119. ])
  120. # ---------- Callback ----------
  121. recorder = UFEpisodeRecorder()
  122. callback = UFTrainingCallback(recorder, verbose=1)
  123. # ---------- Trainer ----------
  124. dqn_params = DQNParams(remark="uf_dqn_real_reset")
  125. trainer = DQNTrainer(
  126. env=train_env,
  127. params=dqn_params,
  128. callback=callback,
  129. PROJECT_ROOT=PROJECT_ROOT
  130. )
  131. # ---------- Training ----------
  132. print("\n Start training")
  133. trainer.train(total_timesteps=TOTAL_TIMESTEPS)
  134. trainer.save()
  135. # ========================================================
  136. # 验证
  137. # ========================================================
  138. print("\n[Eval] Start validation rollout")
  139. TMP0_min = 0.01
  140. TMP0_max = 0.08
  141. rewards = []
  142. # ---------- 用于可视化的容器(只记录第一个 episode) ----------
  143. vis_tmp_series = []
  144. vis_action_series = []
  145. for ep_idx in range(len(val_pool)):
  146. obs = val_env.reset()
  147. episode_reward = 0.0
  148. for step in range(10):
  149. # ====================================================
  150. # 可视化:只记录第一个 validation episode
  151. # ====================================================
  152. if ep_idx == 0:
  153. TMP0_norm = obs[0]
  154. TMP0 = (
  155. TMP0_norm * (TMP0_max - TMP0_min)
  156. + TMP0_min
  157. )
  158. vis_tmp_series.append(TMP0)
  159. # ---------------- 策略决策 ----------------
  160. action, _ = trainer.model.predict(
  161. obs, deterministic=True
  162. )
  163. if ep_idx == 0:
  164. vis_action_series.append(action[0])
  165. # ---------------- 环境推进 ----------------
  166. obs, reward, done, _ = val_env.step(action)
  167. episode_reward += reward[0]
  168. if done:
  169. break
  170. rewards.append(episode_reward)
  171. # ========================================================
  172. # 验证结果保存
  173. # ========================================================
  174. rewards = np.asarray(rewards)
  175. save_path = Path(trainer.log_dir) / "val_rewards.npy"
  176. np.save(save_path, rewards)
  177. print(f"[Eval] Saved to {save_path}")
  178. print(f"[Eval] Mean reward = {rewards.mean():.3f}")
  179. # ========================================================
  180. # 可视化(第一个 validation episode)
  181. # ========================================================
  182. import matplotlib.pyplot as plt
  183. vis_tmp_series = np.asarray(vis_tmp_series)
  184. vis_action_series = np.asarray(vis_action_series)
  185. steps = np.arange(len(vis_tmp_series))
  186. # ---------- TMP 曲线 ----------
  187. plt.figure()
  188. plt.plot(steps, vis_tmp_series, marker="o")
  189. plt.axhline(
  190. TMP0_max,
  191. linestyle="--",
  192. label="TMP Upper Limit"
  193. )
  194. plt.xlabel("Step")
  195. plt.ylabel("TMP (MPa)")
  196. plt.title("Validation Episode TMP Evolution")
  197. plt.legend()
  198. plt.grid(True)
  199. plt.show()
  200. # ---------- Action 曲线 ----------
  201. plt.figure()
  202. plt.plot(steps, vis_action_series, marker="o")
  203. plt.xlabel("Step")
  204. plt.ylabel("Action")
  205. plt.title("Validation Episode Action Output")
  206. plt.grid(True)
  207. plt.show()
  208. # ============================================================
  209. # 入口
  210. # ============================================================
  211. if __name__ == "__main__":
  212. # ============================================================
  213. # 2. 全局配置
  214. # ============================================================
  215. RANDOM_SEED = 2025
  216. TOTAL_TIMESTEPS = 1500000
  217. RESET_STATE_CSV = (
  218. PROJECT_ROOT
  219. / "datasets/rl_ready/output/reset_state_pool.csv"
  220. )
  221. main()