Selaa lähdekoodia

feat:真实工厂数据测试版
- 新增工厂提取数据构建训练集测试集的逻辑
- 新增真实数据采样训练和测试逻辑
- 优化强化学习代码结构

junc_WHU 3 kuukautta sitten
vanhempi
commit
e273c714eb

+ 14 - 4
models/uf-rl/uf_train/env/env_params.py

@@ -215,6 +215,13 @@ class UFPhysicsParams:
     【物理与工艺固定参数】
     描述超滤系统的客观工艺条件,在整个 episode 中保持不变。
     """
+    global_TMP_hard_limit: float = 0.08
+    # TMP 硬上限(MPa)
+    # 说明:超过此值将导致episode失败,需立即停机
+
+    global_TMP_soft_limit: float = 0.06
+    # TMP 软上限 (MPa)
+    # 说明:此上限用于指导奖励函数中膜阻力允许上升值,越接近该上限,系统对膜阻力上升控制的更严格
 
     # ========== 反洗参数(固定配置) ==========
     tau_bw_s: float = 20.0
@@ -306,10 +313,6 @@ class UFStateBounds:
     仅用于 reset() 时随机初始化环境状态,
     不参与 step() 中的状态演化。
     """
-    # --- 初始TMP约束 ---
-    TMP0_max: float = 0.035  # 初始TMP上限(MPa)
-    TMP0_min: float = 0.01  # 初始TMP下限(MPa)
-
     # --- 流量约束 ---
     q_UF_max: float = 380.0  # 进水流量上限(m³/h)
     q_UF_min: float = 210.0  # 进水流量下限(m³/h)
@@ -318,6 +321,13 @@ class UFStateBounds:
     temp_max: float = 32.0  # 温度上限(℃)
     temp_min: float = 16.0  # 温度下限(℃)
 
+    # --- 初始TMP约束 ---
+    TMP0_max: float = 0.045  # 初始TMP上限(MPa)
+    TMP0_min: float = 0.01  # 初始TMP下限(MPa)
+
+    # --- TMP上限约束 ---
+    global_TMP_hard_limit: float = 0.08
+
     # --- 短期污染模型参数约束 ---
     nuK_max: float = 2.6e+02  # 阻力增长系数上限
     nuK_min: float = 4e+01  # 阻力增长系数下限

+ 83 - 105
models/uf-rl/uf_train/env/env_reset.py

@@ -1,73 +1,54 @@
+import pandas as pd
 import numpy as np
 from typing import Optional
 
-
 class ResetSampler:
-    """
-    ResetSampler
-    -------------
-    负责在 reset() 时生成“合法初始状态”。
-
-    核心原则:
-    - UFStateBounds 定义的是【硬可行域】(仅作用于自由采样变量)
-    - 所有 reset 状态必须严格落在该域内
-    - sampler 不依赖 env,只依赖 bounds + physics + 数据
-    """
-
-    # === 自由采样变量(不含 R)===
-    STATE_KEYS = (
-        "TMP0",
+    # ============================================================
+    # 状态定义(唯一真源)
+    # ============================================================
+    FREE_STATE_KEYS = [
+        "q_UF",
+        "temp",
+        "TMP",
+        "nuK",
+        "slope",
+        "power",
+        "ceb_removal",
+    ]
+
+    FULL_STATE_KEYS = [
         "q_UF",
         "temp",
+        "TMP",
+        "R",
         "nuK",
         "slope",
         "power",
         "ceb_removal",
-    )
+    ]
 
-    # === 向量索引(显式声明,避免隐式 bug)===
-    IDX_TMP0 = 0
-    IDX_Q_UF = 1
-    IDX_TEMP = 2
+    IDX_Q_UF = 0
+    IDX_TEMP = 1
+    IDX_TMP = 2
 
     def __init__(
         self,
         bounds: "UFStateBounds",
-        physics,                      # ← 新增
-        real_state_pool: Optional[np.ndarray] = None,
+        physics,
+        real_state_pool=None,
         max_resample_attempts: int = 50,
-        random_state: Optional[np.random.RandomState] = None,
+        random_state=None,
     ):
