run_dqn_train.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. """
  2. DQN 超滤强化学习训练与测试主脚本(工程化优化版)
  3. """
  4. import random
  5. from pathlib import Path
  6. import numpy as np
  7. import torch
  8. # ============================================================
  9. # 1. 导入模块
  10. # ============================================================
  11. CURRENT_DIR = Path(__file__).resolve().parent
  12. PROJECT_ROOT = CURRENT_DIR.parents[2] # uf-rl
  13. # ---------- 数据 ----------
  14. from data_to_rl.data_splitter import ResetStatePoolLoader
  15. # ---------- 阻力模型 ----------
  16. from env.uf_resistance_models_load import load_resistance_models
  17. from env.uf_physics import UFPhysicsModel
  18. # ---------- 强化学习环境 ----------
  19. from env.env_params import (UFActionSpec, UFRewardParams, UFStateBounds)
  20. from env.env_config_loader import EnvConfigLoader, create_env_params_from_yaml
  21. from env.uf_env import UFSuperCycleEnv
  22. from env.env_visual import UFEpisodeRecorder, UFTrainingCallback
  23. from rl_model.DQN.dqn_model.dqn_config_loader import DQNConfigLoader
  24. from rl_model.DQN.uf_train.dqn_trainer import DQNTrainer
  25. # ---------- SB3 ----------
  26. from stable_baselines3.common.monitor import Monitor
  27. from stable_baselines3.common.vec_env import DummyVecEnv
  28. # ============================================================
  29. # 随机种子
  30. # ============================================================
  31. def set_global_seed(seed: int):
  32. random.seed(seed)
  33. np.random.seed(seed)
  34. torch.manual_seed(seed)
  35. torch.cuda.manual_seed_all(seed)
  36. torch.backends.cudnn.deterministic = True
  37. torch.backends.cudnn.benchmark = False
  38. print(f"[Seed] Global random seed = {seed}")
  39. # ============================================================
  40. # 4. Reset State Pool 加载与划分
  41. # ============================================================
  42. def load_reset_state_pools():
  43. loader = ResetStatePoolLoader(
  44. csv_path=RESET_STATE_CSV,
  45. train_ratio=0.8,
  46. shuffle=True,
  47. random_state=RANDOM_SEED,
  48. )
  49. train_pool, val_pool = loader.split()
  50. print("[Data] Reset state pool loaded")
  51. print(f" Train pool size: {len(train_pool)}")
  52. print(f" Val pool size: {len(val_pool)}")
  53. return train_pool, val_pool
  54. # ============================================================
  55. # 5. 环境构造函数
  56. # ============================================================
  57. def make_env(
  58. physics: UFPhysicsModel,
  59. reward_params: UFRewardParams,
  60. action_spec: UFActionSpec,
  61. statebounds: UFStateBounds,
  62. reset_state_pool,
  63. seed: int,
  64. ):
  65. def _init():
  66. env = UFSuperCycleEnv(
  67. physics=physics,
  68. reward_params=reward_params,
  69. action_spec=action_spec,
  70. statebounds=statebounds,
  71. real_state_pool=reset_state_pool,
  72. RANDOM_SEED=seed,
  73. )
  74. env.action_space.seed(seed)
  75. env.observation_space.seed(seed)
  76. return Monitor(env)
  77. return _init
  78. # ============================================================
  79. # 6. 主流程
  80. # ============================================================
  81. def main():
  82. # 创建配置加载器
  83. config_loader = EnvConfigLoader(ENV_CONFIG_PATH)
  84. # 验证配置
  85. config_loader.validate_config()
  86. config_loader.print_config_summary()
  87. # 加载所有参数类
  88. (
  89. uf_state_default, # UFState默认值(可用于reset)
  90. phys_params, # UFPhysicsParams
  91. action_spec, # UFActionSpec
  92. reward_params, # UFRewardParams
  93. state_bounds # UFStateBounds
  94. ) = create_env_params_from_yaml(ENV_CONFIG_PATH)
  95. # ---------- Seed ----------
  96. set_global_seed(RANDOM_SEED)
  97. # ---------- Reset states ----------
  98. train_pool, val_pool = load_reset_state_pools()
  99. # ---------- Resistance models ----------
  100. res_fp, res_bw = load_resistance_models(phys_params)
  101. # ---------- Physics ----------
  102. physics_model = UFPhysicsModel(
  103. phys_params=phys_params,
  104. state_bounds=state_bounds,
  105. resistance_model_fp=res_fp,
  106. resistance_model_bw=res_bw,
  107. IS_TIMES=IS_TIMES
  108. )
  109. # ---------- Environments ----------
  110. train_env = DummyVecEnv([
  111. make_env(
  112. physics_model,
  113. reward_params,
  114. action_spec,
  115. state_bounds,
  116. train_pool,
  117. RANDOM_SEED,
  118. )
  119. ])
  120. val_env = DummyVecEnv([
  121. make_env(
  122. physics_model,
  123. reward_params,
  124. action_spec,
  125. state_bounds,
  126. val_pool,
  127. RANDOM_SEED,
  128. )
  129. ])
  130. # ---------- Callback ----------
  131. recorder = UFEpisodeRecorder()
  132. callback = UFTrainingCallback(recorder, verbose=1)
  133. # ---------- Trainer ----------
  134. # ========== 2. 加载DQN配置 ==========
  135. dqn_loader = DQNConfigLoader(MODEL_CONFIG_PATH)
  136. dqn_loader.validate_config()
  137. dqn_params = dqn_loader.load_params()
  138. dqn_loader.print_config_summary()
  139. trainer = DQNTrainer(
  140. env=train_env,
  141. params=dqn_params,
  142. callback=callback,
  143. PROJECT_ROOT=PROJECT_ROOT,
  144. DIR_NAME=DIR_NAME,
  145. )
  146. # ---------- Training ----------
  147. print("\n Start training")
  148. trainer.train(total_timesteps=TOTAL_TIMESTEPS)
  149. trainer.save()
  150. # ========================================================
  151. # 验证
  152. # ========================================================
  153. print("\n[Eval] Start validation rollout")
  154. TMP0_min = 0.01
  155. TMP0_max = 0.08
  156. rewards = []
  157. # ---------- 用于可视化的容器(只记录第一个 episode) ----------
  158. vis_tmp_series = []
  159. vis_action_series = []
  160. for ep_idx in range(len(val_pool)):
  161. obs = val_env.reset()
  162. episode_reward = 0.0
  163. for step in range(10):
  164. # ====================================================
  165. # 可视化:只记录第一个 validation episode
  166. # ====================================================
  167. if ep_idx == 0:
  168. TMP0_norm = obs[0]
  169. TMP0 = (
  170. TMP0_norm * (TMP0_max - TMP0_min)
  171. + TMP0_min
  172. )
  173. vis_tmp_series.append(TMP0)
  174. # ---------------- 策略决策 ----------------
  175. action, _ = trainer.model.predict(
  176. obs, deterministic=True
  177. )
  178. if ep_idx == 0:
  179. vis_action_series.append(action[0])
  180. # ---------------- 环境推进 ----------------
  181. obs, reward, done, _ = val_env.step(action)
  182. episode_reward += reward[0]
  183. if done:
  184. break
  185. rewards.append(episode_reward)
  186. # ========================================================
  187. # 验证结果保存
  188. # ========================================================
  189. rewards = np.asarray(rewards)
  190. save_path = Path(trainer.log_dir) / "val_rewards.npy"
  191. np.save(save_path, rewards)
  192. print(f"[Eval] Saved to {save_path}")
  193. print(f"[Eval] Mean reward = {rewards.mean():.3f}")
  194. # ========================================================
  195. # 可视化(第一个 validation episode)
  196. # ========================================================
  197. import matplotlib.pyplot as plt
  198. vis_tmp_series = np.asarray(vis_tmp_series)
  199. vis_action_series = np.asarray(vis_action_series)
  200. steps = np.arange(len(vis_tmp_series))
  201. # ---------- TMP 曲线 ----------
  202. plt.figure()
  203. plt.plot(steps, vis_tmp_series, marker="o")
  204. plt.axhline(
  205. TMP0_max,
  206. linestyle="--",
  207. label="TMP Upper Limit"
  208. )
  209. plt.xlabel("Step")
  210. plt.ylabel("TMP (MPa)")
  211. plt.title("Validation Episode TMP Evolution")
  212. plt.legend()
  213. plt.grid(True)
  214. plt.show()
  215. # ---------- Action 曲线 ----------
  216. plt.figure()
  217. plt.plot(steps, vis_action_series, marker="o")
  218. plt.xlabel("Step")
  219. plt.ylabel("Action")
  220. plt.title("Validation Episode Action Output")
  221. plt.grid(True)
  222. plt.show()
  223. # ============================================================
  224. # 入口
  225. # ============================================================
  226. if __name__ == "__main__":
  227. # ============================================================
  228. # 2. 全局配置
  229. # ============================================================
  230. RANDOM_SEED = 2025
  231. TOTAL_TIMESTEPS = 300000
  232. IS_TIMES = False
  233. RESET_STATE_CSV = (
  234. PROJECT_ROOT
  235. / "datasets/UF_longting_data/rl_ready/output/reset_state_pool.csv"
  236. )
  237. ENV_CONFIG_PATH = PROJECT_ROOT / "longting" / "env_config.yaml"
  238. MODEL_CONFIG_PATH = PROJECT_ROOT / "longting" / "dqn_config.yaml"
  239. DIR_NAME = "longting48h"
  240. main()