train_entry.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. """
  2. 通用强化学习训练入口(当前绑定 DQN,实现已验证)
  3. 仅负责:
  4. - 构造环境
  5. - 构造 Trainer
  6. - 启动训练并保存模型
  7. """
  8. import random
  9. from pathlib import Path
  10. import numpy as np
  11. import torch
  12. # ============================================================
  13. # 1. 路径解析
  14. # ============================================================
  15. CURRENT_DIR = Path(__file__).resolve().parent
  16. PROJECT_ROOT = CURRENT_DIR.parents[2] # uf_train / uf-rl
  17. # ============================================================
  18. # 2. 导入:数据 / 环境
  19. # ============================================================
  20. from uf_train.data_to_rl.data_splitter import ResetStatePoolLoader
  21. from uf_train.env.uf_resistance_models_load import load_resistance_models
  22. from uf_train.env.uf_physics import UFPhysicsModel
  23. from uf_train.env.env_params import (
  24. UFPhysicsParams,
  25. UFStateBounds,
  26. UFRewardParams,
  27. UFActionSpec,
  28. )
  29. from uf_train.env.uf_env import UFSuperCycleEnv
  30. from uf_train.env.env_visual import UFEpisodeRecorder, UFTrainingCallback
  31. # ============================================================
  32. # 3. 导入:算法(当前为 DQN)
  33. # ============================================================
  34. from uf_train.rl_model.DQN.dqn_params import DQNParams
  35. from uf_train.rl_model.DQN.dqn_trainer import DQNTrainer
  36. # ============================================================
  37. # 4. SB3 VecEnv
  38. # ============================================================
  39. from stable_baselines3.common.monitor import Monitor
  40. from stable_baselines3.common.vec_env import DummyVecEnv
  41. # ============================================================
  42. # 5. 随机种子
  43. # ============================================================
  44. def set_global_seed(seed: int):
  45. random.seed(seed)
  46. np.random.seed(seed)
  47. torch.manual_seed(seed)
  48. torch.cuda.manual_seed_all(seed)
  49. torch.backends.cudnn.deterministic = True
  50. torch.backends.cudnn.benchmark = False
  51. print(f"[Seed] Global random seed = {seed}")
  52. # ============================================================
  53. # 6. Reset State Pool 加载
  54. # ============================================================
  55. def load_reset_state_pool():
  56. loader = ResetStatePoolLoader(
  57. csv_path=RESET_STATE_CSV,
  58. train_ratio=0.8,
  59. shuffle=True,
  60. random_state=RANDOM_SEED,
  61. )
  62. train_pool, _ = loader.split()
  63. print("[Data] Reset state pool loaded")
  64. print(f" Train pool size: {len(train_pool)}")
  65. return train_pool
  66. # ============================================================
  67. # 7. 环境构造函数
  68. # ============================================================
  69. def make_env(
  70. physics: UFPhysicsModel,
  71. reward_params: UFRewardParams,
  72. action_spec: UFActionSpec,
  73. statebounds: UFStateBounds,
  74. reset_state_pool,
  75. seed: int,
  76. ):
  77. def _init():
  78. env = UFSuperCycleEnv(
  79. physics=physics,
  80. reward_params=reward_params,
  81. action_spec=action_spec,
  82. statebounds=statebounds,
  83. real_state_pool=reset_state_pool,
  84. RANDOM_SEED=seed,
  85. )
  86. env.action_space.seed(seed)
  87. env.observation_space.seed(seed)
  88. return Monitor(env)
  89. return _init
  90. # ============================================================
  91. # 8. 主训练流程
  92. # ============================================================
  93. def main():
  94. # ---------- Seed ----------
  95. set_global_seed(RANDOM_SEED)
  96. # ---------- Reset states ----------
  97. train_pool = load_reset_state_pool()
  98. # ---------- Resistance models ----------
  99. phys_params = UFPhysicsParams()
  100. res_fp, res_bw = load_resistance_models(phys_params)
  101. # ---------- Physics ----------
  102. physics_model = UFPhysicsModel(
  103. phys_params=phys_params,
  104. resistance_model_fp=res_fp,
  105. resistance_model_bw=res_bw,
  106. )
  107. # ---------- RL specs ----------
  108. reward_params = UFRewardParams()
  109. action_spec = UFActionSpec()
  110. state_bounds = UFStateBounds()
  111. # ---------- Training Env ----------
  112. train_env = DummyVecEnv([
  113. make_env(
  114. physics_model,
  115. reward_params,
  116. action_spec,
  117. state_bounds,
  118. train_pool,
  119. RANDOM_SEED,
  120. )
  121. ])
  122. # ---------- Callback ----------
  123. recorder = UFEpisodeRecorder()
  124. callback = UFTrainingCallback(recorder, verbose=1)
  125. # ---------- Trainer ----------
  126. algo_params = DQNParams(remark="uf_dqn_train_only")
  127. trainer = DQNTrainer(
  128. env=train_env,
  129. params=algo_params,
  130. callback=callback,
  131. PROJECT_ROOT=PROJECT_ROOT,
  132. )
  133. # ---------- Training ----------
  134. print("\n[Train] Start training")
  135. trainer.train(total_timesteps=TOTAL_TIMESTEPS)
  136. trainer.save()
  137. print("[Train] Finished")
  138. # ============================================================
  139. # 9. 入口
  140. # ============================================================
  141. if __name__ == "__main__":
  142. RANDOM_SEED = 2025
  143. TOTAL_TIMESTEPS = 1_500_000
  144. RESET_STATE_CSV = (
  145. PROJECT_ROOT
  146. / "datasets/rl_ready/output/reset_state_pool.csv"
  147. )
  148. main()