import pandas as pd from pathlib import Path from typing import Tuple class ResetStatePoolLoader: """ Reset 初始状态池加载与划分类 功能: - 从 CSV 文件读取 RL-ready 的 reset 状态池 - 做基础数据合法性检查 - 按比例划分为 train / val 两个状态池 设计原则: - 不涉及 RL 算法 - 不涉及环境逻辑 - 仅负责「数据 → reset 可用状态池」 """ def __init__( self, csv_path: str | Path, train_ratio: float = 0.8, shuffle: bool = True, random_state: int = 42, ): self.csv_path = Path(csv_path) self.train_ratio = train_ratio self.shuffle = shuffle self.random_state = random_state self._validate_path() def _validate_path(self): if not self.csv_path.exists(): raise FileNotFoundError( f"Reset state pool 文件不存在: {self.csv_path}" ) def load(self) -> pd.DataFrame: """ 读取 reset 状态池 CSV """ df = pd.read_csv(self.csv_path) if df.empty: raise ValueError("reset_state_pool.csv 为空") # 基础合法性检查 if df.isnull().any().any(): raise ValueError( "reset_state_pool.csv 中存在 NaN,请在 data_to_rl 阶段处理" ) return df def split(self) -> Tuple[pd.DataFrame, pd.DataFrame]: """ 按比例划分 train / val reset 状态池 返回: train_pool, val_pool """ df = self.load() if self.shuffle: df = df.sample( frac=1.0, random_state=self.random_state ).reset_index(drop=True) split_idx = int(len(df) * self.train_ratio) train_pool = df.iloc[:split_idx].reset_index(drop=True) val_pool = df.iloc[split_idx:].reset_index(drop=True) if len(train_pool) == 0 or len(val_pool) == 0: raise ValueError( f"数据量不足以划分 train/val,样本数={len(df)}" ) return train_pool, val_pool