""" 通用强化学习训练入口(当前绑定 DQN,实现已验证) 仅负责: - 构造环境 - 构造 Trainer - 启动训练并保存模型 """ import random from pathlib import Path import numpy as np import torch # ============================================================ # 1. 路径解析 # ============================================================ CURRENT_DIR = Path(__file__).resolve().parent PROJECT_ROOT = CURRENT_DIR.parents[2] # uf_train / uf-rl # ============================================================ # 2. 导入:数据 / 环境 # ============================================================ from uf_train.data_to_rl.data_splitter import ResetStatePoolLoader from uf_train.env.uf_resistance_models_load import load_resistance_models from uf_train.env.uf_physics import UFPhysicsModel from uf_train.env.env_params import ( UFPhysicsParams, UFStateBounds, UFRewardParams, UFActionSpec, ) from uf_train.env.uf_env import UFSuperCycleEnv from uf_train.env.env_visual import UFEpisodeRecorder, UFTrainingCallback # ============================================================ # 3. 导入:算法(当前为 DQN) # ============================================================ from uf_train.rl_model.DQN.dqn_params import DQNParams from uf_train.rl_model.DQN.dqn_trainer import DQNTrainer # ============================================================ # 4. SB3 VecEnv # ============================================================ from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv # ============================================================ # 5. 随机种子 # ============================================================ def set_global_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False print(f"[Seed] Global random seed = {seed}") # ============================================================ # 6. Reset State Pool 加载 # ============================================================ def load_reset_state_pool(): loader = ResetStatePoolLoader( csv_path=RESET_STATE_CSV, train_ratio=0.8, shuffle=True, random_state=RANDOM_SEED, ) train_pool, _ = loader.split() print("[Data] Reset state pool loaded") print(f" Train pool size: {len(train_pool)}") return train_pool # ============================================================ # 7. 环境构造函数 # ============================================================ def make_env( physics: UFPhysicsModel, reward_params: UFRewardParams, action_spec: UFActionSpec, statebounds: UFStateBounds, reset_state_pool, seed: int, ): def _init(): env = UFSuperCycleEnv( physics=physics, reward_params=reward_params, action_spec=action_spec, statebounds=statebounds, real_state_pool=reset_state_pool, RANDOM_SEED=seed, ) env.action_space.seed(seed) env.observation_space.seed(seed) return Monitor(env) return _init # ============================================================ # 8. 主训练流程 # ============================================================ def main(): # ---------- Seed ---------- set_global_seed(RANDOM_SEED) # ---------- Reset states ---------- train_pool = load_reset_state_pool() # ---------- Resistance models ---------- phys_params = UFPhysicsParams() res_fp, res_bw = load_resistance_models(phys_params) # ---------- Physics ---------- physics_model = UFPhysicsModel( phys_params=phys_params, resistance_model_fp=res_fp, resistance_model_bw=res_bw, ) # ---------- RL specs ---------- reward_params = UFRewardParams() action_spec = UFActionSpec() state_bounds = UFStateBounds() # ---------- Training Env ---------- train_env = DummyVecEnv([ make_env( physics_model, reward_params, action_spec, state_bounds, train_pool, RANDOM_SEED, ) ]) # ---------- Callback ---------- recorder = UFEpisodeRecorder() callback = UFTrainingCallback(recorder, verbose=1) # ---------- Trainer ---------- algo_params = DQNParams(remark="uf_dqn_train_only") trainer = DQNTrainer( env=train_env, params=algo_params, callback=callback, PROJECT_ROOT=PROJECT_ROOT, ) # ---------- Training ---------- print("\n[Train] Start training") trainer.train(total_timesteps=TOTAL_TIMESTEPS) trainer.save() print("[Train] Finished") # ============================================================ # 9. 入口 # ============================================================ if __name__ == "__main__": RANDOM_SEED = 2025 TOTAL_TIMESTEPS = 1_500_000 RESET_STATE_CSV = ( PROJECT_ROOT / "datasets/rl_ready/output/reset_state_pool.csv" ) main()