run_dqn_train.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. """
  2. DQN 超滤强化学习训练与测试主脚本(工程化优化版)
  3. """
  4. import os
  5. import sys
  6. import random
  7. from pathlib import Path
  8. import numpy as np
  9. import torch
  10. # ============================================================
  11. # 1. 导入模块
  12. # ============================================================
  13. CURRENT_DIR = Path(__file__).resolve().parent
  14. PROJECT_ROOT = CURRENT_DIR.parents[2] # uf_train # uf-rl
  15. # ---------- 数据 ----------
  16. from uf_train.data_to_rl.data_splitter import ResetStatePoolLoader
  17. # ---------- 阻力模型 ----------
  18. from uf_train.env.uf_resistance_models_load import load_resistance_models
  19. from uf_train.env.uf_physics import UFPhysicsModel
  20. from uf_train.env.env_params import UFPhysicsParams, UFStateBounds, UFRewardParams, UFActionSpec
  21. from uf_train.env.uf_env import UFSuperCycleEnv
  22. from uf_train.env.env_visual import UFEpisodeRecorder, UFTrainingCallback
  23. from uf_train.rl_model.DQN.dqn_params import DQNParams
  24. from uf_train.rl_model.DQN.dqn_trainer import DQNTrainer
  25. # ---------- SB3 ----------
  26. from stable_baselines3.common.monitor import Monitor
  27. from stable_baselines3.common.vec_env import DummyVecEnv
  28. # ============================================================
  29. # 随机种子
  30. # ============================================================
  31. def set_global_seed(seed: int):
  32. random.seed(seed)
  33. np.random.seed(seed)
  34. torch.manual_seed(seed)
  35. torch.cuda.manual_seed_all(seed)
  36. torch.backends.cudnn.deterministic = True
  37. torch.backends.cudnn.benchmark = False
  38. print(f"[Seed] Global random seed = {seed}")
  39. # ============================================================
  40. # 4. Reset State Pool 加载与划分
  41. # ============================================================
  42. def load_reset_state_pools():
  43. loader = ResetStatePoolLoader(
  44. csv_path=RESET_STATE_CSV,
  45. train_ratio=0.8,
  46. shuffle=True,
  47. random_state=RANDOM_SEED,
  48. )
  49. train_pool, val_pool = loader.split()
  50. print("[Data] Reset state pool loaded")
  51. print(f" Train pool size: {len(train_pool)}")
  52. print(f" Val pool size: {len(val_pool)}")
  53. return train_pool, val_pool
  54. # ============================================================
  55. # 5. 环境构造函数
  56. # ============================================================
  57. def make_env(
  58. physics: UFPhysicsModel,
  59. reward_params: UFRewardParams,
  60. action_spec: UFActionSpec,
  61. statebounds: UFStateBounds,
  62. reset_state_pool,
  63. seed: int,
  64. ):
  65. def _init():
  66. env = UFSuperCycleEnv(
  67. physics=physics,
  68. reward_params=reward_params,
  69. action_spec=action_spec,
  70. statebounds=statebounds,
  71. real_state_pool=reset_state_pool,
  72. RANDOM_SEED=seed,
  73. )
  74. env.action_space.seed(seed)
  75. env.observation_space.seed(seed)
  76. return Monitor(env)
  77. return _init
  78. # ============================================================
  79. # 6. 主流程
  80. # ============================================================
  81. def main():
  82. # ---------- Seed ----------
  83. set_global_seed(RANDOM_SEED)
  84. # ---------- Reset states ----------
  85. train_pool, val_pool = load_reset_state_pools()
  86. # ---------- Resistance models ----------
  87. phys_params = UFPhysicsParams()
  88. res_fp, res_bw = load_resistance_models(phys_params)
  89. # ---------- Physics ----------
  90. physics_model = UFPhysicsModel(
  91. phys_params=phys_params,
  92. resistance_model_fp=res_fp,
  93. resistance_model_bw=res_bw,
  94. )
  95. # ---------- RL specs ----------
  96. reward_params = UFRewardParams()
  97. action_spec = UFActionSpec()
  98. state_bounds = UFStateBounds()
  99. # ---------- Environments ----------
  100. train_env = DummyVecEnv([
  101. make_env(
  102. physics_model,
  103. reward_params,
  104. action_spec,
  105. state_bounds,
  106. train_pool,
  107. RANDOM_SEED,
  108. )
  109. ])
  110. val_env = DummyVecEnv([
  111. make_env(
  112. physics_model,
  113. reward_params,
  114. action_spec,
  115. state_bounds,
  116. val_pool,
  117. RANDOM_SEED,
  118. )
  119. ])
  120. # ---------- Callback ----------
  121. recorder = UFEpisodeRecorder()
  122. callback = UFTrainingCallback(recorder, verbose=1)
  123. # ---------- Trainer ----------
  124. dqn_params = DQNParams(remark="uf_dqn_real_reset")
  125. trainer = DQNTrainer(
  126. env=train_env,
  127. params=dqn_params,
  128. callback=callback,
  129. PROJECT_ROOT=PROJECT_ROOT
  130. )
  131. # ---------- Training ----------
  132. print("\n Start training")
  133. trainer.train(total_timesteps=TOTAL_TIMESTEPS)
  134. trainer.save()
  135. # ========================================================
  136. # 验证
  137. # ========================================================
  138. print("\n[Eval] Start validation rollout")
  139. rewards = []
  140. for _ in range(len(val_pool)):
  141. obs = val_env.reset()
  142. episode_reward = 0.0
  143. for _ in range(10):
  144. action, _ = trainer.model.predict(
  145. obs, deterministic=True
  146. )
  147. obs, reward, done, _ = val_env.step(action)
  148. episode_reward += reward[0]
  149. if done:
  150. break
  151. rewards.append(episode_reward)
  152. rewards = np.asarray(rewards)
  153. save_path = Path(trainer.log_dir) / "val_rewards.npy"
  154. np.save(save_path, rewards)
  155. print(f"[Eval] Saved to {save_path}")
  156. print(f"[Eval] Mean reward = {rewards.mean():.3f}")
  157. # ============================================================
  158. # 入口
  159. # ============================================================
  160. if __name__ == "__main__":
  161. # ============================================================
  162. # 2. 全局配置
  163. # ============================================================
  164. RANDOM_SEED = 2025
  165. TOTAL_TIMESTEPS = 1500000
  166. RESET_STATE_CSV = (
  167. PROJECT_ROOT
  168. / "datasets/rl_ready/output/reset_state_pool.csv"
  169. )
  170. main()