| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- import pandas as pd
- from pathlib import Path
- from typing import Tuple
- class ResetStatePoolLoader:
- """
- Reset 初始状态池加载与划分类
- 功能:
- - 从 CSV 文件读取 RL-ready 的 reset 状态池
- - 做基础数据合法性检查
- - 按比例划分为 train / val 两个状态池
- 设计原则:
- - 不涉及 RL 算法
- - 不涉及环境逻辑
- - 仅负责「数据 → reset 可用状态池」
- """
- def __init__(
- self,
- csv_path: str | Path,
- train_ratio: float = 0.8,
- shuffle: bool = True,
- random_state: int = 42,
- ):
- self.csv_path = Path(csv_path)
- self.train_ratio = train_ratio
- self.shuffle = shuffle
- self.random_state = random_state
- self._validate_path()
- def _validate_path(self):
- if not self.csv_path.exists():
- raise FileNotFoundError(
- f"Reset state pool 文件不存在: {self.csv_path}"
- )
- def load(self) -> pd.DataFrame:
- """
- 读取 reset 状态池 CSV
- """
- df = pd.read_csv(self.csv_path)
- if df.empty:
- raise ValueError("reset_state_pool.csv 为空")
- # 基础合法性检查
- if df.isnull().any().any():
- raise ValueError(
- "reset_state_pool.csv 中存在 NaN,请在 data_to_rl 阶段处理"
- )
- return df
- def split(self) -> Tuple[pd.DataFrame, pd.DataFrame]:
- """
- 按比例划分 train / val reset 状态池
- 返回:
- train_pool, val_pool
- """
- df = self.load()
- if self.shuffle:
- df = df.sample(
- frac=1.0,
- random_state=self.random_state
- ).reset_index(drop=True)
- split_idx = int(len(df) * self.train_ratio)
- train_pool = df.iloc[:split_idx].reset_index(drop=True)
- val_pool = df.iloc[split_idx:].reset_index(drop=True)
- if len(train_pool) == 0 or len(val_pool) == 0:
- raise ValueError(
- f"数据量不足以划分 train/val,样本数={len(df)}"
- )
- return train_pool, val_pool
|