DQN_env.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. import os
  2. import torch
  3. from pathlib import Path
  4. import numpy as np
  5. import gymnasium as gym
  6. from gymnasium import spaces
  7. from typing import Dict, Tuple, Optional
  8. import torch
  9. import torch.nn as nn
  10. from dataclasses import dataclass, asdict
  11. from UF_resistance_models import ResistanceIncreaseModel, ResistanceDecreaseModel # 导入模型类
  12. import copy
  13. # =======================
  14. # 膜运行参数类:定义膜的基础运行参数
  15. # =======================
  16. @dataclass
  17. class UFParams:
  18. # —— 膜动态运行参数 ——
  19. q_UF: float = 360.0 # 过滤进水流量(m^3/h)
  20. TMP0: float = 0.03 # 初始跨膜压差
  21. temp: float = 25.0 # 水温,摄氏度
  22. # —— 膜阻力模型参数 ——
  23. nuK: float =4.92e+01 # 过滤阶段膜阻力增长模型参数
  24. slope: float = 3.44e-01 # 全周期不可逆污染阻力增长斜率
  25. power: float = 1.032 # 全周期不可逆污染阻力增长幂次
  26. tau_bw_s: float = 30.0 # 物洗时长影响时间尺度
  27. gamma_t: float = 1.0 # 物洗时长作用指数
  28. ceb_removal: float = 150 # CEB去除膜阻力
  29. # —— 膜运行约束参数 ——
  30. global_TMP_limit: float = 0.08 # TMP硬上限(MPa)
  31. TMP0_max: float = 0.035 # 初始TMP上限(MPa)
  32. TMP0_min: float = 0.01 # 初始TMP下限(MPa)
  33. q_UF_max: float = 400.0 # 进水流量上限(m^3/h)
  34. q_UF_min: float = 250.0 # 进水流量上限(m^3/h)
  35. temp_max: float = 40.0 # 温度上限(摄氏度)
  36. temp_min: float = 10.0 # 温度下限(摄氏度)
  37. nuK_max: float = 6e+01 # 物理周期总阻力增速上限(m^-1/s)
  38. nuK_min: float = 3e+01 # 物理周期总阻力增速下限(m^-1/s)
  39. slope_max: float = 10 # 化学周期长期阻力增速斜率上限
  40. slope_min: float = 0.1 # 化学周期长期阻力增速斜率下限
  41. power_max: float = 1.3 # 化学周期长期阻力增速幂次上限
  42. power_min: float = 0.8 # 化学周期长期阻力增速幂次下限
  43. ceb_removal_max: float = 150 # CEB去除阻力(已缩放)上限(m^-1)
  44. ceb_removal_min: float = 100 # CEB去除阻力(已缩放)下限(m^-1)
  45. # —— 反洗参数(固定) ——
  46. q_bw_m3ph: float = 1000.0 # 物理反洗流量(m^3/h)
  47. # —— CEB参数 ——
  48. T_ceb_interval_h: float = 60.0 # 固定每 k 小时做一次CEB
  49. v_ceb_m3: float = 30.0 # CEB用水体积(m^3)
  50. t_ceb_s: float = 40 * 60.0 # CEB时长(s)
  51. # —— 搜索范围(秒) ——
  52. L_min_s: float = 3800.0 # 过滤时长下限(s)
  53. L_max_s: float = 6000.0 # 过滤时长上限(s)
  54. t_bw_min_s: float = 40.0 # 物洗时长下限(s)
  55. t_bw_max_s: float = 60.0 # 物洗时长上限(s)
  56. # —— 网格 ——
  57. L_step_s: float = 60.0 # 过滤时长步长(s)
  58. t_bw_step_s: float = 5.0 # 物洗时长步长(s)
  59. # —— 奖励函数参数 ——
  60. k_rec = 5.0 # 回收率敏感度
  61. k_res = 10.0 # 残余污染敏感度
  62. rec_low, rec_high = 0.92, 0.99
  63. rr0 = 0.08
  64. # =======================
  65. # 辅助函数:转换膜阻力与跨膜压差
  66. # =======================
  67. def xishan_viscosity(temp):
  68. # temp: 水温,单位摄氏度
  69. """
  70. 锡山水厂 PLC水温校正因子经验公式(25摄氏度标准)
  71. 返回温度修正后的水粘度(纯水修正),TODO:水厂水质与纯水相差较大,对粘度有一定影响
  72. """
  73. x = (temp + 273.15) / 300
  74. factor = 890 / (280.68 * x ** -1.9 + 511.45 * x ** -7.7 + 61.131 * x ** -19.6 + 0.45903 * x ** -40)
  75. mu = 0.00089 / factor
  76. return mu
  77. def _calculate_resistance(tmp, q_UF, temp):
  78. """
  79. 计算超滤膜阻力 R = TMP / (J * μ)
  80. 返回缩小1e10的膜阻力(超滤原膜阻力量级为1e12,过大的绝对值容易导致平稳拟合)
  81. """
  82. A = 128 * 40 # m²,有效膜面积
  83. mu = xishan_viscosity(temp) # 温度修正后的水粘度
  84. TMP_Pa = tmp * 1e6 # 跨膜压差 MPa -> Pa
  85. J = q_UF / A / 3600 # 通量 m³/h -> m³/(m²·s)
  86. if J <= 0 or mu <= 0:
  87. return np.nan
  88. R = TMP_Pa / (J * mu) / 1e10 # 缩放膜阻力
  89. return float(R)
  90. def _calculate_tmp(R, q_UF, temp):
  91. """
  92. 还原超滤跨膜压差 TMP
  93. """
  94. A = 128 * 40 # m²,有效膜面积
  95. mu = xishan_viscosity(temp) # 温度修正后的水粘度
  96. J = q_UF / A / 3600 # 通量 m³/h -> m³/(m²·s)
  97. TMP_Pa = R * J * mu * 1e10
  98. tmp = TMP_Pa / 1e6
  99. return float(tmp)
  100. # =======================
  101. # 环境体模型加载函数
  102. # =======================
  103. def load_resistance_models():
  104. """加载阻力变化模型,仅在首次调用时执行"""
  105. global resistance_model_fp, resistance_model_bw
  106. # 如果全局模型已存在,则直接返回
  107. if "resistance_model_fp" in globals() and resistance_model_fp is not None:
  108. return resistance_model_fp, resistance_model_bw
  109. print("🔄 Loading resistance models...")
  110. # 初始化模型
  111. resistance_model_fp = ResistanceIncreaseModel()
  112. resistance_model_bw = ResistanceDecreaseModel()
  113. # 取得当前脚本所在目录(即 rl_dqn_env.py 或 check_initial_state.py 同目录)
  114. base_dir = Path(__file__).resolve().parent
  115. # 构造模型路径
  116. fp_path = base_dir / "resistance_model_fp.pth"
  117. bw_path = base_dir / "resistance_model_bw.pth"
  118. # 检查文件存在性
  119. assert fp_path.exists(), f"缺少 {fp_path.name}"
  120. assert bw_path.exists(), f"缺少 {bw_path.name}"
  121. # 加载权重
  122. resistance_model_fp.load_state_dict(torch.load(fp_path, map_location="cpu"))
  123. resistance_model_bw.load_state_dict(torch.load(bw_path, map_location="cpu"))
  124. # 设置推理模式
  125. resistance_model_fp.eval()
  126. resistance_model_bw.eval()
  127. print("✅ Resistance models loaded successfully from current directory.")
  128. return resistance_model_fp, resistance_model_bw
  129. # =======================
  130. # 环境体模型模拟函数
  131. # =======================
  132. def _delta_resistance(p, L_h: float) -> float:
  133. """
  134. 过滤时段膜阻力上升量:调用 resistance_model_fp.pth 模型
  135. """
  136. return resistance_model_fp(p, L_h)
  137. def phi_bw_of(p, R0: float, R_end: float, L_h_start: float, L_h_next_start: float, t_bw_s: float) -> float:
  138. """
  139. 物理冲洗去除膜阻力值:调用 resistance_model_bw 模型
  140. """
  141. return resistance_model_bw(p, R0, R_end, L_h_start, L_h_next_start, t_bw_s)
  142. def _v_bw_m3(p, t_bw_s: float) -> float:
  143. """
  144. 物理反洗水耗
  145. """
  146. return float(p.q_bw_m3ph * (float(t_bw_s) / 3600.0))
  147. def simulate_one_supercycle(p: UFParams, L_s: float, t_bw_s: float):
  148. """
  149. 模拟一个超级周期(多次物理反洗 + 一次化学反洗)
  150. 返回: (info, next_params)
  151. """
  152. L_h = float(L_s) / 3600.0 # 小周期过滤时间(h)
  153. tmp = p.TMP0
  154. R0 = _calculate_resistance(p.TMP0, p.q_UF, p.temp)
  155. max_tmp_during_filtration = tmp
  156. min_tmp_during_filtration = tmp
  157. max_residual_increase = 0.0
  158. t_small_cycle_h = (L_s + t_bw_s) / 3600.0
  159. k_bw_per_ceb = int(np.floor(p.T_ceb_interval_h / t_small_cycle_h))
  160. if k_bw_per_ceb < 1:
  161. k_bw_per_ceb = 1
  162. energy_lookup = {
  163. 3600: 0.1034, 3660: 0.1031, 3720: 0.1029, 3780: 0.1026,
  164. 3840: 0.1023, 3900: 0.1021, 3960: 0.1019, 4020: 0.1017,
  165. 4080: 0.1015, 4140: 0.1012, 4200: 0.1011
  166. }
  167. # --- 循环模拟物理反洗 ---
  168. for idx in range(k_bw_per_ceb):
  169. tmp_run_start = tmp
  170. q_UF = p.q_UF
  171. temp = p.temp
  172. R_run_start = _calculate_resistance(tmp_run_start, q_UF, temp)
  173. d_R = _delta_resistance(p, L_s)
  174. R_peak = R_run_start + d_R
  175. tmp_peak = _calculate_tmp(R_peak, q_UF, temp)
  176. max_tmp_during_filtration = max(max_tmp_during_filtration, tmp_peak)
  177. min_tmp_during_filtration = min(min_tmp_during_filtration, tmp_run_start)
  178. # 物洗膜阻力减小
  179. L_h_start = (L_s + t_bw_s) / 3600.0 * idx
  180. L_h_next_start = (L_s + t_bw_s) / 3600.0 * (idx + 1)
  181. reversible_R = phi_bw_of(p, R_run_start, R_peak, L_h_start, L_h_next_start, t_bw_s)
  182. R_after_bw = R_peak - reversible_R
  183. tmp_after_bw = _calculate_tmp(R_after_bw, q_UF, temp)
  184. residual_inc = tmp_after_bw - tmp_run_start
  185. max_residual_increase = max(max_residual_increase, residual_inc)
  186. tmp = tmp_after_bw
  187. # --- CEB反洗 ---
  188. R_after_ceb = R_peak - p.ceb_removal
  189. tmp_after_ceb = _calculate_tmp(R_after_ceb, q_UF, temp)
  190. # ============================================================
  191. # 生成本周期指标
  192. # ============================================================
  193. # --- 体积与能耗 ---
  194. V_feed_super = k_bw_per_ceb * p.q_UF * L_h
  195. V_loss_super = k_bw_per_ceb * _v_bw_m3(p, t_bw_s) + p.v_ceb_m3
  196. V_net = max(0.0, V_feed_super - V_loss_super)
  197. recovery = max(0.0, V_net / max(V_feed_super, 1e-12))
  198. T_super_h = k_bw_per_ceb * (L_s + t_bw_s) / 3600.0 + p.t_ceb_s / 3600.0
  199. daily_prod_time_h = k_bw_per_ceb * L_h / T_super_h * 24.0
  200. closest_L = min(energy_lookup.keys(), key=lambda x: abs(x - L_s))
  201. ton_water_energy = energy_lookup[closest_L] #TODO:需确认新过滤时间范围下的吨水电耗
  202. # --- 信息输出 ---
  203. info = {
  204. "q_UF": p.q_UF,
  205. "temp": p.temp,
  206. "recovery": recovery,
  207. "V_feed_super_m3": V_feed_super,
  208. "V_loss_super_m3": V_loss_super,
  209. "V_net_super_m3": V_net,
  210. "supercycle_time_h": T_super_h,
  211. "max_TMP_during_filtration": max_tmp_during_filtration,
  212. "min_TMP_during_filtration": min_tmp_during_filtration,
  213. "global_TMP_limit":p.global_TMP_limit,
  214. "max_residual_increase_per_run": max_residual_increase,
  215. "R0": R0,
  216. "R_after_ceb": R_after_ceb,
  217. "TMP0":p.TMP0,
  218. "TMP_after_ceb": tmp_after_ceb,
  219. "daily_prod_time_h": daily_prod_time_h,
  220. "ton_water_energy_kWh_per_m3": ton_water_energy,
  221. "k_bw_per_ceb": k_bw_per_ceb
  222. }
  223. # ============================================================
  224. # 状态更新:生成 next_params(新状态)
  225. # ============================================================
  226. next_params = copy.deepcopy(p)
  227. # 更新跨膜压差(TMP)
  228. next_params.TMP0 = tmp_after_ceb
  229. # 可选参数(当前保持不变,未来可扩展更新逻辑)
  230. next_params.slope = p.slope
  231. next_params.power = p.power
  232. next_params.ceb_removal = p.ceb_removal
  233. next_params.nuK = p.nuK
  234. next_params.q_UF = p.q_UF
  235. next_params.temp = p.temp
  236. return info, next_params
  237. def calculate_reward(p: UFParams, info: dict) -> float:
  238. """
  239. TMP不参与奖励计算,仅考虑回收率与残余污染比例之间的权衡。
  240. 满足:
  241. - 当 recovery=0.97, residual_ratio=0.1 → reward = 0
  242. - 当 recovery=0.90, residual_ratio=0.0 → reward = 0
  243. - 在两者之间平衡(如 recovery≈0.94, residual_ratio≈0.05)→ reward > 0
  244. """
  245. recovery = info["recovery"]
  246. residual_ratio = (info["R_after_ceb"] - info["R0"]) / info["R0"]
  247. # 回收率奖励(在 [rec_low, rec_high] 内平滑上升)
  248. rec_norm = (recovery - p.rec_low) / (p.rec_high - p.rec_low)
  249. rec_reward = np.clip(np.tanh(p.k_rec * (rec_norm - 0.5)), -1, 1)
  250. # 残余比惩罚(超过rr0时快速变为负值)
  251. res_penalty = -np.tanh(p.k_res * (residual_ratio / p.rr0 - 1))
  252. # 组合逻辑:权衡二者
  253. total_reward = rec_reward + res_penalty
  254. # 再平移使指定点为零:
  255. # recovery=0.97, residual=0.1 → 0
  256. # recovery=0.90, residual=0.0 → 0
  257. # 经验上,这两点几乎对称,因此无需额外线性偏移
  258. # 若希望严格归零,可用线性校正:
  259. total_reward -= 0.0
  260. return total_reward
  261. def is_dead_cycle(info: dict) -> bool:
  262. """
  263. 判断当前循环是否为成功循环(True)或失败循环(False)
  264. 失败条件:
  265. 1. 最大TMP超过设定上限;
  266. 2. 回收率低于75%;
  267. 3. 化学反冲洗后膜阻力上升超过10%。
  268. 参数:
  269. info: dict
  270. simulate_one_supercycle() 返回的指标字典,需包含:
  271. - max_TMP_during_filtration
  272. - recovery
  273. - R_after_ceb
  274. - R_run_start
  275. - TMP_limit(如果有定义)
  276. 返回:
  277. bool: True 表示成功循环,False 表示失败循环。
  278. """
  279. TMP_limit = info.get("global_TMP_limit", 0.08) # 默认硬约束上限
  280. max_tmp = info.get("max_TMP_during_filtration", 0)
  281. recovery = info.get("recovery", 1.0)
  282. R_after_ceb = info.get("R_after_ceb", 0)
  283. R0 = info.get("R0", 1e-6)
  284. # 判断条件
  285. if max_tmp > TMP_limit:
  286. return False
  287. if recovery < 0.75:
  288. return False
  289. if (R_after_ceb - R0) / R0 > 0.1:
  290. return False
  291. return True
  292. class UFSuperCycleEnv(gym.Env):
  293. """超滤系统环境(超级周期级别决策)"""
  294. metadata = {"render_modes": ["human"]}
  295. def __init__(self, base_params, resistance_models=None, max_episode_steps: int = 15):
  296. super(UFSuperCycleEnv, self).__init__()
  297. self.base_params = base_params
  298. self.current_params = copy.deepcopy(base_params)
  299. self.max_episode_steps = max_episode_steps
  300. self.current_step = 0
  301. if resistance_models is None:
  302. self.resistance_model_fp, self.resistance_model_bw = load_resistance_models()
  303. else:
  304. self.resistance_model_fp, self.resistance_model_bw = resistance_models
  305. # 计算离散动作空间
  306. self.L_values = np.arange(
  307. self.base_params.L_min_s,
  308. self.base_params.L_max_s,
  309. self.base_params.L_step_s
  310. )
  311. self.t_bw_values = np.arange(
  312. self.base_params.t_bw_min_s,
  313. self.base_params.t_bw_max_s + self.base_params.t_bw_step_s,
  314. self.base_params.t_bw_step_s
  315. )
  316. self.num_L = len(self.L_values)
  317. self.num_bw = len(self.t_bw_values)
  318. # 单一离散动作空间
  319. self.action_space = spaces.Discrete(self.num_L * self.num_bw)
  320. # 状态空间,归一化在 _get_obs 中处理
  321. self.observation_space = spaces.Box(
  322. low=np.zeros(8, dtype=np.float32),
  323. high=np.ones(8, dtype=np.float32),
  324. dtype=np.float32
  325. )
  326. # 初始化环境
  327. self.reset(seed=None)
  328. def generate_initial_state(self):
  329. """
  330. 随机生成一个初始状态,不进行死状态判断
  331. """
  332. self.current_params.TMP0 = np.random.uniform(
  333. self.current_params.TMP0_min, self.current_params.TMP0_max
  334. )
  335. self.current_params.q_UF = np.random.uniform(
  336. self.current_params.q_UF_min, self.current_params.q_UF_max
  337. )
  338. self.current_params.temp = np.random.uniform(
  339. self.current_params.temp_min, self.current_params.temp_max
  340. )
  341. self.current_params.R0 = _calculate_resistance(
  342. self.current_params.TMP0,
  343. self.current_params.q_UF,
  344. self.current_params.temp
  345. )
  346. self.current_params.nuK = np.random.uniform(
  347. self.current_params.nuK_min, self.current_params.nuK_max
  348. )
  349. self.current_params.slope = np.random.uniform(
  350. self.current_params.slope_min, self.current_params.slope_max
  351. )
  352. self.current_params.power = np.random.uniform(
  353. self.current_params.power_min, self.current_params.power_max
  354. )
  355. self.current_params.ceb_removal = np.random.uniform(
  356. self.current_params.ceb_removal_min, self.current_params.ceb_removal_max
  357. )
  358. return self._get_state_copy()
  359. def reset(self, seed=None, options=None, max_attempts: int = 200):
  360. super().reset(seed=seed)
  361. attempts = 0
  362. while attempts < max_attempts:
  363. attempts += 1
  364. self.generate_initial_state() # 生成随机初始状态
  365. if self.check_dead_initial_state(max_steps=getattr(self, "max_episode_steps", 15),
  366. L_s=3800, t_bw_s=60):
  367. # True 表示可行,退出循环
  368. break
  369. else:
  370. # 超过最大尝试次数仍未生成可行状态
  371. raise RuntimeError(f"在 {max_attempts} 次尝试后仍无法生成可行初始状态。")
  372. # 初始化步数、动作、最大 TMP
  373. self.current_step = 0
  374. self.last_action = (self.base_params.L_min_s, self.base_params.t_bw_min_s)
  375. self.max_TMP_during_filtration = self.current_params.TMP0
  376. return self._get_obs(), {}
  377. def check_dead_initial_state(self, max_steps: int = None,
  378. L_s: int = 4900, t_bw_s: int = 50) -> bool:
  379. """
  380. 判断当前环境生成的初始状态是否为可行(non-dead)。
  381. 使用最保守策略连续模拟 max_steps 次:
  382. 若任意一次 is_dead_cycle(info) 返回 False,则视为必死状态。
  383. 参数:
  384. max_steps: 模拟步数,默认使用 self.max_episode_steps
  385. L_s: 过滤时长(s),默认 3800
  386. t_bw_s: 物理反洗时长(s),默认 60
  387. 返回:
  388. bool: True 表示可行状态(non-dead),False 表示必死状态
  389. """
  390. if max_steps is None:
  391. max_steps = getattr(self, "max_episode_steps", 15)
  392. # 生成初始状态
  393. self.generate_initial_state()
  394. if not hasattr(self, "current_params"):
  395. raise AttributeError("generate_initial_state() 未设置 current_params。")
  396. import copy
  397. curr_p = copy.deepcopy(self.current_params)
  398. # 逐步模拟
  399. for step in range(max_steps):
  400. try:
  401. info, next_params = simulate_one_supercycle(curr_p, L_s, t_bw_s)
  402. except Exception:
  403. # 异常即视为不可行
  404. return False
  405. if not is_dead_cycle(info):
  406. # 任意一次失败即为必死状态
  407. return False
  408. curr_p = next_params
  409. return True
  410. def _get_state_copy(self):
  411. return copy.deepcopy(self.current_params)
  412. def _get_obs(self):
  413. """
  414. 构建当前环境归一化状态向量
  415. """
  416. # === 1. 从 current_params 读取动态参数 ===
  417. TMP0 = self.current_params.TMP0
  418. q_UF = self.current_params.q_UF
  419. temp = self.current_params.temp
  420. # === 2. 计算本周期初始膜阻力 ===
  421. R0 = _calculate_resistance(TMP0, q_UF, temp)
  422. # === 3. 从 current_params 读取膜阻力增长模型参数 ===
  423. nuk = self.current_params.nuK
  424. slope = self.current_params.slope
  425. power = self.current_params.power
  426. ceb_removal = self.current_params.ceb_removal
  427. # === 4. 从 current_params 动态读取上下限 ===
  428. TMP0_min, TMP0_max = self.current_params.TMP0_min, self.current_params.TMP0_max
  429. q_UF_min, q_UF_max = self.current_params.q_UF_min, self.current_params.q_UF_max
  430. temp_min, temp_max = self.current_params.temp_min, self.current_params.temp_max
  431. nuK_min, nuK_max = self.current_params.nuK_min, self.current_params.nuK_max
  432. slope_min, slope_max = self.current_params.slope_min, self.current_params.slope_max
  433. power_min, power_max = self.current_params.power_min, self.current_params.power_max
  434. ceb_min, ceb_max = self.current_params.ceb_removal_min, self.current_params.ceb_removal_max
  435. # === 5. 归一化计算(clip防止越界) ===
  436. TMP0_norm = np.clip((TMP0 - TMP0_min) / (TMP0_max - TMP0_min), 0, 1)
  437. q_UF_norm = np.clip((q_UF - q_UF_min) / (q_UF_max - q_UF_min), 0, 1)
  438. temp_norm = np.clip((temp - temp_min) / (temp_max - temp_min), 0, 1)
  439. # R0 不在 current_params 中定义上下限,设定经验范围
  440. R0_norm = np.clip((R0 - 100.0) / (800.0 - 100.0), 0, 1)
  441. short_term_norm = np.clip((nuk - nuK_min) / (nuK_max - nuK_min), 0, 1)
  442. long_term_slope_norm = np.clip((slope - slope_min) / (slope_max - slope_min), 0, 1)
  443. long_term_power_norm = np.clip((power - power_min) / (power_max - power_min), 0, 1)
  444. ceb_removal_norm = np.clip((ceb_removal - ceb_min) / (ceb_max - ceb_min), 0, 1)
  445. # === 6. 构建观测向量 ===
  446. obs = np.array([
  447. TMP0_norm,
  448. q_UF_norm,
  449. temp_norm,
  450. R0_norm,
  451. short_term_norm,
  452. long_term_slope_norm,
  453. long_term_power_norm,
  454. ceb_removal_norm
  455. ], dtype=np.float32)
  456. return obs
  457. def _get_action_values(self, action):
  458. """
  459. 将动作还原为实际时长
  460. """
  461. L_idx = action // self.num_bw
  462. t_bw_idx = action % self.num_bw
  463. return self.L_values[L_idx], self.t_bw_values[t_bw_idx]
  464. def step(self, action):
  465. self.current_step += 1
  466. L_s, t_bw_s = self._get_action_values(action)
  467. L_s = np.clip(L_s, self.base_params.L_min_s, self.base_params.L_max_s)
  468. t_bw_s = np.clip(t_bw_s, self.base_params.t_bw_min_s, self.base_params.t_bw_max_s)
  469. # 模拟超级周期
  470. info, next_params = simulate_one_supercycle(self.current_params, L_s, t_bw_s)
  471. # 根据 info 判断是否成功
  472. feasible = is_dead_cycle(info) # True 表示成功循环,False 表示失败
  473. if feasible:
  474. reward = calculate_reward(self.current_params, info)
  475. self.current_params = next_params
  476. terminated = False
  477. else:
  478. reward = -10
  479. terminated = True
  480. truncated = self.current_step >= self.max_episode_steps
  481. self.last_action = (L_s, t_bw_s)
  482. next_obs = self._get_obs()
  483. info["feasible"] = feasible
  484. info["step"] = self.current_step
  485. return next_obs, reward, terminated, truncated, info