| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 |
- """
- 通用强化学习训练入口(当前绑定 DQN,实现已验证)
- 仅负责:
- - 构造环境
- - 构造 Trainer
- - 启动训练并保存模型
- """
- import random
- from pathlib import Path
- import numpy as np
- import torch
- # ============================================================
- # 1. 路径解析
- # ============================================================
- CURRENT_DIR = Path(__file__).resolve().parent
- PROJECT_ROOT = CURRENT_DIR.parents[2] # uf_train / uf-rl
- # ============================================================
- # 2. 导入:数据 / 环境
- # ============================================================
- from uf_train.data_to_rl.data_splitter import ResetStatePoolLoader
- from uf_train.env.uf_resistance_models_load import load_resistance_models
- from uf_train.env.uf_physics import UFPhysicsModel
- from uf_train.env.env_params import (
- UFPhysicsParams,
- UFStateBounds,
- UFRewardParams,
- UFActionSpec,
- )
- from uf_train.env.uf_env import UFSuperCycleEnv
- from uf_train.env.env_visual import UFEpisodeRecorder, UFTrainingCallback
- # ============================================================
- # 3. 导入:算法(当前为 DQN)
- # ============================================================
- from uf_train.rl_model.DQN.dqn_params import DQNParams
- from uf_train.rl_model.DQN.dqn_trainer import DQNTrainer
- # ============================================================
- # 4. SB3 VecEnv
- # ============================================================
- from stable_baselines3.common.monitor import Monitor
- from stable_baselines3.common.vec_env import DummyVecEnv
- # ============================================================
- # 5. 随机种子
- # ============================================================
- def set_global_seed(seed: int):
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
- print(f"[Seed] Global random seed = {seed}")
- # ============================================================
- # 6. Reset State Pool 加载
- # ============================================================
- def load_reset_state_pool():
- loader = ResetStatePoolLoader(
- csv_path=RESET_STATE_CSV,
- train_ratio=0.8,
- shuffle=True,
- random_state=RANDOM_SEED,
- )
- train_pool, _ = loader.split()
- print("[Data] Reset state pool loaded")
- print(f" Train pool size: {len(train_pool)}")
- return train_pool
- # ============================================================
- # 7. 环境构造函数
- # ============================================================
- def make_env(
- physics: UFPhysicsModel,
- reward_params: UFRewardParams,
- action_spec: UFActionSpec,
- statebounds: UFStateBounds,
- reset_state_pool,
- seed: int,
- ):
- def _init():
- env = UFSuperCycleEnv(
- physics=physics,
- reward_params=reward_params,
- action_spec=action_spec,
- statebounds=statebounds,
- real_state_pool=reset_state_pool,
- RANDOM_SEED=seed,
- )
- env.action_space.seed(seed)
- env.observation_space.seed(seed)
- return Monitor(env)
- return _init
- # ============================================================
- # 8. 主训练流程
- # ============================================================
- def main():
- # ---------- Seed ----------
- set_global_seed(RANDOM_SEED)
- # ---------- Reset states ----------
- train_pool = load_reset_state_pool()
- # ---------- Resistance models ----------
- phys_params = UFPhysicsParams()
- res_fp, res_bw = load_resistance_models(phys_params)
- # ---------- Physics ----------
- physics_model = UFPhysicsModel(
- phys_params=phys_params,
- resistance_model_fp=res_fp,
- resistance_model_bw=res_bw,
- )
- # ---------- RL specs ----------
- reward_params = UFRewardParams()
- action_spec = UFActionSpec()
- state_bounds = UFStateBounds()
- # ---------- Training Env ----------
- train_env = DummyVecEnv([
- make_env(
- physics_model,
- reward_params,
- action_spec,
- state_bounds,
- train_pool,
- RANDOM_SEED,
- )
- ])
- # ---------- Callback ----------
- recorder = UFEpisodeRecorder()
- callback = UFTrainingCallback(recorder, verbose=1)
- # ---------- Trainer ----------
- algo_params = DQNParams(remark="uf_dqn_train_only")
- trainer = DQNTrainer(
- env=train_env,
- params=algo_params,
- callback=callback,
- PROJECT_ROOT=PROJECT_ROOT,
- )
- # ---------- Training ----------
- print("\n[Train] Start training")
- trainer.train(total_timesteps=TOTAL_TIMESTEPS)
- trainer.save()
- print("[Train] Finished")
- # ============================================================
- # 9. 入口
- # ============================================================
- if __name__ == "__main__":
- RANDOM_SEED = 2025
- TOTAL_TIMESTEPS = 1_500_000
- RESET_STATE_CSV = (
- PROJECT_ROOT
- / "datasets/rl_ready/output/reset_state_pool.csv"
- )
- main()
|