| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- import pandas as pd
- import numpy as np
- from typing import Optional
- class ResetSampler:
- # ============================================================
- # 状态定义(唯一真源)
- # ============================================================
- FREE_STATE_KEYS = [
- "q_UF",
- "temp",
- "TMP",
- "nuK",
- "slope",
- "power",
- "ceb_removal",
- ]
- FULL_STATE_KEYS = [
- "q_UF",
- "temp",
- "TMP",
- "R",
- "nuK",
- "slope",
- "power",
- "ceb_removal",
- ]
- IDX_Q_UF = 0
- IDX_TEMP = 1
- IDX_TMP = 2
- def __init__(
- self,
- bounds: "UFStateBounds",
- physics,
- real_state_pool=None,
- max_resample_attempts: int = 50,
- random_state=None,
- ):
- self.bounds = bounds
- self.physics = physics
- self.max_resample_attempts = max_resample_attempts
- self.rng = random_state or np.random.RandomState()
- # --- 自由变量边界(顺序必须与 FREE_STATE_KEYS 一致)---
- self.low = np.array([
- bounds.q_UF_min,
- bounds.temp_min,
- bounds.TMP0_min,
- bounds.nuK_min,
- bounds.slope_min,
- bounds.power_min,
- bounds.ceb_removal_min,
- ])
- self.high = np.array([
- bounds.q_UF_max,
- bounds.temp_max,
- bounds.TMP0_max,
- bounds.nuK_max,
- bounds.slope_max,
- bounds.power_max,
- bounds.ceb_removal_max,
- ])
- self.state_dim = len(self.low)
- # ========================================================
- # 统一 real_state_pool 为 DataFrame(不含 R)
- # ========================================================
- if real_state_pool is not None:
- if isinstance(real_state_pool, pd.DataFrame):
- df = real_state_pool.copy()
- else:
- df = pd.DataFrame(real_state_pool, columns=self.FULL_STATE_KEYS)
- if "R" in df.columns:
- df = df.drop(columns=["R"])
- df = df[self.FREE_STATE_KEYS]
- self.real_state_pool = df.reset_index(drop=True)
- else:
- self.real_state_pool = None
- # ============================================================
- # 对外接口
- # ============================================================
- def sample(self, progress: float) -> pd.Series:
- cfg = self._get_sampling_config(progress)
- sources, weights = [], []
- if self.real_state_pool is not None:
- sources += ["real", "perturb"]
- weights += [cfg["w_real"], cfg["w_perturb"]]
- sources.append("virtual")
- weights.append(cfg["w_virtual"])
- source = self.rng.choice(sources, p=self._normalize(weights))
- # ========================================================
- # 1. 采样自由变量(DataFrame 单行)
- # ========================================================
- if source == "real":
- state_df = self._sample_real() # 必须返回 DataFrame(1 row)
- elif source == "perturb":
- base = self._sample_real()
- noise = self.rng.normal(0.0, cfg["perturb_scale"], size=self.state_dim)
- vec = base.values.squeeze() + noise
- vec = np.clip(vec, self.low, self.high)
- state_df = pd.DataFrame([vec], columns=self.FREE_STATE_KEYS)
- else:
- vec = self.rng.uniform(self.low, self.high)
- state_df = pd.DataFrame([vec], columns=self.FREE_STATE_KEYS)
- # ========================================================
- # 2. 物理派生:计算 R
- # ========================================================
- TMP = state_df.at[state_df.index[0], "TMP"]
- q_UF = state_df.at[state_df.index[0], "q_UF"]
- temp = state_df.at[state_df.index[0], "temp"]
- R = self.physics.resistance_from_tmp(tmp=TMP, q_UF=q_UF, temp=temp)
- if not np.isfinite(R):
- raise RuntimeError("Invalid resistance computed during reset sampling.")
- # ========================================================
- # 3. 插入 R,形成完整状态(仍是 DataFrame)
- # ========================================================
- state_df.insert(3, "R", R)
- state_df = state_df[self.FULL_STATE_KEYS]
- # ========================================================
- # 4. DataFrame(1 row) → Series(对外唯一返回)
- # ========================================================
- state = state_df.iloc[0]
- state.name = "reset_state" # 可选,但强烈推荐
- return state
- # ============================================================
- # 内部方法
- # ============================================================
- def _sample_real(self) -> pd.DataFrame:
- for _ in range(self.max_resample_attempts):
- idx = self.rng.randint(len(self.real_state_pool))
- row = self.real_state_pool.iloc[[idx]].reset_index(drop=True)
- vec = row.values.squeeze()
- if np.all(vec >= self.low) and np.all(vec <= self.high):
- return row
- raise RuntimeError("No valid real reset state within bounds.")
- @staticmethod
- def _normalize(w):
- s = sum(w)
- return [x / s for x in w]
- def _get_sampling_config(self, progress: float) -> dict:
- progress = np.clip(progress, 0.0, 1.0)
- # -------------------------
- # 阶段权重设计(非线性 + 提高虚拟工况)
- # -------------------------
- # w_real = (1.0 - progress) ** 1.2 # 历史工况逐渐衰减
- # w_perturb = 0.5 * progress # 周边扰动按线性增加
- # w_virtual = 0.3 * progress ** 1.5 # 虚拟工况加快增长,后期最大约 0.3
- #
- # # perturb 扰动幅度
- # perturb_scale = 0.02 + 0.04 * progress
- w_real = 0.0
- w_perturb = 0.0
- w_virtual = 1.0
- # perturb 扰动幅度
- perturb_scale = 0.0
- return dict(
- w_real=w_real,
- w_perturb=w_perturb,
- w_virtual=w_virtual,
- perturb_scale=perturb_scale,
- )
|