data_splitter.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import pandas as pd
  2. from pathlib import Path
  3. from typing import Tuple
  4. class ResetStatePoolLoader:
  5. """
  6. Reset 初始状态池加载与划分类
  7. 功能:
  8. - 从 CSV 文件读取 RL-ready 的 reset 状态池
  9. - 做基础数据合法性检查
  10. - 按比例划分为 train / val 两个状态池
  11. 设计原则:
  12. - 不涉及 RL 算法
  13. - 不涉及环境逻辑
  14. - 仅负责「数据 → reset 可用状态池」
  15. """
  16. def __init__(
  17. self,
  18. csv_path: str | Path,
  19. train_ratio: float = 0.8,
  20. shuffle: bool = True,
  21. random_state: int = 42,
  22. ):
  23. self.csv_path = Path(csv_path)
  24. self.train_ratio = train_ratio
  25. self.shuffle = shuffle
  26. self.random_state = random_state
  27. self._validate_path()
  28. def _validate_path(self):
  29. if not self.csv_path.exists():
  30. raise FileNotFoundError(
  31. f"Reset state pool 文件不存在: {self.csv_path}"
  32. )
  33. def load(self) -> pd.DataFrame:
  34. """
  35. 读取 reset 状态池 CSV
  36. """
  37. df = pd.read_csv(self.csv_path)
  38. if df.empty:
  39. raise ValueError("reset_state_pool.csv 为空")
  40. # 基础合法性检查
  41. if df.isnull().any().any():
  42. raise ValueError(
  43. "reset_state_pool.csv 中存在 NaN,请在 data_to_rl 阶段处理"
  44. )
  45. return df
  46. def split(self) -> Tuple[pd.DataFrame, pd.DataFrame]:
  47. """
  48. 按比例划分 train / val reset 状态池
  49. 返回:
  50. train_pool, val_pool
  51. """
  52. df = self.load()
  53. if self.shuffle:
  54. df = df.sample(
  55. frac=1.0,
  56. random_state=self.random_state
  57. ).reset_index(drop=True)
  58. split_idx = int(len(df) * self.train_ratio)
  59. train_pool = df.iloc[:split_idx].reset_index(drop=True)
  60. val_pool = df.iloc[split_idx:].reset_index(drop=True)
  61. if len(train_pool) == 0 or len(val_pool) == 0:
  62. raise ValueError(
  63. f"数据量不足以划分 train/val,样本数={len(df)}"
  64. )
  65. return train_pool, val_pool