import pandas as pd import numpy as np from typing import Optional class ResetSampler: # ============================================================ # 状态定义(唯一真源) # ============================================================ FREE_STATE_KEYS = [ "q_UF", "temp", "TMP", "nuK", "slope", "power", "ceb_removal", ] FULL_STATE_KEYS = [ "q_UF", "temp", "TMP", "R", "nuK", "slope", "power", "ceb_removal", ] IDX_Q_UF = 0 IDX_TEMP = 1 IDX_TMP = 2 def __init__( self, bounds: "UFStateBounds", physics, real_state_pool=None, max_resample_attempts: int = 50, random_state=None, ): self.bounds = bounds self.physics = physics self.max_resample_attempts = max_resample_attempts self.rng = random_state or np.random.RandomState() # --- 自由变量边界(顺序必须与 FREE_STATE_KEYS 一致)--- self.low = np.array([ bounds.q_UF_min, bounds.temp_min, bounds.TMP0_min, bounds.nuK_min, bounds.slope_min, bounds.power_min, bounds.ceb_removal_min, ]) self.high = np.array([ bounds.q_UF_max, bounds.temp_max, bounds.TMP0_max, bounds.nuK_max, bounds.slope_max, bounds.power_max, bounds.ceb_removal_max, ]) self.state_dim = len(self.low) # ======================================================== # 统一 real_state_pool 为 DataFrame(不含 R) # ======================================================== if real_state_pool is not None: if isinstance(real_state_pool, pd.DataFrame): df = real_state_pool.copy() else: df = pd.DataFrame(real_state_pool, columns=self.FULL_STATE_KEYS) if "R" in df.columns: df = df.drop(columns=["R"]) df = df[self.FREE_STATE_KEYS] self.real_state_pool = df.reset_index(drop=True) else: self.real_state_pool = None # ============================================================ # 对外接口 # ============================================================ def sample(self, progress: float) -> pd.Series: cfg = self._get_sampling_config(progress) sources, weights = [], [] if self.real_state_pool is not None: sources += ["real", "perturb"] weights += [cfg["w_real"], cfg["w_perturb"]] sources.append("virtual") weights.append(cfg["w_virtual"]) source = self.rng.choice(sources, p=self._normalize(weights)) # ======================================================== # 1. 采样自由变量(DataFrame 单行) # ======================================================== if source == "real": state_df = self._sample_real() # 必须返回 DataFrame(1 row) elif source == "perturb": base = self._sample_real() noise = self.rng.normal(0.0, cfg["perturb_scale"], size=self.state_dim) vec = base.values.squeeze() + noise vec = np.clip(vec, self.low, self.high) state_df = pd.DataFrame([vec], columns=self.FREE_STATE_KEYS) else: vec = self.rng.uniform(self.low, self.high) state_df = pd.DataFrame([vec], columns=self.FREE_STATE_KEYS) # ======================================================== # 2. 物理派生:计算 R # ======================================================== TMP = state_df.at[state_df.index[0], "TMP"] q_UF = state_df.at[state_df.index[0], "q_UF"] temp = state_df.at[state_df.index[0], "temp"] R = self.physics.resistance_from_tmp(tmp=TMP, q_UF=q_UF, temp=temp) if not np.isfinite(R): raise RuntimeError("Invalid resistance computed during reset sampling.") # ======================================================== # 3. 插入 R,形成完整状态(仍是 DataFrame) # ======================================================== state_df.insert(3, "R", R) state_df = state_df[self.FULL_STATE_KEYS] # ======================================================== # 4. DataFrame(1 row) → Series(对外唯一返回) # ======================================================== state = state_df.iloc[0] state.name = "reset_state" # 可选,但强烈推荐 return state # ============================================================ # 内部方法 # ============================================================ def _sample_real(self) -> pd.DataFrame: for _ in range(self.max_resample_attempts): idx = self.rng.randint(len(self.real_state_pool)) row = self.real_state_pool.iloc[[idx]].reset_index(drop=True) vec = row.values.squeeze() if np.all(vec >= self.low) and np.all(vec <= self.high): return row raise RuntimeError("No valid real reset state within bounds.") @staticmethod def _normalize(w): s = sum(w) return [x / s for x in w] def _get_sampling_config(self, progress: float) -> dict: progress = np.clip(progress, 0.0, 1.0) # # ------------------------- # # 阶段权重设计(非线性 + 提高虚拟工况) # # ------------------------- # w_real = (1.0 - progress) ** 1.2 # 历史工况逐渐衰减 # w_perturb = 0.5 * progress # 周边扰动按线性增加 # w_virtual = 0.3 * progress ** 1.5 # 虚拟工况加快增长,后期最大约 0.3 # # # perturb 扰动幅度 # perturb_scale = 0.02 + 0.04 * progress w_real = 0.0 w_perturb = 0.0 w_virtual = 1.0 # perturb 扰动幅度 perturb_scale = 0.0 return dict( w_real=w_real, w_perturb=w_perturb, w_virtual=w_virtual, perturb_scale=perturb_scale, )