env_reset.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import pandas as pd
  2. import numpy as np
  3. from typing import Optional
  4. class ResetSampler:
  5. # ============================================================
  6. # 状态定义(唯一真源)
  7. # ============================================================
  8. FREE_STATE_KEYS = [
  9. "q_UF",
  10. "temp",
  11. "TMP",
  12. "nuK",
  13. "slope",
  14. "power",
  15. "ceb_removal",
  16. ]
  17. FULL_STATE_KEYS = [
  18. "q_UF",
  19. "temp",
  20. "TMP",
  21. "R",
  22. "nuK",
  23. "slope",
  24. "power",
  25. "ceb_removal",
  26. ]
  27. IDX_Q_UF = 0
  28. IDX_TEMP = 1
  29. IDX_TMP = 2
  30. def __init__(
  31. self,
  32. bounds: "UFStateBounds",
  33. physics,
  34. real_state_pool=None,
  35. max_resample_attempts: int = 50,
  36. random_state=None,
  37. ):
  38. self.bounds = bounds
  39. self.physics = physics
  40. self.max_resample_attempts = max_resample_attempts
  41. self.rng = random_state or np.random.RandomState()
  42. # --- 自由变量边界(顺序必须与 FREE_STATE_KEYS 一致)---
  43. self.low = np.array([
  44. bounds.q_UF_min,
  45. bounds.temp_min,
  46. bounds.TMP0_min,
  47. bounds.nuK_min,
  48. bounds.slope_min,
  49. bounds.power_min,
  50. bounds.ceb_removal_min,
  51. ])
  52. self.high = np.array([
  53. bounds.q_UF_max,
  54. bounds.temp_max,
  55. bounds.TMP0_max,
  56. bounds.nuK_max,
  57. bounds.slope_max,
  58. bounds.power_max,
  59. bounds.ceb_removal_max,
  60. ])
  61. self.state_dim = len(self.low)
  62. # ========================================================
  63. # 统一 real_state_pool 为 DataFrame(不含 R)
  64. # ========================================================
  65. if real_state_pool is not None:
  66. if isinstance(real_state_pool, pd.DataFrame):
  67. df = real_state_pool.copy()
  68. else:
  69. df = pd.DataFrame(real_state_pool, columns=self.FULL_STATE_KEYS)
  70. if "R" in df.columns:
  71. df = df.drop(columns=["R"])
  72. df = df[self.FREE_STATE_KEYS]
  73. self.real_state_pool = df.reset_index(drop=True)
  74. else:
  75. self.real_state_pool = None
  76. # ============================================================
  77. # 对外接口
  78. # ============================================================
  79. def sample(self, progress: float) -> pd.Series:
  80. cfg = self._get_sampling_config(progress)
  81. sources, weights = [], []
  82. if self.real_state_pool is not None:
  83. sources += ["real", "perturb"]
  84. weights += [cfg["w_real"], cfg["w_perturb"]]
  85. sources.append("virtual")
  86. weights.append(cfg["w_virtual"])
  87. source = self.rng.choice(sources, p=self._normalize(weights))
  88. # ========================================================
  89. # 1. 采样自由变量(DataFrame 单行)
  90. # ========================================================
  91. if source == "real":
  92. state_df = self._sample_real() # 必须返回 DataFrame(1 row)
  93. elif source == "perturb":
  94. base = self._sample_real()
  95. noise = self.rng.normal(0.0, cfg["perturb_scale"], size=self.state_dim)
  96. vec = base.values.squeeze() + noise
  97. vec = np.clip(vec, self.low, self.high)
  98. state_df = pd.DataFrame([vec], columns=self.FREE_STATE_KEYS)
  99. else:
  100. vec = self.rng.uniform(self.low, self.high)
  101. state_df = pd.DataFrame([vec], columns=self.FREE_STATE_KEYS)
  102. # ========================================================
  103. # 2. 物理派生:计算 R
  104. # ========================================================
  105. TMP = state_df.at[state_df.index[0], "TMP"]
  106. q_UF = state_df.at[state_df.index[0], "q_UF"]
  107. temp = state_df.at[state_df.index[0], "temp"]
  108. R = self.physics.resistance_from_tmp(tmp=TMP, q_UF=q_UF, temp=temp)
  109. if not np.isfinite(R):
  110. raise RuntimeError("Invalid resistance computed during reset sampling.")
  111. # ========================================================
  112. # 3. 插入 R,形成完整状态(仍是 DataFrame)
  113. # ========================================================
  114. state_df.insert(3, "R", R)
  115. state_df = state_df[self.FULL_STATE_KEYS]
  116. # ========================================================
  117. # 4. DataFrame(1 row) → Series(对外唯一返回)
  118. # ========================================================
  119. state = state_df.iloc[0]
  120. state.name = "reset_state" # 可选,但强烈推荐
  121. return state
  122. # ============================================================
  123. # 内部方法
  124. # ============================================================
  125. def _sample_real(self) -> pd.DataFrame:
  126. for _ in range(self.max_resample_attempts):
  127. idx = self.rng.randint(len(self.real_state_pool))
  128. row = self.real_state_pool.iloc[[idx]].reset_index(drop=True)
  129. vec = row.values.squeeze()
  130. if np.all(vec >= self.low) and np.all(vec <= self.high):
  131. return row
  132. raise RuntimeError("No valid real reset state within bounds.")
  133. @staticmethod
  134. def _normalize(w):
  135. s = sum(w)
  136. return [x / s for x in w]
  137. def _get_sampling_config(self, progress: float) -> dict:
  138. progress = np.clip(progress, 0.0, 1.0)
  139. # # -------------------------
  140. # # 阶段权重设计(非线性 + 提高虚拟工况)
  141. # # -------------------------
  142. # w_real = (1.0 - progress) ** 1.2 # 历史工况逐渐衰减
  143. # w_perturb = 0.5 * progress # 周边扰动按线性增加
  144. # w_virtual = 0.3 * progress ** 1.5 # 虚拟工况加快增长,后期最大约 0.3
  145. #
  146. # # perturb 扰动幅度
  147. # perturb_scale = 0.02 + 0.04 * progress
  148. w_real = 0.0
  149. w_perturb = 0.0
  150. w_virtual = 1.0
  151. # perturb 扰动幅度
  152. perturb_scale = 0.0
  153. return dict(
  154. w_real=w_real,
  155. w_perturb=w_perturb,
  156. w_virtual=w_virtual,
  157. perturb_scale=perturb_scale,
  158. )