-        """
-        Parameters
-        ----------
-        bounds : UFStateBounds
-            reset 状态的硬边界(5% / 95% 分位数)
-
-        physics : UFPhysics
-            用于 TMP → R 的物理模型(Darcy 定律)
-
-        real_state_pool : np.ndarray, optional
-            真实状态池,shape = (N, state_dim)
-            仅包含 STATE_KEYS 中的变量(不含 R)
-
-        max_resample_attempts : int
-            真实采样失败时最大重试次数
-
-        random_state : np.random.RandomState
-            随机数生成器
-        """
         self.bounds = bounds
         self.physics = physics
-        self.real_state_pool = real_state_pool
         self.max_resample_attempts = max_resample_attempts
         self.rng = random_state or np.random.RandomState()
 
-        # --- 构造向量化边界(仅自由变量)---
+        # --- 自由变量边界(顺序必须与 FREE_STATE_KEYS 一致)---
         self.low = np.array([
-            bounds.TMP0_min,
             bounds.q_UF_min,
             bounds.temp_min,
+            bounds.TMP0_min,
             bounds.nuK_min,
             bounds.slope_min,
             bounds.power_min,
@@ -75,9 +56,9 @@ class ResetSampler:
         ])
 
         self.high = np.array([
-            bounds.TMP0_max,
             bounds.q_UF_max,
             bounds.temp_max,
+            bounds.TMP0_max,
             bounds.nuK_max,
             bounds.slope_max,
             bounds.power_max,
@@ -86,107 +67,104 @@ class ResetSampler:
 
         self.state_dim = len(self.low)
 
+        # ========================================================
+        # 统一 real_state_pool 为 DataFrame(不含 R)
+        # ========================================================
         if real_state_pool is not None:
-            assert real_state_pool.shape[1] == self.state_dim
+            if isinstance(real_state_pool, pd.DataFrame):
+                df = real_state_pool.copy()
+            else:
+                df = pd.DataFrame(real_state_pool, columns=self.FULL_STATE_KEYS)
+
+            if "R" in df.columns:
+                df = df.drop(columns=["R"])
+
+            df = df[self.FREE_STATE_KEYS]
+            self.real_state_pool = df.reset_index(drop=True)
+        else:
+            self.real_state_pool = None
 
     # ============================================================
-    # 对外唯一接口
+    # 对外接口
     # ============================================================
-    def sample(self, progress: float) -> np.ndarray:
-        """
-        根据训练进度采样 reset 初始状态
-
-        Returns
-        -------
-        full_state : np.ndarray
-            完整初始状态:
-            [TMP0, q_UF, temp, R, nuK, slope, power, ceb_removal]
-        """
+    def sample(self, progress: float) -> pd.Series:
         cfg = self._get_sampling_config(progress)
 
-        sources = []
-        weights = []
+        sources, weights = [], []
 
         if self.real_state_pool is not None:
-            sources.extend(["real", "perturb"])
-            weights.extend([cfg["w_real"], cfg["w_perturb"]])
+            sources += ["real", "perturb"]
+            weights += [cfg["w_real"], cfg["w_perturb"]]
 
         sources.append("virtual")
         weights.append(cfg["w_virtual"])
 
         source = self.rng.choice(sources, p=self._normalize(weights))
 
+        # ========================================================
+        # 1. 采样自由变量(DataFrame 单行)
+        # ========================================================
         if source == "real":
-            state = self._sample_real()
+            state_df = self._sample_real()  # 必须返回 DataFrame(1 row)
 
         elif source == "perturb":
             base = self._sample_real()
-            state = base + self.rng.normal(
-                0.0, cfg["perturb_scale"], size=self.state_dim
-            )
+            noise = self.rng.normal(0.0, cfg["perturb_scale"], size=self.state_dim)
+            vec = base.values.squeeze() + noise
+            vec = np.clip(vec, self.low, self.high)
+            state_df = pd.DataFrame([vec], columns=self.FREE_STATE_KEYS)
 
         else:
-            state = self._sample_virtual()
-
-        # ---- 强制投影(只对自由变量)----
-        state = self._clip(state)
+            vec = self.rng.uniform(self.low, self.high)
+            state_df = pd.DataFrame([vec], columns=self.FREE_STATE_KEYS)
 
-        # ---- 物理派生:计算初始膜阻力 R ----
-        TMP0 = state[self.IDX_TMP0]
-        q_UF = state[self.IDX_Q_UF]
-        temp = state[self.IDX_TEMP]
+        # ========================================================
+        # 2. 物理派生:计算 R
+        # ========================================================
+        TMP = state_df.at[state_df.index[0], "TMP"]
+        q_UF = state_df.at[state_df.index[0], "q_UF"]
+        temp = state_df.at[state_df.index[0], "temp"]
 
-        R = self.physics.resistance_from_tmp(
-            tmp=TMP0,
-            q_UF=q_UF,
-            temp=temp,
-        )
+        R = self.physics.resistance_from_tmp(tmp=TMP, q_UF=q_UF, temp=temp)
 
         if not np.isfinite(R):
             raise RuntimeError("Invalid resistance computed during reset sampling.")
 
-        # ---- 返回完整状态向量(顺序固定)----
-        full_state = np.insert(state, 3, R)
-        return full_state
+        # ========================================================
+        # 3. 插入 R,形成完整状态(仍是 DataFrame)
+        # ========================================================
+        state_df.insert(3, "R", R)
+        state_df = state_df[self.FULL_STATE_KEYS]
+
+        # ========================================================
+        # 4. DataFrame(1 row) → Series(对外唯一返回)
+        # ========================================================
+        state = state_df.iloc[0]
+        state.name = "reset_state"  # 可选,但强烈推荐
+
+        return state
 
     # ============================================================
-    # 各类采样
+    # 内部方法
     # ============================================================
-    def _sample_real(self) -> np.ndarray:
-        """从真实数据池采样,拒绝越界样本"""
+    def _sample_real(self) -> pd.DataFrame:
         for _ in range(self.max_resample_attempts):
             idx = self.rng.randint(len(self.real_state_pool))
-            state = self.real_state_pool[idx]
+            row = self.real_state_pool.iloc[[idx]].reset_index(drop=True)
 
-            if self._is_valid(state):
-                return state
+            vec = row.values.squeeze()
+            if np.all(vec >= self.low) and np.all(vec <= self.high):
+                return row
 
         raise RuntimeError("No valid real reset state within bounds.")
 
-    def _sample_virtual(self) -> np.ndarray:
-        """在硬边界内均匀采样"""
-        return self.rng.uniform(self.low, self.high)
-
-    # ============================================================
-    # 工具函数
-    # ============================================================
-    def _is_valid(self, state: np.ndarray) -> bool:
-        return np.all(state >= self.low) and np.all(state <= self.high)
-
-    def _clip(self, state: np.ndarray) -> np.ndarray:
-        return np.clip(state, self.low, self.high)
-
     @staticmethod
     def _normalize(w):
         s = sum(w)
         return [x / s for x in w]
 
     def _get_sampling_config(self, progress: float) -> dict:
-        """
-        reset curriculum(不改变边界,只改变来源比例)
-        """
         progress = np.clip(progress, 0.0, 1.0)
-
         return dict(
             w_real=1.0 - 0.7 * progress,
             w_perturb=0.5 * progress,

+ 93 - 0
models/uf-rl/uf_train/env/env_visual.py

@@ -0,0 +1,93 @@
+import numpy as np
+from stable_baselines3.common.callbacks import BaseCallback
+
+
+class UFEpisodeRecorder:
+    """记录episode中的决策和结果"""
+
+    def __init__(self):
+        self.episode_data = []
+        self.current_episode = []
+
+    def record_step(self, obs, action, reward, done, info):
+        """记录单步信息"""
+        step_data = {
+            "obs": obs.copy(),
+            "action": action.copy(),
+            "reward": reward,
+            "done": done,
+            "info": info.copy() if info else {}
+        }
+        self.current_episode.append(step_data)
+
+        if done:
+            self.episode_data.append(self.current_episode)
+            self.current_episode = []
+
+    def get_episode_stats(self, episode_idx=-1):
+        """获取episode统计信息"""
+        if not self.episode_data:
+            return {}
+
+        episode = self.episode_data[episode_idx]
+        total_reward = sum(step["reward"] for step in episode)
+        avg_recovery = np.mean([step["info"].get("recovery", 0) for step in episode if "recovery" in step["info"]])
+        feasible_steps = sum(1 for step in episode if step["info"].get("feasible", False))
+
+        return {
+            "total_reward": total_reward,
+            "avg_recovery": avg_recovery,
+            "feasible_steps": feasible_steps,
+            "total_steps": len(episode)
+        }
+
+
+# ==== 定义强化学习训练回调器 ====
+class UFTrainingCallback(BaseCallback):
+    """
+    强化学习训练回调,用于记录每一步的数据到 recorder。
+    1. 不依赖环境内部 last_* 属性
+    2. 使用环境接口提供的 obs、actions、rewards、dones、infos
+    3. 自动处理 episode 结束时的统计
+    """
+
+    def __init__(self, recorder, verbose=0):
+        super(UFTrainingCallback, self).__init__(verbose)
+        self.recorder = recorder
+
+    def _on_step(self) -> bool:
+        try:
+            new_obs = self.locals.get("new_obs")
+            actions = self.locals.get("actions")
+            rewards = self.locals.get("rewards")
+            dones = self.locals.get("dones")
+            infos = self.locals.get("infos")
+
+            if len(new_obs) > 0:
+                step_obs = new_obs[0]
+                step_action = actions[0] if actions is not None else None
+                step_reward = rewards[0] if rewards is not None else 0.0
+                step_done = dones[0] if dones is not None else False
+                step_info = infos[0] if infos is not None else {}
+
+                # 打印当前 step 的信息
+                if self.verbose:
+                    print(f"[Step {self.num_timesteps}] 动作={step_action}, 奖励={step_reward:.3f}, Done={step_done}")
+
+                # 记录数据
+                self.recorder.record_step(
+                    obs=step_obs,
+                    action=step_action,
+                    reward=step_reward,
+                    done=step_done,
+                    info=step_info,
+                )
+
+        except Exception as e:
+            if self.verbose:
+                print(f"[Callback Error] {e}")
+
+        return True
+
+
+

+ 17 - 11
models/uf-rl/uf_train/env/uf_env.py

@@ -190,6 +190,12 @@ class UFSuperCycleEnv(gym.Env):
             ceb_removal=ceb_removal,
         )
 
+    def _get_training_progress(self) -> float:
+        """
+        返回训练进度,用于 reset_sampler 的 curriculum sampling
+        """
+        return min(1.0, self.current_step / self.max_episode_steps )
+
     def reset(self, seed=None, options=None, max_attempts: int = 1000):
         super().reset(seed=seed)
 
@@ -227,12 +233,12 @@ class UFSuperCycleEnv(gym.Env):
         构建当前环境归一化状态向量
         """
         # === 1. 从 self.state 读取动态参数 ===
-        TMP0 = self.state.TMP0
+        TMP = self.state.TMP
         q_UF = self.state.q_UF
         temp = self.state.temp
 
         # === 2. 计算本周期初始膜阻力 ===
-        R0 = self.state.R0
+        R = self.state.R
 
         # === 3. 从 self.state 读取膜阻力增长模型参数 ===
         nuk = self.state.nuK
@@ -241,21 +247,21 @@ class UFSuperCycleEnv(gym.Env):
         ceb_removal = self.state.ceb_removal
 
         # === 4. 从 current_params 动态读取上下限 ===
-        TMP0_min, TMP0_max = self.statebounds.TMP0_min, self.statebounds.global_TMP_hard_limit
-        q_UF_min, q_UF_max = self.statebounds.q_UF_min, self.statebounds.q_UF_max
-        temp_min, temp_max = self.statebounds.temp_min, self.statebounds.temp_max
-        nuK_min, nuK_max = self.statebounds.nuK_min, self.statebounds.nuK_max
-        slope_min, slope_max = self.statebounds.slope_min, self.statebounds.slope_max
-        power_min, power_max = self.statebounds.power_min, self.statebounds.power_max
-        ceb_min, ceb_max = self.statebounds.ceb_removal_min, self.statebounds.ceb_removal_max
+        TMP0_min, TMP0_max = self.state_bounds.TMP0_min, self.state_bounds.global_TMP_hard_limit
+        q_UF_min, q_UF_max = self.state_bounds.q_UF_min, self.state_bounds.q_UF_max
+        temp_min, temp_max = self.state_bounds.temp_min, self.state_bounds.temp_max
+        nuK_min, nuK_max = self.state_bounds.nuK_min, self.state_bounds.nuK_max
+        slope_min, slope_max = self.state_bounds.slope_min, self.state_bounds.slope_max
+        power_min, power_max = self.state_bounds.power_min, self.state_bounds.power_max
+        ceb_min, ceb_max = self.state_bounds.ceb_removal_min, self.state_bounds.ceb_removal_max
 
         # === 5. 归一化计算(clip防止越界) ===
-        TMP0_norm = np.clip((TMP0 - TMP0_min) / (TMP0_max - TMP0_min), 0, 1)
+        TMP0_norm = np.clip((TMP - TMP0_min) / (TMP0_max - TMP0_min), 0, 1)
         q_UF_norm = np.clip((q_UF - q_UF_min) / (q_UF_max - q_UF_min), 0, 1)
         temp_norm = np.clip((temp - temp_min) / (temp_max - temp_min), 0, 1)
 
         # R0 不在 current_params 中定义上下限,设定经验范围
-        R0_norm = np.clip((R0 - 100.0) / (600.0 - 100.0), 0, 1)
+        R0_norm = np.clip((R - 100.0) / (800.0 - 100.0), 0, 1)
 
         short_term_norm = np.clip((nuk - nuK_min) / (nuK_max - nuK_min), 0, 1)
         long_term_slope_norm = np.clip((slope - slope_min) / (slope_max - slope_min), 0, 1)

+ 8 - 7
models/uf-rl/uf_train/env/uf_physics.py

@@ -195,7 +195,7 @@ class UFPhysicsModel:
 
     def delta_resistance_filter(
         self,
-        state: UFState,
+        state,
         L_s: float
     ) -> float:
         """
@@ -215,7 +215,7 @@ class UFPhysicsModel:
 
     def delta_resistance_backwash(
         self,
-        state: UFState,
+        state,
         R0: float,
         R_end: float,
         L_h_next_start: float,
@@ -262,8 +262,8 @@ class UFPhysicsModel:
 
         # 初始状态(统一用 TMP / R,与当前 state 保持一致)
         initial_tmp = state.TMP # 记录周期初始跨膜压差
-        initial_R = state.R0 # 记录周期初始膜阻力
-        tmp = state.TMP # 当前跨膜压差
+        initial_R = state.R # 记录周期初始膜阻力
+        tmp = initial_tmp # 当前跨膜压差
 
 
         # 跟踪变量(用于记录周期内的极值)
@@ -349,6 +349,7 @@ class UFPhysicsModel:
 
         # ===== 新指标:膜阻力允许上升空间 =====
         # 该指标根据当前最大跨膜压差距离软约束跨膜压差的距离,动态计算当前周期允许上升的膜阻力值,用于后续清洗效果奖励计算
+
         delta_R_allow = max(
             self.resistance_from_tmp(self.p.global_TMP_soft_limit, state.q_UF, state.temp) -
             self.resistance_from_tmp(max_tmp_during_filtration, state.q_UF, state.temp),
@@ -390,8 +391,8 @@ class UFPhysicsModel:
 
         # 更新 state
         next_state = copy.deepcopy(state)
-        next_state.TMP0 = tmp_after_ceb
-        next_state.R0 = R_after_ceb
+        next_state.TMP = tmp_after_ceb
+        next_state.R = R_after_ceb
 
         # ========== 可选更新的参数(当前保持不变) ==========
         # 这些参数可根据实际情况动态调整,预留扩展接口
@@ -510,7 +511,7 @@ class UFPhysicsModel:
                 return False
 
             # 物理硬约束:TMP 不允许为负
-            if next_state.TMP0 < 0:
+            if next_state.TMP < 0:
                 return False
 
             # 进入下一步

+ 4 - 3
models/uf-rl/uf_train/env/uf_resistance_models_load.py

@@ -3,8 +3,9 @@ import torch
 from pathlib import Path
 from uf_train.env.uf_resistance_models_define import ResistanceDecreaseModel, ResistanceIncreaseModel
 
+
 # ==================== 膜阻力模型加载函数 ====================
-def load_resistance_models():
+def load_resistance_models(phys):
     """
     加载膜阻力预测模型(单例模式)
 
@@ -32,8 +33,8 @@ def load_resistance_models():
     print("🔄 正在加载膜阻力模型...")
 
     # 初始化模型对象
-    resistance_model_fp = ResistanceIncreaseModel()
-    resistance_model_bw = ResistanceDecreaseModel()
+    resistance_model_fp = ResistanceIncreaseModel(phys)
+    resistance_model_bw = ResistanceDecreaseModel(phys)
 
     # 获取当前脚本所在目录
     base_dir = Path(__file__).resolve().parent

+ 3 - 0
models/uf-rl/uf_train/rl_model/DQN/dqn_params.py

@@ -1,4 +1,7 @@
+from dataclasses import dataclass
+
 # ==================== DQN超参数配置类 ====================
+@dataclass
 class DQNParams:
     """
     DQN 超参数配置类

+ 26 - 18
models/uf-rl/uf_train/rl_model/DQN/run_dqn_train.py

@@ -12,6 +12,10 @@ import torch
 # ============================================================
 # 1. 导入模块
 # ============================================================
+CURRENT_DIR = Path(__file__).resolve().parent
+
+PROJECT_ROOT = CURRENT_DIR.parents[2]     # uf_train  # uf-rl
+
 
 # ---------- 数据 ----------
 from uf_train.data_to_rl.data_splitter import ResetStatePoolLoader
@@ -22,7 +26,7 @@ from uf_train.env.uf_physics import UFPhysicsModel
 from uf_train.env.env_params import  UFPhysicsParams, UFStateBounds, UFRewardParams, UFActionSpec
 from uf_train.env.uf_env import UFSuperCycleEnv
 
-from uf_train.env.isualization import UFEpisodeRecorder, UFTrainingCallback
+from uf_train.env.env_visual import UFEpisodeRecorder, UFTrainingCallback
 
 from uf_train.rl_model.DQN.dqn_params import DQNParams
 from uf_train.rl_model.DQN.dqn_trainer import DQNTrainer
@@ -33,20 +37,11 @@ from stable_baselines3.common.monitor import Monitor
 from stable_baselines3.common.vec_env import DummyVecEnv
 
 
-# ============================================================
-# 2. 全局配置
-# ============================================================
-RANDOM_SEED = 2025
-TOTAL_TIMESTEPS = 300_000
 
-RESET_STATE_CSV = (
-    PROJECT_ROOT
-    / "datasets/rl_ready/output/reset_state_pool.csv"
-)
 
 
 # ============================================================
-# 3. 随机种子
+#  随机种子
 # ============================================================
 def set_global_seed(seed: int):
     random.seed(seed)
@@ -87,7 +82,7 @@ def make_env(
     physics: UFPhysicsModel,
     reward_params: UFRewardParams,
     action_spec: UFActionSpec,
-    state_bounds: UFStateBounds,
+    statebounds: UFStateBounds,
     reset_state_pool,
     seed: int,
 ):
@@ -96,17 +91,17 @@ def make_env(
             physics=physics,
             reward_params=reward_params,
             action_spec=action_spec,
-            state_bounds=state_bounds,
+            statebounds=statebounds,
             real_state_pool=reset_state_pool,
-            random_seed=seed,
+            RANDOM_SEED=seed,
         )
         env.action_space.seed(seed)
         env.observation_space.seed(seed)
         return Monitor(env)
-
     return _init
 
 
+
 # ============================================================
 # 6. 主流程
 # ============================================================
@@ -118,10 +113,10 @@ def main():
     train_pool, val_pool = load_reset_state_pools()
 
     # ---------- Resistance models ----------
-    res_fp, res_bw = load_resistance_models()
+    phys_params = UFPhysicsParams()
+    res_fp, res_bw = load_resistance_models(phys_params)
 
     # ---------- Physics ----------
-    phys_params = UFPhysicsParams()
     physics_model = UFPhysicsModel(
         phys_params=phys_params,
         resistance_model_fp=res_fp,
@@ -168,12 +163,14 @@ def main():
         callback=callback,
     )
 
+
     # ---------- Training ----------
+    print("\n Start training")
     trainer.train(total_timesteps=TOTAL_TIMESTEPS)
     trainer.save()
 
     # ========================================================
-    # Validation rollout(明确为 validation,而非“测试”)
+    # 验证
     # ========================================================
     print("\n[Eval] Start validation rollout")
 
@@ -207,4 +204,15 @@ def main():
 # 入口
 # ============================================================
 if __name__ == "__main__":
+    # ============================================================
+    # 2. 全局配置
+    # ============================================================
+    RANDOM_SEED = 2025
+    TOTAL_TIMESTEPS = 1000000
+
+    RESET_STATE_CSV = (
+            PROJECT_ROOT
+            / "datasets/rl_ready/output/reset_state_pool.csv"
+    )
+
     main()