|
@@ -0,0 +1,186 @@
|
|
|
|
|
+"""
|
|
|
|
|
+通用强化学习训练入口(当前绑定 DQN,实现已验证)
|
|
|
|
|
+仅负责:
|
|
|
|
|
+- 构造环境
|
|
|
|
|
+- 构造 Trainer
|
|
|
|
|
+- 启动训练并保存模型
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+import random
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+import torch
|
|
|
|
|
+
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+# 1. 路径解析
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+CURRENT_DIR = Path(__file__).resolve().parent
|
|
|
|
|
+PROJECT_ROOT = CURRENT_DIR.parents[2] # uf_train / uf-rl
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+# 2. 导入:数据 / 环境
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+from uf_train.data_to_rl.data_splitter import ResetStatePoolLoader
|
|
|
|
|
+
|
|
|
|
|
+from uf_train.env.uf_resistance_models_load import load_resistance_models
|
|
|
|
|
+from uf_train.env.uf_physics import UFPhysicsModel
|
|
|
|
|
+from uf_train.env.env_params import (
|
|
|
|
|
+ UFPhysicsParams,
|
|
|
|
|
+ UFStateBounds,
|
|
|
|
|
+ UFRewardParams,
|
|
|
|
|
+ UFActionSpec,
|
|
|
|
|
+)
|
|
|
|
|
+from uf_train.env.uf_env import UFSuperCycleEnv
|
|
|
|
|
+
|
|
|
|
|
+from uf_train.env.env_visual import UFEpisodeRecorder, UFTrainingCallback
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+# 3. 导入:算法(当前为 DQN)
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+from uf_train.rl_model.DQN.dqn_params import DQNParams
|
|
|
|
|
+from uf_train.rl_model.DQN.dqn_trainer import DQNTrainer
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+# 4. SB3 VecEnv
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+from stable_baselines3.common.monitor import Monitor
|
|
|
|
|
+from stable_baselines3.common.vec_env import DummyVecEnv
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+# 5. 随机种子
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+def set_global_seed(seed: int):
|
|
|
|
|
+ random.seed(seed)
|
|
|
|
|
+ np.random.seed(seed)
|
|
|
|
|
+ torch.manual_seed(seed)
|
|
|
|
|
+ torch.cuda.manual_seed_all(seed)
|
|
|
|
|
+
|
|
|
|
|
+ torch.backends.cudnn.deterministic = True
|
|
|
|
|
+ torch.backends.cudnn.benchmark = False
|
|
|
|
|
+
|
|
|
|
|
+ print(f"[Seed] Global random seed = {seed}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+# 6. Reset State Pool 加载
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+def load_reset_state_pool():
|
|
|
|
|
+ loader = ResetStatePoolLoader(
|
|
|
|
|
+ csv_path=RESET_STATE_CSV,
|
|
|
|
|
+ train_ratio=0.8,
|
|
|
|
|
+ shuffle=True,
|
|
|
|
|
+ random_state=RANDOM_SEED,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ train_pool, _ = loader.split()
|
|
|
|
|
+
|
|
|
|
|
+ print("[Data] Reset state pool loaded")
|
|
|
|
|
+ print(f" Train pool size: {len(train_pool)}")
|
|
|
|
|
+
|
|
|
|
|
+ return train_pool
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+# 7. 环境构造函数
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+def make_env(
|
|
|
|
|
+ physics: UFPhysicsModel,
|
|
|
|
|
+ reward_params: UFRewardParams,
|
|
|
|
|
+ action_spec: UFActionSpec,
|
|
|
|
|
+ statebounds: UFStateBounds,
|
|
|
|
|
+ reset_state_pool,
|
|
|
|
|
+ seed: int,
|
|
|
|
|
+):
|
|
|
|
|
+ def _init():
|
|
|
|
|
+ env = UFSuperCycleEnv(
|
|
|
|
|
+ physics=physics,
|
|
|
|
|
+ reward_params=reward_params,
|
|
|
|
|
+ action_spec=action_spec,
|
|
|
|
|
+ statebounds=statebounds,
|
|
|
|
|
+ real_state_pool=reset_state_pool,
|
|
|
|
|
+ RANDOM_SEED=seed,
|
|
|
|
|
+ )
|
|
|
|
|
+ env.action_space.seed(seed)
|
|
|
|
|
+ env.observation_space.seed(seed)
|
|
|
|
|
+ return Monitor(env)
|
|
|
|
|
+
|
|
|
|
|
+ return _init
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+# 8. 主训练流程
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+def main():
|
|
|
|
|
+ # ---------- Seed ----------
|
|
|
|
|
+ set_global_seed(RANDOM_SEED)
|
|
|
|
|
+
|
|
|
|
|
+ # ---------- Reset states ----------
|
|
|
|
|
+ train_pool = load_reset_state_pool()
|
|
|
|
|
+
|
|
|
|
|
+ # ---------- Resistance models ----------
|
|
|
|
|
+ phys_params = UFPhysicsParams()
|
|
|
|
|
+ res_fp, res_bw = load_resistance_models(phys_params)
|
|
|
|
|
+
|
|
|
|
|
+ # ---------- Physics ----------
|
|
|
|
|
+ physics_model = UFPhysicsModel(
|
|
|
|
|
+ phys_params=phys_params,
|
|
|
|
|
+ resistance_model_fp=res_fp,
|
|
|
|
|
+ resistance_model_bw=res_bw,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # ---------- RL specs ----------
|
|
|
|
|
+ reward_params = UFRewardParams()
|
|
|
|
|
+ action_spec = UFActionSpec()
|
|
|
|
|
+ state_bounds = UFStateBounds()
|
|
|
|
|
+
|
|
|
|
|
+ # ---------- Training Env ----------
|
|
|
|
|
+ train_env = DummyVecEnv([
|
|
|
|
|
+ make_env(
|
|
|
|
|
+ physics_model,
|
|
|
|
|
+ reward_params,
|
|
|
|
|
+ action_spec,
|
|
|
|
|
+ state_bounds,
|
|
|
|
|
+ train_pool,
|
|
|
|
|
+ RANDOM_SEED,
|
|
|
|
|
+ )
|
|
|
|
|
+ ])
|
|
|
|
|
+
|
|
|
|
|
+ # ---------- Callback ----------
|
|
|
|
|
+ recorder = UFEpisodeRecorder()
|
|
|
|
|
+ callback = UFTrainingCallback(recorder, verbose=1)
|
|
|
|
|
+
|
|
|
|
|
+ # ---------- Trainer ----------
|
|
|
|
|
+ algo_params = DQNParams(remark="uf_dqn_train_only")
|
|
|
|
|
+
|
|
|
|
|
+ trainer = DQNTrainer(
|
|
|
|
|
+ env=train_env,
|
|
|
|
|
+ params=algo_params,
|
|
|
|
|
+ callback=callback,
|
|
|
|
|
+ PROJECT_ROOT=PROJECT_ROOT,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # ---------- Training ----------
|
|
|
|
|
+ print("\n[Train] Start training")
|
|
|
|
|
+ trainer.train(total_timesteps=TOTAL_TIMESTEPS)
|
|
|
|
|
+ trainer.save()
|
|
|
|
|
+
|
|
|
|
|
+ print("[Train] Finished")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+# 9. 入口
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ RANDOM_SEED = 2025
|
|
|
|
|
+ TOTAL_TIMESTEPS = 1_500_000
|
|
|
|
|
+
|
|
|
|
|
+ RESET_STATE_CSV = (
|
|
|
|
|
+ PROJECT_ROOT
|
|
|
|
|
+ / "datasets/rl_ready/output/reset_state_pool.csv"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ main()
|