|
@@ -2,8 +2,6 @@
|
|
|
DQN 超滤强化学习训练与测试主脚本(工程化优化版)
|
|
DQN 超滤强化学习训练与测试主脚本(工程化优化版)
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
-import os
|
|
|
|
|
-import sys
|
|
|
|
|
import random
|
|
import random
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
import numpy as np
|
|
import numpy as np
|
|
@@ -14,26 +12,25 @@ import torch
|
|
|
# ============================================================
|
|
# ============================================================
|
|
|
CURRENT_DIR = Path(__file__).resolve().parent
|
|
CURRENT_DIR = Path(__file__).resolve().parent
|
|
|
|
|
|
|
|
-PROJECT_ROOT = CURRENT_DIR.parents[2] # uf_train # uf-rl
|
|
|
|
|
|
|
+PROJECT_ROOT = CURRENT_DIR.parents[2] # uf-rl
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------- 数据 ----------
|
|
# ---------- 数据 ----------
|
|
|
-from uf_train.data_to_rl.data_splitter import ResetStatePoolLoader
|
|
|
|
|
|
|
+from data_to_rl.data_splitter import ResetStatePoolLoader
|
|
|
|
|
|
|
|
# ---------- 阻力模型 ----------
|
|
# ---------- 阻力模型 ----------
|
|
|
-from uf_train.env.uf_resistance_models_load import load_resistance_models
|
|
|
|
|
-from uf_train.env.uf_physics import UFPhysicsModel
|
|
|
|
|
|
|
+from env.uf_resistance_models_load import load_resistance_models
|
|
|
|
|
+from env.uf_physics import UFPhysicsModel
|
|
|
|
|
|
|
|
# ---------- 强化学习环境 ----------
|
|
# ---------- 强化学习环境 ----------
|
|
|
-from uf_train.env.env_params import (UFState,UFPhysicsParams,UFActionSpec,UFRewardParams,UFStateBounds)
|
|
|
|
|
-from uf_train.env.env_config_loader import EnvConfigLoader, create_env_params_from_yaml
|
|
|
|
|
-from uf_train.env.uf_env import UFSuperCycleEnv
|
|
|
|
|
|
|
+from env.env_params import (UFActionSpec, UFRewardParams, UFStateBounds)
|
|
|
|
|
+from env.env_config_loader import EnvConfigLoader, create_env_params_from_yaml
|
|
|
|
|
+from env.uf_env import UFSuperCycleEnv
|
|
|
|
|
|
|
|
-from uf_train.env.env_visual import UFEpisodeRecorder, UFTrainingCallback
|
|
|
|
|
|
|
+from env.env_visual import UFEpisodeRecorder, UFTrainingCallback
|
|
|
|
|
|
|
|
-from uf_train.rl_model.DQN.dqn_params import DQNParams
|
|
|
|
|
-from uf_train.rl_model.DQN.dqn_config_loader import DQNConfigLoader, load_dqn_config_with_validation
|
|
|
|
|
-from uf_train.rl_model.DQN.dqn_trainer import DQNTrainer
|
|
|
|
|
|
|
+from rl_model.DQN.dqn_model.dqn_config_loader import DQNConfigLoader
|
|
|
|
|
+from rl_model.DQN.uf_train.dqn_trainer import DQNTrainer
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------- SB3 ----------
|
|
# ---------- SB3 ----------
|
|
@@ -290,14 +287,14 @@ if __name__ == "__main__":
|
|
|
# ============================================================
|
|
# ============================================================
|
|
|
RANDOM_SEED = 2025
|
|
RANDOM_SEED = 2025
|
|
|
TOTAL_TIMESTEPS = 300000
|
|
TOTAL_TIMESTEPS = 300000
|
|
|
- IS_TIMES = True
|
|
|
|
|
|
|
+ IS_TIMES = False
|
|
|
|
|
|
|
|
RESET_STATE_CSV = (
|
|
RESET_STATE_CSV = (
|
|
|
PROJECT_ROOT
|
|
PROJECT_ROOT
|
|
|
- / "datasets/UF_xishan_data/rl_ready/output/reset_state_pool.csv"
|
|
|
|
|
|
|
+ / "datasets/UF_longting_data/rl_ready/output/reset_state_pool.csv"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- ENV_CONFIG_PATH = PROJECT_ROOT / "xishan" / "env_config.yaml"
|
|
|
|
|
- MODEL_CONFIG_PATH = PROJECT_ROOT / "xishan" / "dqn_config.yaml"
|
|
|
|
|
|
|
+ ENV_CONFIG_PATH = PROJECT_ROOT / "longting" / "env_config.yaml"
|
|
|
|
|
+ MODEL_CONFIG_PATH = PROJECT_ROOT / "longting" / "dqn_config.yaml"
|
|
|
|
|
|
|
|
main()
|
|
main()
|