|
|
@@ -1,73 +1,54 @@
|
|
|
+import pandas as pd
|
|
|
import numpy as np
|
|
|
from typing import Optional
|
|
|
|
|
|
-
|
|
|
class ResetSampler:
|
|
|
- """
|
|
|
- ResetSampler
|
|
|
- -------------
|
|
|
- 负责在 reset() 时生成“合法初始状态”。
|
|
|
-
|
|
|
- 核心原则:
|
|
|
- - UFStateBounds 定义的是【硬可行域】(仅作用于自由采样变量)
|
|
|
- - 所有 reset 状态必须严格落在该域内
|
|
|
- - sampler 不依赖 env,只依赖 bounds + physics + 数据
|
|
|
- """
|
|
|
-
|
|
|
- # === 自由采样变量(不含 R)===
|
|
|
- STATE_KEYS = (
|
|
|
- "TMP0",
|
|
|
+ # ============================================================
|
|
|
+ # 状态定义(唯一真源)
|
|
|
+ # ============================================================
|
|
|
+ FREE_STATE_KEYS = [
|
|
|
+ "q_UF",
|
|
|
+ "temp",
|
|
|
+ "TMP",
|
|
|
+ "nuK",
|
|
|
+ "slope",
|
|
|
+ "power",
|
|
|
+ "ceb_removal",
|
|
|
+ ]
|
|
|
+
|
|
|
+ FULL_STATE_KEYS = [
|
|
|
"q_UF",
|
|
|
"temp",
|
|
|
+ "TMP",
|
|
|
+ "R",
|
|
|
"nuK",
|
|
|
"slope",
|
|
|
"power",
|
|
|
"ceb_removal",
|
|
|
- )
|
|
|
+ ]
|
|
|
|
|
|
- # === 向量索引(显式声明,避免隐式 bug)===
|
|
|
- IDX_TMP0 = 0
|
|
|
- IDX_Q_UF = 1
|
|
|
- IDX_TEMP = 2
|
|
|
+ IDX_Q_UF = 0
|
|
|
+ IDX_TEMP = 1
|
|
|
+ IDX_TMP = 2
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
bounds: "UFStateBounds",
|
|
|
- physics, # ← 新增
|
|
|
- real_state_pool: Optional[np.ndarray] = None,
|
|
|
+ physics,
|
|
|
+ real_state_pool=None,
|
|
|
max_resample_attempts: int = 50,
|
|
|
- random_state: Optional[np.random.RandomState] = None,
|
|
|
+ random_state=None,
|
|
|
):
|
|
|
- """
|
|
|
- Parameters
|
|
|
- ----------
|
|
|
- bounds : UFStateBounds
|
|
|
- reset 状态的硬边界(5% / 95% 分位数)
|
|
|
-
|
|
|
- physics : UFPhysics
|
|
|
- 用于 TMP → R 的物理模型(Darcy 定律)
|
|
|
-
|
|
|
- real_state_pool : np.ndarray, optional
|
|
|
- 真实状态池,shape = (N, state_dim)
|
|
|
- 仅包含 STATE_KEYS 中的变量(不含 R)
|
|
|
-
|
|
|
- max_resample_attempts : int
|
|
|
- 真实采样失败时最大重试次数
|
|
|
-
|
|
|
- random_state : np.random.RandomState
|
|
|
- 随机数生成器
|
|
|
- """
|
|
|
self.bounds = bounds
|
|
|
self.physics = physics
|
|
|
- self.real_state_pool = real_state_pool
|
|
|
self.max_resample_attempts = max_resample_attempts
|
|
|
self.rng = random_state or np.random.RandomState()
|
|
|
|
|
|
- # --- 构造向量化边界(仅自由变量)---
|
|
|
+ # --- 自由变量边界(顺序必须与 FREE_STATE_KEYS 一致)---
|
|
|
self.low = np.array([
|
|
|
- bounds.TMP0_min,
|
|
|
bounds.q_UF_min,
|
|
|
bounds.temp_min,
|
|
|
+ bounds.TMP0_min,
|
|
|
bounds.nuK_min,
|
|
|
bounds.slope_min,
|
|
|
bounds.power_min,
|
|
|
@@ -75,9 +56,9 @@ class ResetSampler:
|
|
|
])
|
|
|
|
|
|
self.high = np.array([
|
|
|
- bounds.TMP0_max,
|
|
|
bounds.q_UF_max,
|
|
|
bounds.temp_max,
|
|
|
+ bounds.TMP0_max,
|
|
|
bounds.nuK_max,
|
|
|
bounds.slope_max,
|
|
|
bounds.power_max,
|
|
|
@@ -86,107 +67,104 @@ class ResetSampler:
|
|
|
|
|
|
self.state_dim = len(self.low)
|
|
|
|
|
|
+ # ========================================================
|
|
|
+ # 统一 real_state_pool 为 DataFrame(不含 R)
|
|
|
+ # ========================================================
|
|
|
if real_state_pool is not None:
|
|
|
- assert real_state_pool.shape[1] == self.state_dim
|
|
|
+ if isinstance(real_state_pool, pd.DataFrame):
|
|
|
+ df = real_state_pool.copy()
|
|
|
+ else:
|
|
|
+ df = pd.DataFrame(real_state_pool, columns=self.FULL_STATE_KEYS)
|
|
|
+
|
|
|
+ if "R" in df.columns:
|
|
|
+ df = df.drop(columns=["R"])
|
|
|
+
|
|
|
+ df = df[self.FREE_STATE_KEYS]
|
|
|
+ self.real_state_pool = df.reset_index(drop=True)
|
|
|
+ else:
|
|
|
+ self.real_state_pool = None
|
|
|
|
|
|
# ============================================================
|
|
|
- # 对外唯一接口
|
|
|
+ # 对外接口
|
|
|
# ============================================================
|
|
|
- def sample(self, progress: float) -> np.ndarray:
|
|
|
- """
|
|
|
- 根据训练进度采样 reset 初始状态
|
|
|
-
|
|
|
- Returns
|
|
|
- -------
|
|
|
- full_state : np.ndarray
|
|
|
- 完整初始状态:
|
|
|
- [TMP0, q_UF, temp, R, nuK, slope, power, ceb_removal]
|
|
|
- """
|
|
|
+ def sample(self, progress: float) -> pd.Series:
|
|
|
cfg = self._get_sampling_config(progress)
|
|
|
|
|
|
- sources = []
|
|
|
- weights = []
|
|
|
+ sources, weights = [], []
|
|
|
|
|
|
if self.real_state_pool is not None:
|
|
|
- sources.extend(["real", "perturb"])
|
|
|
- weights.extend([cfg["w_real"], cfg["w_perturb"]])
|
|
|
+ sources += ["real", "perturb"]
|
|
|
+ weights += [cfg["w_real"], cfg["w_perturb"]]
|
|
|
|
|
|
sources.append("virtual")
|
|
|
weights.append(cfg["w_virtual"])
|
|
|
|
|
|
source = self.rng.choice(sources, p=self._normalize(weights))
|
|
|
|
|
|
+ # ========================================================
|
|
|
+ # 1. 采样自由变量(DataFrame 单行)
|
|
|
+ # ========================================================
|
|
|
if source == "real":
|
|
|
- state = self._sample_real()
|
|
|
+ state_df = self._sample_real() # 必须返回 DataFrame(1 row)
|
|
|
|
|
|
elif source == "perturb":
|
|
|
base = self._sample_real()
|
|
|
- state = base + self.rng.normal(
|
|
|
- 0.0, cfg["perturb_scale"], size=self.state_dim
|
|
|
- )
|
|
|
+ noise = self.rng.normal(0.0, cfg["perturb_scale"], size=self.state_dim)
|
|
|
+ vec = base.values.squeeze() + noise
|
|
|
+ vec = np.clip(vec, self.low, self.high)
|
|
|
+ state_df = pd.DataFrame([vec], columns=self.FREE_STATE_KEYS)
|
|
|
|
|
|
else:
|
|
|
- state = self._sample_virtual()
|
|
|
-
|
|
|
- # ---- 强制投影(只对自由变量)----
|
|
|
- state = self._clip(state)
|
|
|
+ vec = self.rng.uniform(self.low, self.high)
|
|
|
+ state_df = pd.DataFrame([vec], columns=self.FREE_STATE_KEYS)
|
|
|
|
|
|
- # ---- 物理派生:计算初始膜阻力 R ----
|
|
|
- TMP0 = state[self.IDX_TMP0]
|
|
|
- q_UF = state[self.IDX_Q_UF]
|
|
|
- temp = state[self.IDX_TEMP]
|
|
|
+ # ========================================================
|
|
|
+ # 2. 物理派生:计算 R
|
|
|
+ # ========================================================
|
|
|
+ TMP = state_df.at[state_df.index[0], "TMP"]
|
|
|
+ q_UF = state_df.at[state_df.index[0], "q_UF"]
|
|
|
+ temp = state_df.at[state_df.index[0], "temp"]
|
|
|
|
|
|
- R = self.physics.resistance_from_tmp(
|
|
|
- tmp=TMP0,
|
|
|
- q_UF=q_UF,
|
|
|
- temp=temp,
|
|
|
- )
|
|
|
+ R = self.physics.resistance_from_tmp(tmp=TMP, q_UF=q_UF, temp=temp)
|
|
|
|
|
|
if not np.isfinite(R):
|
|
|
raise RuntimeError("Invalid resistance computed during reset sampling.")
|
|
|
|
|
|
- # ---- 返回完整状态向量(顺序固定)----
|
|
|
- full_state = np.insert(state, 3, R)
|
|
|
- return full_state
|
|
|
+ # ========================================================
|
|
|
+ # 3. 插入 R,形成完整状态(仍是 DataFrame)
|
|
|
+ # ========================================================
|
|
|
+ state_df.insert(3, "R", R)
|
|
|
+ state_df = state_df[self.FULL_STATE_KEYS]
|
|
|
+
|
|
|
+ # ========================================================
|
|
|
+ # 4. DataFrame(1 row) → Series(对外唯一返回)
|
|
|
+ # ========================================================
|
|
|
+ state = state_df.iloc[0]
|
|
|
+ state.name = "reset_state" # 可选,但强烈推荐
|
|
|
+
|
|
|
+ return state
|
|
|
|
|
|
# ============================================================
|
|
|
- # 各类采样
|
|
|
+ # 内部方法
|
|
|
# ============================================================
|
|
|
- def _sample_real(self) -> np.ndarray:
|
|
|
- """从真实数据池采样,拒绝越界样本"""
|
|
|
+ def _sample_real(self) -> pd.DataFrame:
|
|
|
for _ in range(self.max_resample_attempts):
|
|
|
idx = self.rng.randint(len(self.real_state_pool))
|
|
|
- state = self.real_state_pool[idx]
|
|
|
+ row = self.real_state_pool.iloc[[idx]].reset_index(drop=True)
|
|
|
|
|
|
- if self._is_valid(state):
|
|
|
- return state
|
|
|
+ vec = row.values.squeeze()
|
|
|
+ if np.all(vec >= self.low) and np.all(vec <= self.high):
|
|
|
+ return row
|
|
|
|
|
|
raise RuntimeError("No valid real reset state within bounds.")
|
|
|
|
|
|
- def _sample_virtual(self) -> np.ndarray:
|
|
|
- """在硬边界内均匀采样"""
|
|
|
- return self.rng.uniform(self.low, self.high)
|
|
|
-
|
|
|
- # ============================================================
|
|
|
- # 工具函数
|
|
|
- # ============================================================
|
|
|
- def _is_valid(self, state: np.ndarray) -> bool:
|
|
|
- return np.all(state >= self.low) and np.all(state <= self.high)
|
|
|
-
|
|
|
- def _clip(self, state: np.ndarray) -> np.ndarray:
|
|
|
- return np.clip(state, self.low, self.high)
|
|
|
-
|
|
|
@staticmethod
|
|
|
def _normalize(w):
|
|
|
s = sum(w)
|
|
|
return [x / s for x in w]
|
|
|
|
|
|
def _get_sampling_config(self, progress: float) -> dict:
|
|
|
- """
|
|
|
- reset curriculum(不改变边界,只改变来源比例)
|
|
|
- """
|
|
|
progress = np.clip(progress, 0.0, 1.0)
|
|
|
-
|
|
|
return dict(
|
|
|
w_real=1.0 - 0.7 * progress,
|
|
|
w_perturb=0.5 * progress,
|