dqn_statebuilder.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import re
  2. from pathlib import Path
  3. from typing import Dict, Any
  4. import yaml
  5. import numpy as np
  6. import pandas as pd
  7. from oauthlib.uri_validate import segment
  8. # -------------------------------
  9. # 引入环境状态模板(最终输出)
  10. # -------------------------------
  11. from uf_train.env.env_params import UFState
  12. # -------------------------------
  13. # 引入分析类
  14. # -------------------------------
  15. from uf_data_process.load import UFConfigLoader
  16. from uf_data_process.label import UFEventClassifier, PostBackwashInletMarker
  17. from uf_data_process.filter import ConstantFlowFilter, EventQualityFilter, InletSegmentFilter
  18. from uf_data_process.calculate import UFResistanceCalculator, UFResistanceAnalyzer
  19. from uf_data_process.fit import ShortTermCycleFoulingFitter, LongTermFoulingFitter
  20. class DQNStateBuilder:
  21. """
  22. 在 DQN 决策前构建状态的工具类
  23. 相关数据:
  24. * CSV1 = 上一完整化学周期
  25. * CSV2 = 新周期初始进水段
  26. """
  27. def __init__(self, config_path: str):
  28. """
  29. Parameters
  30. ----------
  31. config_path : str
  32. uf_analyze_config.yaml 路径
  33. """
  34. self.cfg = UFConfigLoader(config_path)
  35. self.uf_cfg = self.cfg.uf
  36. self.params = self.cfg.params
  37. self.units = self.uf_cfg["units"]
  38. self.area_m2 = self.uf_cfg["area_m2"]
  39. self.scale_factor = self.params.get("scale_factor", 1e10)
  40. self.segment_head_n = self.params.get("segment_head_n", 10)
  41. self.segment_tail_n = self.params.get("segment_tail_n", 10)
  42. # ======================================================================
  43. # 对外主接口
  44. # ======================================================================
  45. def build_from_csv_pair(
  46. self,
  47. prev_cycle_csv: str,
  48. init_cycle_csv: str,
  49. ) -> UFState:
  50. """
  51. 使用【上一完整化学周期 CSV】+【当前周期初始 CSV】构建 UFState
  52. """
  53. df_prev = pd.read_csv(prev_cycle_csv)
  54. df_init = pd.read_csv(init_cycle_csv)
  55. # 自动识别 UF 单元编号(UF1 / UF2 / ...)
  56. unit_id = self._infer_unit_id(df_prev)
  57. # 分别处理两个 CSV
  58. prev_features = self._analyze_previous_cycle_csv(df_prev, unit_id)
  59. init_features = self._analyze_init_cycle_csv(df_init, unit_id)
  60. # 化学清洗去除阻力(上一周期末 - 当前初始)
  61. ceb_removal = max(
  62. prev_features["R_end"] - init_features["R_start"],
  63. 0.0
  64. )
  65. # 构建 UFState
  66. state = UFState(
  67. q_UF=init_features["q_mean"],
  68. TMP=init_features["tmp_mean"],
  69. temp=init_features["temp_mean"],
  70. R=init_features["R_start"],
  71. nuK=prev_features["nuK"],
  72. slope=prev_features["slope"],
  73. power=prev_features["power"],
  74. ceb_removal=ceb_removal,
  75. )
  76. return state
  77. # ======================================================================
  78. # 上一完整化学周期分析
  79. # ======================================================================
  80. def _analyze_previous_cycle_csv(
  81. self,
  82. df: pd.DataFrame,
  83. unit_id: str,
  84. ) -> Dict[str, float]:
  85. """
  86. 上一完整化学周期分析逻辑
  87. 步骤:
  88. 1. 事件标注
  89. 2. 进水段过滤(质量过滤)
  90. 3. 膜阻力计算
  91. 4. 提取周期末稳定阻力
  92. 5. 拟合 nuK
  93. 6. 拟合长期不可逆污染(slope / power)
  94. """
  95. # -------- 事件标注 --------
  96. df = self._label_events(df, unit_id)
  97. # -------- 保留过滤进水段 --------
  98. inlet_filter = InletSegmentFilter(
  99. control_col=f"C.M.{unit_id}_DB@word_control",
  100. stable_value=self.uf_cfg["stable_inlet_code"],
  101. min_points=self.params["min_stable_points"],
  102. )
  103. segments = inlet_filter.extract(df)
  104. quality_filter = EventQualityFilter(
  105. min_points=self.params["min_stable_points"]
  106. )
  107. segments = quality_filter.filter(segments)
  108. if len(segments) == 0:
  109. raise ValueError("上一周期无有效稳定进水段,无法构建状态")
  110. # -------- 3️⃣ 膜阻力计算 --------
  111. res_calc = UFResistanceCalculator(
  112. units=[unit_id],
  113. area_m2=self.area_m2,
  114. scale_factor=self.scale_factor,
  115. )
  116. segments = res_calc.calculate_for_segments(
  117. segments,
  118. temp_col=self.uf_cfg["temp_col"],
  119. flow_col=self.uf_cfg["flow_col_template"].format(unit=unit_id),
  120. )
  121. # -------- 膜阻力统计 --------
  122. res_col = f"{unit_id}_R_scaled"
  123. ura = UFResistanceAnalyzer(
  124. resistance_col=res_col,
  125. head_n=self.segment_head_n,
  126. tail_n=self.segment_tail_n
  127. )
  128. segments = ura.analyze_segments(segments)
  129. df_all = segments[-1]
  130. R_end = df_all["R_scaled_end"].iloc[0]
  131. # ===== 确保 time 为 datetime =====
  132. for i, seg in enumerate(segments):
  133. if not pd.api.types.is_datetime64_any_dtype(seg["time"]):
  134. seg = seg.copy()
  135. seg["time"] = pd.to_datetime(seg["time"], errors="coerce")
  136. seg = seg.dropna(subset=["time"])
  137. segments[i] = seg
  138. # -------- 5️⃣ 短期污染拟合(nuK)--------
  139. st_fitter = ShortTermCycleFoulingFitter(unit_id)
  140. nuK, _ = st_fitter.fit_cycle(segments)
  141. # -------- 6️⃣ 长期不可逆污染拟合 --------
  142. lt_fitter = LongTermFoulingFitter(unit_id)
  143. slope, power, _ = lt_fitter.fit_cycle(segments)
  144. return {
  145. "R_end": R_end,
  146. "nuK": float(nuK),
  147. "slope": float(slope),
  148. "power": float(power),
  149. }
  150. # ======================================================================
  151. # 当前周期初始进水段分析
  152. # ======================================================================
  153. def _analyze_init_cycle_csv(
  154. self,
  155. df: pd.DataFrame,
  156. unit_id: str,
  157. ) -> Dict[str, float]:
  158. """
  159. 当前周期初始进水段分析
  160. 特点:
  161. - 不切段
  162. - 不过滤
  163. - 只计算均值
  164. """
  165. # -------- 1️⃣ 事件标注 --------
  166. df = self._label_events(df, unit_id)
  167. # -------- 2️⃣ 仅保留进水行 --------
  168. df = df[df["event_type"] == "inlet"].copy()
  169. if df.empty:
  170. raise ValueError("初始 CSV 中无进水数据")
  171. # -------- 3️⃣ 膜阻力计算 --------
  172. res_calc = UFResistanceCalculator(
  173. units=[unit_id],
  174. area_m2=self.area_m2,
  175. scale_factor=self.scale_factor,
  176. )
  177. segments = [df]
  178. segments = res_calc.calculate_for_segments(
  179. segments,
  180. temp_col=self.uf_cfg["temp_col"],
  181. flow_col=self.uf_cfg["flow_col_template"].format(unit=unit_id),
  182. )
  183. df = segments[-1]
  184. flow_col = self.uf_cfg["flow_col_template"].format(unit=unit_id)
  185. temp_col = self.uf_cfg["temp_col"]
  186. press_col = f"C.M.{unit_id}_DB@press_PV"
  187. res_col = f"{unit_id}_R_scaled"
  188. return {
  189. "q_mean": float(df[flow_col].mean()),
  190. "tmp_mean": float(df[press_col].mean()),
  191. "temp_mean": float(df[temp_col].mean()),
  192. "R_start": float(df[res_col].mean()),
  193. }
  194. # ======================================================================
  195. # 工具函数
  196. # ======================================================================
  197. def _infer_unit_id(self, df: pd.DataFrame) -> str:
  198. """
  199. 根据列名自动识别 UF 单元编号
  200. """
  201. for unit in self.units:
  202. key = f"C.M.{unit}_FT_JS@out"
  203. if key in df.columns:
  204. return unit
  205. raise ValueError("无法从 CSV 列名识别 UF 单元编号")
  206. def _label_events(self, df: pd.DataFrame, unit_id: str) -> pd.DataFrame:
  207. """
  208. 为 DataFrame 标注 event_type
  209. """
  210. clf = UFEventClassifier(
  211. unit_name=unit_id,
  212. inlet_codes=self.uf_cfg["inlet_codes"],
  213. physical_code=self.uf_cfg["physical_bw_code"],
  214. chemical_code=self.uf_cfg["chemical_bw_code"],
  215. )
  216. df = clf.classify(df)
  217. df = clf.segment(df)
  218. return df