run_dqn_train.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. """
  2. DQN 超滤强化学习训练与测试主脚本(工程化优化版)
  3. """
  4. import random
  5. from pathlib import Path
  6. import numpy as np
  7. import torch
  8. # ============================================================
  9. # 1. 导入模块
  10. # ============================================================
  11. CURRENT_DIR = Path(__file__).resolve().parent
  12. PROJECT_ROOT = CURRENT_DIR.parents[2] # uf-rl
  13. # ---------- 数据 ----------
  14. from data_to_rl.data_splitter import ResetStatePoolLoader
  15. # ---------- 阻力模型 ----------
  16. from env.uf_resistance_models_load import load_resistance_models
  17. from env.uf_physics import UFPhysicsModel
  18. # ---------- 强化学习环境 ----------
  19. from env.env_params import (UFActionSpec, UFRewardParams, UFStateBounds)
  20. from env.env_config_loader import EnvConfigLoader, create_env_params_from_yaml
  21. from env.uf_env import UFSuperCycleEnv
  22. from env.env_visual import UFEpisodeRecorder, UFTrainingCallback
  23. from rl_model.DQN.dqn_model.dqn_config_loader import DQNConfigLoader
  24. from rl_model.DQN.uf_train.dqn_trainer import DQNTrainer
  25. # ---------- SB3 ----------
  26. from stable_baselines3.common.monitor import Monitor
  27. from stable_baselines3.common.vec_env import DummyVecEnv
  28. # ============================================================
  29. # 随机种子
  30. # ============================================================
  31. def set_global_seed(seed: int):
  32. random.seed(seed)
  33. np.random.seed(seed)
  34. torch.manual_seed(seed)
  35. torch.cuda.manual_seed_all(seed)
  36. torch.backends.cudnn.deterministic = True
  37. torch.backends.cudnn.benchmark = False
  38. print(f"[Seed] Global random seed = {seed}")
  39. # ============================================================
  40. # 4. Reset State Pool 加载与划分
  41. # ============================================================
  42. def load_reset_state_pools():
  43. loader = ResetStatePoolLoader(
  44. csv_path=RESET_STATE_CSV,
  45. train_ratio=0.8,
  46. shuffle=True,
  47. random_state=RANDOM_SEED,
  48. )
  49. train_pool, val_pool = loader.split()
  50. print("[Data] Reset state pool loaded")
  51. print(f" Train pool size: {len(train_pool)}")
  52. print(f" Val pool size: {len(val_pool)}")
  53. return train_pool, val_pool
  54. # ============================================================
  55. # 5. 环境构造函数
  56. # ============================================================
  57. def make_env(
  58. physics: UFPhysicsModel,
  59. reward_params: UFRewardParams,
  60. action_spec: UFActionSpec,
  61. statebounds: UFStateBounds,
  62. reset_state_pool,
  63. seed: int,
  64. ):
  65. def _init():
  66. env = UFSuperCycleEnv(
  67. physics=physics,
  68. reward_params=reward_params,
  69. action_spec=action_spec,
  70. statebounds=statebounds,
  71. real_state_pool=reset_state_pool,
  72. RANDOM_SEED=seed,
  73. )
  74. env.action_space.seed(seed)
  75. env.observation_space.seed(seed)
  76. return Monitor(env)
  77. return _init
  78. # ============================================================
  79. # 6. 主流程
  80. # ============================================================
  81. def main():
  82. # 创建配置加载器
  83. config_loader = EnvConfigLoader(ENV_CONFIG_PATH)
  84. # 验证配置
  85. config_loader.validate_config()
  86. config_loader.print_config_summary()
  87. # 加载所有参数类
  88. (
  89. uf_state_default, # UFState默认值(可用于reset)
  90. phys_params, # UFPhysicsParams
  91. action_spec, # UFActionSpec
  92. reward_params, # UFRewardParams
  93. state_bounds # UFStateBounds
  94. ) = create_env_params_from_yaml(ENV_CONFIG_PATH)
  95. # ---------- Seed ----------
  96. set_global_seed(RANDOM_SEED)
  97. # ---------- Reset states ----------
  98. train_pool, val_pool = load_reset_state_pools()
  99. # ---------- Resistance models ----------
  100. res_fp, res_bw = load_resistance_models(phys_params)
  101. # ---------- Physics ----------
  102. physics_model = UFPhysicsModel(
  103. phys_params=phys_params,
  104. resistance_model_fp=res_fp,
  105. resistance_model_bw=res_bw,
  106. IS_TIMES=IS_TIMES
  107. )
  108. # ---------- Environments ----------
  109. train_env = DummyVecEnv([
  110. make_env(
  111. physics_model,
  112. reward_params,
  113. action_spec,
  114. state_bounds,
  115. train_pool,
  116. RANDOM_SEED,
  117. )
  118. ])
  119. val_env = DummyVecEnv([
  120. make_env(
  121. physics_model,
  122. reward_params,
  123. action_spec,
  124. state_bounds,
  125. val_pool,
  126. RANDOM_SEED,
  127. )
  128. ])
  129. # ---------- Callback ----------
  130. recorder = UFEpisodeRecorder()
  131. callback = UFTrainingCallback(recorder, verbose=1)
  132. # ---------- Trainer ----------
  133. # ========== 2. 加载DQN配置 ==========
  134. dqn_loader = DQNConfigLoader(MODEL_CONFIG_PATH)
  135. dqn_loader.validate_config()
  136. dqn_params = dqn_loader.load_params()
  137. dqn_loader.print_config_summary()
  138. trainer = DQNTrainer(
  139. env=train_env,
  140. params=dqn_params,
  141. callback=callback,
  142. PROJECT_ROOT=PROJECT_ROOT
  143. )
  144. # ---------- Training ----------
  145. print("\n Start training")
  146. trainer.train(total_timesteps=TOTAL_TIMESTEPS)
  147. trainer.save()
  148. # ========================================================
  149. # 验证
  150. # ========================================================
  151. print("\n[Eval] Start validation rollout")
  152. TMP0_min = 0.01
  153. TMP0_max = 0.08
  154. rewards = []
  155. # ---------- 用于可视化的容器(只记录第一个 episode) ----------
  156. vis_tmp_series = []
  157. vis_action_series = []
  158. for ep_idx in range(len(val_pool)):
  159. obs = val_env.reset()
  160. episode_reward = 0.0
  161. for step in range(10):
  162. # ====================================================
  163. # 可视化:只记录第一个 validation episode
  164. # ====================================================
  165. if ep_idx == 0:
  166. TMP0_norm = obs[0]
  167. TMP0 = (
  168. TMP0_norm * (TMP0_max - TMP0_min)
  169. + TMP0_min
  170. )
  171. vis_tmp_series.append(TMP0)
  172. # ---------------- 策略决策 ----------------
  173. action, _ = trainer.model.predict(
  174. obs, deterministic=True
  175. )
  176. if ep_idx == 0:
  177. vis_action_series.append(action[0])
  178. # ---------------- 环境推进 ----------------
  179. obs, reward, done, _ = val_env.step(action)
  180. episode_reward += reward[0]
  181. if done:
  182. break
  183. rewards.append(episode_reward)
  184. # ========================================================
  185. # 验证结果保存
  186. # ========================================================
  187. rewards = np.asarray(rewards)
  188. save_path = Path(trainer.log_dir) / "val_rewards.npy"
  189. np.save(save_path, rewards)
  190. print(f"[Eval] Saved to {save_path}")
  191. print(f"[Eval] Mean reward = {rewards.mean():.3f}")
  192. # ========================================================
  193. # 可视化(第一个 validation episode)
  194. # ========================================================
  195. import matplotlib.pyplot as plt
  196. vis_tmp_series = np.asarray(vis_tmp_series)
  197. vis_action_series = np.asarray(vis_action_series)
  198. steps = np.arange(len(vis_tmp_series))
  199. # ---------- TMP 曲线 ----------
  200. plt.figure()
  201. plt.plot(steps, vis_tmp_series, marker="o")
  202. plt.axhline(
  203. TMP0_max,
  204. linestyle="--",
  205. label="TMP Upper Limit"
  206. )
  207. plt.xlabel("Step")
  208. plt.ylabel("TMP (MPa)")
  209. plt.title("Validation Episode TMP Evolution")
  210. plt.legend()
  211. plt.grid(True)
  212. plt.show()
  213. # ---------- Action 曲线 ----------
  214. plt.figure()
  215. plt.plot(steps, vis_action_series, marker="o")
  216. plt.xlabel("Step")
  217. plt.ylabel("Action")
  218. plt.title("Validation Episode Action Output")
  219. plt.grid(True)
  220. plt.show()
  221. # ============================================================
  222. # 入口
  223. # ============================================================
  224. if __name__ == "__main__":
  225. # ============================================================
  226. # 2. 全局配置
  227. # ============================================================
  228. RANDOM_SEED = 2025
  229. TOTAL_TIMESTEPS = 300000
  230. IS_TIMES = False
  231. RESET_STATE_CSV = (
  232. PROJECT_ROOT
  233. / "datasets/UF_longting_data/rl_ready/output/reset_state_pool.csv"
  234. )
  235. ENV_CONFIG_PATH = PROJECT_ROOT / "longting" / "env_config.yaml"
  236. MODEL_CONFIG_PATH = PROJECT_ROOT / "longting" / "dqn_config.yaml"
  237. main()