DQN_env.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. import os
  2. import time
  3. import random
  4. import numpy as np
  5. import gymnasium as gym
  6. from gymnasium import spaces
  7. from stable_baselines3 import DQN
  8. from stable_baselines3.common.monitor import Monitor
  9. from stable_baselines3.common.vec_env import DummyVecEnv
  10. from stable_baselines3.common.callbacks import BaseCallback
  11. from typing import Dict, Tuple, Optional
  12. import torch
  13. import torch.nn as nn
  14. from dataclasses import dataclass, asdict
  15. from save_uf_models import TMPIncreaseModel, TMPDecreaseModel # 导入模型类
  16. import copy
  17. # ==== 定义膜的基础运行参数 ====
  18. @dataclass
  19. class UFParams:
  20. # —— 膜与运行参数 ——
  21. q_UF: float = 360.0 # 过滤进水流量(m^3/h)
  22. TMP0: float = 0.03 # 初始TMP(MPa)
  23. TMP_max: float = 0.06 # TMP硬上限(MPa)
  24. # —— 膜污染动力学 ——
  25. alpha: float = 1e-6 # TMP增长系数
  26. belta: float = 1.1 # 幂指数
  27. # —— 反洗参数(固定) ——
  28. q_bw_m3ph: float = 1000.0 # 物理反洗流量(m^3/h)
  29. # —— CEB参数(固定) ——
  30. T_ceb_interval_h: float = 48.0 # 固定每 k 小时做一次CEB
  31. v_ceb_m3: float = 30.0 # CEB用水体积(m^3)
  32. t_ceb_s: float = 40 * 60.0 # CEB时长(s)
  33. phi_ceb: float = 1.0 # CEB去除比例(简化:完全恢复到TMP0)
  34. # —— 约束与收敛 ——
  35. dTMP: float = 0.001 # 单次产水结束时,相对TMP0最大升幅(MPa)
  36. # —— 搜索范围(秒) ——
  37. L_min_s: float = 3800.0 # 过滤时长下限(s)
  38. L_max_s: float = 4200.0 # 过滤时长上限(s)
  39. t_bw_min_s: float = 90.0 # 物洗时长下限(s)
  40. t_bw_max_s: float = 100.0 # 物洗时长上限(s)
  41. # —— 物理反洗恢复函数参数 ——
  42. phi_bw_min: float = 0.7 # 物洗去除比例最小值
  43. phi_bw_max: float = 1.0 # 物洗去除比例最大值
  44. L_ref_s: float = 4000.0 # 过滤时长影响时间尺度
  45. tau_bw_s: float = 30.0 # 物洗时长影响时间尺度
  46. gamma_t: float = 1.0 # 物洗时长作用指数
  47. # —— 网格 ——
  48. L_step_s: float = 60.0 # 过滤时长步长(s)
  49. t_bw_step_s: float = 2.0 # 物洗时长步长(s)
  50. # 多目标加权及高TMP惩罚
  51. w_rec: float = 0.8 # 回收率权重
  52. w_rate: float = 0.2 # 净供水率权重
  53. w_headroom: float = 0.3 # 贴边惩罚权重
  54. r_headroom: float = 2.0 # 贴边惩罚幂次
  55. headroom_hardcap: float = 0.98 # 超过此比例直接视为不可取
  56. # ==== 定义强化学习超参数 ====
  57. @dataclass
  58. class DQNParams:
  59. """
  60. DQN 超参数定义类
  61. 用于统一管理模型训练参数
  62. """
  63. # 学习率,控制神经网络更新步长
  64. learning_rate: float = 1e-4
  65. # 经验回放缓冲区大小(步数)
  66. buffer_size: int = 2000
  67. # 学习开始前需要收集的步数
  68. learning_starts: int = 200
  69. # 每次从经验池中采样的样本数量
  70. batch_size: int = 16
  71. # 折扣因子,越接近1越重视长期奖励
  72. gamma: float = 0.95
  73. # 每隔多少步训练一次
  74. train_freq: int = 1
  75. # 目标网络更新间隔
  76. target_update_interval: int = 1000
  77. # 初始探索率 ε
  78. exploration_initial_eps: float = 1.0
  79. # 从初始ε衰减到最终ε所占的训练比例
  80. exploration_fraction: float = 0.6
  81. # 最终探索率 ε
  82. exploration_final_eps: float = 0.02
  83. # 日志备注(用于区分不同实验)
  84. remark: str = "default"
  85. # ==== 加载模拟环境模型 ====
  86. # 初始化模型
  87. model_fp = TMPIncreaseModel()
  88. model_bw = TMPDecreaseModel()
  89. # 加载参数
  90. model_fp.load_state_dict(torch.load("uf_fp.pth"))
  91. model_bw.load_state_dict(torch.load("uf_bw.pth"))
  92. # 切换到推理模式
  93. model_fp.eval()
  94. model_bw.eval()
  95. def _delta_tmp(p, L_h: float) -> float:
  96. """
  97. 过滤时段TMP上升量:调用 uf_fp.pth 模型
  98. """
  99. return model_fp(p, L_h)
  100. def phi_bw_of(p, L_s: float, t_bw_s: float) -> float:
  101. """
  102. 物洗去除比例:调用 uf_bw.pth 模型
  103. """
  104. return model_bw(p, L_s, t_bw_s)
  105. def _tmp_after_ceb(p, L_s: float, t_bw_s: float) -> float:
  106. """
  107. 计算化学清洗(CEB)后的TMP,当前为恢复初始跨膜压差
  108. """
  109. return p.TMP0
  110. def _v_bw_m3(p, t_bw_s: float) -> float:
  111. """
  112. 物理反洗水耗
  113. """
  114. return float(p.q_bw_m3ph * (float(t_bw_s) / 3600.0))
  115. def simulate_one_supercycle(p: UFParams, L_s: float, t_bw_s: float):
  116. """
  117. 返回 (是否可行, 指标字典)
  118. - 支持动态CEB次数:48h固定间隔
  119. - 增加日均产水时间和吨水电耗
  120. """
  121. L_h = float(L_s) / 3600.0 # 小周期过滤时间(h)
  122. tmp = p.TMP0
  123. max_tmp_during_filtration = tmp
  124. max_residual_increase = 0.0
  125. # 小周期总时长(h) 小周期总时长 = 过滤时长 + 物洗时长
  126. t_small_cycle_h = (L_s + t_bw_s) / 3600.0
  127. # 计算超级周期内CEB次数 超级周期内CEB次数 = 48h / 小周期总时长
  128. k_bw_per_ceb = int(np.floor(p.T_ceb_interval_h / t_small_cycle_h))
  129. if k_bw_per_ceb < 1:
  130. k_bw_per_ceb = 1 # 至少一个小周期
  131. # ton水电耗查表
  132. energy_lookup = {
  133. 3600: 0.1034, 3660: 0.1031, 3720: 0.1029, 3780: 0.1026,
  134. 3840: 0.1023, 3900: 0.1021, 3960: 0.1019, 4020: 0.1017,
  135. 4080: 0.1015, 4140: 0.1012, 4200: 0.1011
  136. }
  137. for _ in range(k_bw_per_ceb):
  138. tmp_run_start = tmp
  139. # 过滤阶段TMP增长
  140. dtmp = _delta_tmp(p, L_h)
  141. tmp_peak = tmp_run_start + dtmp # 过滤阶段TMP峰值 = 过滤阶段TMP开始值 + 过滤阶段TMP上升量
  142. # 约束1:峰值不得超过硬上限
  143. if tmp_peak > p.TMP_max + 1e-12:
  144. return False, {"reason": "TMP_max violated during filtration", "TMP_peak": tmp_peak}
  145. if tmp_peak > max_tmp_during_filtration: # 如果过滤阶段TMP峰值超过当前最大值
  146. max_tmp_during_filtration = tmp_peak # 更新最大值
  147. # 物理反洗
  148. phi = phi_bw_of(p, L_s, t_bw_s) # 物洗去除比例
  149. tmp_after_bw = tmp_peak - phi * (tmp_peak - tmp_run_start) # 物理反洗后TMP = 过滤阶段TMP峰值 - 物洗去除比例 * (过滤阶段TMP峰值 - 过滤阶段TMP开始值)
  150. # 约束2:单次残余增量控制
  151. residual_inc = tmp_after_bw - tmp_run_start # 单次残余增量 = 物理反洗后TMP - 过滤阶段TMP开始值
  152. if residual_inc > p.dTMP + 1e-12: # 如果单次残余增量超过单次残余增量上限
  153. return False, {
  154. "reason": "residual TMP increase after BW exceeded dTMP", # 返回不可行
  155. "residual_increase": residual_inc, # 单次残余增量
  156. "limit_dTMP": p.dTMP
  157. }
  158. if residual_inc > max_residual_increase: # 如果单次残余增量超过当前最大值
  159. max_residual_increase = residual_inc # 更新最大值
  160. tmp = tmp_after_bw # 更新TMP
  161. # CEB
  162. tmp_after_ceb = p.TMP0 # 化学反洗后TMP
  163. # 体积与回收率
  164. V_feed_super = k_bw_per_ceb * p.q_UF * L_h # 进水体积 进水体积 = 超级周期内CEB次数 * 过滤流量 * 过滤时长
  165. V_loss_super = k_bw_per_ceb * _v_bw_m3(p, t_bw_s) + p.v_ceb_m3 # 损失体积 损失体积 = 物洗体积 + CEB用水体积
  166. V_net = max(0.0, V_feed_super - V_loss_super) # 净产水体积 净产水体积 = 进水体积 - 损失体积
  167. recovery = max(0.0, V_net / max(V_feed_super, 1e-12)) # 回收率 净产水体积 / 进水体积
  168. # 时间与净供水率
  169. T_super_h = k_bw_per_ceb * (L_s + t_bw_s) / 3600.0 + p.t_ceb_s / 3600.0 # 超循环时间 超循环时间 = 超级周期内CEB次数 * (过滤时长 + 物洗时长) / 3600 + CEB时长 / 3600
  170. net_delivery_rate_m3ph = V_net / max(T_super_h, 1e-12) # 净供水率 净产水体积 / 超循环时间
  171. # 贴边比例与硬限
  172. headroom_ratio = max_tmp_during_filtration / max(p.TMP_max, 1e-12) # 贴边比例 过滤时段TMP峰值 / 硬上限
  173. if headroom_ratio > p.headroom_hardcap + 1e-12: # 如果贴边比例超过硬上限
  174. return False, {"reason": "headroom hardcap exceeded", "headroom_ratio": headroom_ratio} # 返回不可行
  175. # —— 新增指标 1:日均产水时间(h/d) ——
  176. daily_prod_time_h = k_bw_per_ceb * L_h / T_super_h * 24.0 # 日均产水时间 日均产水时间 = 超级周期内CEB次数 * 过滤时长 / 超循环时间 * 24
  177. # —— 新增指标 2:吨水电耗(kWh/m³) ——
  178. closest_L = min(energy_lookup.keys(), key=lambda x: abs(x - L_s)) # 最接近的过滤时长
  179. ton_water_energy = energy_lookup[closest_L] # 吨水电耗 最接近的过滤时长对应的吨水电耗
  180. info = {
  181. "recovery": recovery, # 回收率
  182. "V_feed_super_m3": V_feed_super, # 进水体积
  183. "V_loss_super_m3": V_loss_super, # 损失体积
  184. "V_net_super_m3": V_net, # 净产水体积
  185. "supercycle_time_h": T_super_h, # 超循环时间
  186. "net_delivery_rate_m3ph": net_delivery_rate_m3ph, # 净供水率
  187. "max_TMP_during_filtration": max_tmp_during_filtration, # 过滤时段TMP峰值
  188. "max_residual_increase_per_run": max_residual_increase, # 单次残余增量最大值
  189. "phi_bw_effective": phi, # 物洗去除比例
  190. "TMP_after_ceb": tmp_after_ceb, # 物理反洗后TMP
  191. "headroom_ratio": headroom_ratio, # 贴边比例
  192. "daily_prod_time_h": daily_prod_time_h, # 日均产水时间
  193. "ton_water_energy_kWh_per_m3": ton_water_energy, # 吨水电耗
  194. "k_bw_per_ceb": k_bw_per_ceb # 超级周期内CEB次数
  195. }
  196. return True, info
  197. def _score(p: UFParams, rec: dict) -> float:
  198. """综合评分:越大越好。不同TMP0会改变max_TMP→改变惩罚→得到不同解。"""
  199. # 无量纲化净供水率
  200. rate_norm = rec["net_delivery_rate_m3ph"] / max(p.q_UF, 1e-12) # 无量纲化净供水率 净供水率 / 过滤流量 1000m3/h / 360m3/h = 2.7778
  201. headroom_penalty = (rec["max_TMP_during_filtration"] / max(p.TMP_max, 1e-12)) ** p.r_headroom # 贴边惩罚
  202. reward = (p.w_rec * rec["recovery"] + p.w_rate * rate_norm - p.w_headroom * headroom_penalty) # 奖励
  203. return reward
  204. def set_global_seed(seed: int):
  205. """固定全局随机种子,保证训练可复现"""
  206. random.seed(seed) # 随机种子
  207. np.random.seed(seed) # 随机种子
  208. torch.manual_seed(seed) # 随机种子
  209. torch.cuda.manual_seed_all(seed) # 如果使用GPU
  210. torch.backends.cudnn.deterministic = True # 确定性
  211. torch.backends.cudnn.benchmark = False # 不使用GPU
  212. class UFSuperCycleEnv(gym.Env):
  213. """超滤系统环境(超级周期级别决策)"""
  214. metadata = {"render_modes": ["human"]}
  215. def __init__(self, base_params, max_episode_steps: int = 10):
  216. super(UFSuperCycleEnv, self).__init__() # 初始化环境
  217. self.base_params = base_params # UFParams 实例
  218. self.current_params = copy.deepcopy(base_params) # UFParams 实例
  219. self.max_episode_steps = max_episode_steps # 最大步数
  220. self.current_step = 0 # 当前步数
  221. # 计算离散动作空间
  222. self.L_values = np.arange(
  223. self.base_params.L_min_s, # 过滤时长下限
  224. self.base_params.L_max_s + self.base_params.L_step_s, # 过滤时长上限
  225. self.base_params.L_step_s # 过滤时长步长
  226. )
  227. self.t_bw_values = np.arange(
  228. self.base_params.t_bw_min_s, # 物洗时长下限
  229. self.base_params.t_bw_max_s + self.base_params.t_bw_step_s, # 物洗时长上限
  230. self.base_params.t_bw_step_s # 物洗时长步长
  231. )
  232. self.num_L = len(self.L_values) # 过滤时长步数
  233. self.num_bw = len(self.t_bw_values) # 物洗时长步数
  234. # 单一离散动作空间,spaces.Discrete(n) 定义了一个包含 n 个离散动作或观测值的空间。这个空间包含从 0 到 n-1 的整数值
  235. self.action_space = spaces.Discrete(self.num_L * self.num_bw) # 动作空间,离散动作空间
  236. # 状态空间:归一化的[TMP0], 用于定义 连续的空间,通常用于表示那些具有连续值的观测空间或动作空间
  237. self.observation_space = spaces.Box(
  238. low=np.array([0.0], dtype=np.float32), # 单一维度,只有TMP0
  239. high=np.array([1.0], dtype=np.float32), # 单一维度,只有TMP0
  240. dtype=np.float32,
  241. shape=(1,) # 明确指定形状为1维
  242. )
  243. # 初始化状态
  244. self.reset(seed=None) # 重置环境
  245. def _get_obs(self):
  246. # 原始状态
  247. TMP0 = self.current_params.TMP0
  248. # 状态归一化
  249. TMP0_norm = (TMP0 - 0.01) / (0.05 - 0.01)
  250. return np.array([TMP0_norm], dtype=np.float32) # 状态
  251. def _get_action_values(self, action):
  252. """解析离散动作"""
  253. L_idx = action // self.num_bw # 过滤时长索引
  254. t_bw_idx = action % self.num_bw # 物洗时长索引
  255. return self.L_values[L_idx], self.t_bw_values[t_bw_idx] # 动作
  256. def reset(self, seed=None, options=None):
  257. """重置环境"""
  258. super().reset(seed=seed)
  259. # 随机初始化 TMP0
  260. self.current_params.TMP0 = np.random.uniform(0.01, 0.05)
  261. # 初始化步数
  262. self.current_step = 0
  263. return self._get_obs(), {} # Gymnasium要求返回(obs, info)
  264. def step(self, action):
  265. """执行一个超级周期"""
  266. self.current_step += 1
  267. # 解析动作 对应过滤时长和物洗时长
  268. L_s, t_bw_s = self._get_action_values(action)
  269. # 确保过滤时长和物洗时长在范围内 np.clip:限制在范围内
  270. L_s = np.clip(L_s, self.base_params.L_min_s, self.base_params.L_max_s)
  271. t_bw_s = np.clip(t_bw_s, self.base_params.t_bw_min_s, self.base_params.t_bw_max_s)
  272. # 记录当前状态 归一化状态
  273. current_obs = self._get_obs()
  274. # 模拟超级周期
  275. feasible, info = simulate_one_supercycle(self.current_params, L_s, t_bw_s)
  276. # 计算奖励
  277. if feasible:
  278. reward = _score(self.current_params, info)
  279. self.current_params.TMP0 = info["TMP_after_ceb"]
  280. terminated = False
  281. else:
  282. reward = -20
  283. terminated = True
  284. # 检查是否达到最大步数
  285. truncated = self.current_step >= self.max_episode_steps
  286. # 获取新状态
  287. next_obs = self._get_obs()
  288. info["feasible"] = feasible
  289. info["step"] = self.current_step
  290. return next_obs, reward, terminated, truncated, info
  291. class UFEpisodeRecorder:
  292. """记录episode中的决策和结果"""
  293. def __init__(self):
  294. self.episode_data = [] # 记录episode中的决策和结果
  295. self.current_episode = []
  296. def record_step(self, obs, action, reward, done, info):
  297. """记录一步"""
  298. step_data = {
  299. "obs": obs.copy(), # 新状态
  300. "action": action.copy(), # 动作
  301. "reward": reward, # 奖励
  302. "done": done, # 是否终止
  303. "info": info.copy() if info else {} # 信息
  304. }
  305. self.current_episode.append(step_data) # 记录episode中的决策和结果
  306. if done:
  307. self.episode_data.append(self.current_episode) # 记录episode中的决策和结果
  308. self.current_episode = []
  309. def get_episode_stats(self, episode_idx=-1):
  310. """获取episode统计信息"""
  311. if not self.episode_data:
  312. return {}
  313. episode = self.episode_data[episode_idx] # 记录episode中的决策和结果
  314. total_reward = sum(step["reward"] for step in episode) # 总奖励
  315. avg_recovery = np.mean([step["info"].get("recovery", 0) for step in episode if "recovery" in step["info"]]) # 平均回收率
  316. feasible_steps = sum(1 for step in episode if step["info"].get("feasible", False)) # 可行的步数
  317. return {
  318. "total_reward": total_reward, # 总奖励
  319. "avg_recovery": avg_recovery, # 平均回收率
  320. "feasible_steps": feasible_steps, # 可行的步数
  321. "total_steps": len(episode) # 总步数
  322. }
  323. class UFTrainingCallback(BaseCallback):
  324. """
  325. PPO 训练回调,用于记录每一步的数据到 recorder。
  326. 相比原来的 RecordingCallback,更加合理和健壮:
  327. 1. 不依赖环境内部 last_* 属性
  328. 2. 使用 PPO 提供的 obs、actions、rewards、dones、infos
  329. 3. 自动处理 episode 结束时的统计
  330. """
  331. def __init__(self, recorder, verbose=0):
  332. super(UFTrainingCallback, self).__init__(verbose)
  333. self.recorder = recorder
  334. def _on_step(self) -> bool:
  335. try:
  336. new_obs = self.locals.get("new_obs") # 新状态
  337. actions = self.locals.get("actions") # 动作
  338. rewards = self.locals.get("rewards") # 奖励
  339. dones = self.locals.get("dones") # 是否终止
  340. infos = self.locals.get("infos") # 信息
  341. if len(new_obs) > 0:
  342. step_obs = new_obs[0] # 新状态
  343. step_action = actions[0] if actions is not None else None # 动作
  344. step_reward = rewards[0] if rewards is not None else 0.0 # 奖励
  345. step_done = dones[0] if dones is not None else False # 是否终止
  346. step_info = infos[0] if infos is not None else {} # 信息
  347. # 打印当前 step 的信息
  348. if self.verbose:
  349. print(f"[Step {self.num_timesteps}] 动作={step_action}, 奖励={step_reward:.3f}, Done={step_done}")
  350. # 记录数据
  351. self.recorder.record_step(
  352. obs=step_obs, # 新状态
  353. action=step_action, # 动作
  354. reward=step_reward, # 奖励
  355. done=step_done, # 是否终止
  356. info=step_info, # 信息
  357. )
  358. except Exception as e:
  359. if self.verbose:
  360. print(f"[Callback Error] {e}")
  361. return True
  362. class DQNTrainer:
  363. def __init__(self, env, params, callback=None):
  364. """
  365. 初始化 DQN 训练器
  366. :param env: 强化学习环境
  367. :param params: DQNParams 实例
  368. :param callback: 可选,训练回调器
  369. """
  370. self.env = env # 环境
  371. self.params = params # DQNParams 实例
  372. self.callback = callback # 训练回调器
  373. self.log_dir = self._create_log_dir() # 日志文件夹
  374. self.model = self._create_model() # 模型
  375. def _create_log_dir(self):
  376. """
  377. 自动生成日志文件夹名:包含核心超参数 + 时间戳
  378. """
  379. timestamp = time.strftime("%Y%m%d-%H%M%S") # 时间戳
  380. log_name = (
  381. f"DQN_lr{self.params.learning_rate}_buf{self.params.buffer_size}_bs{self.params.batch_size}" # 日志文件夹名
  382. f"_gamma{self.params.gamma}_exp{self.params.exploration_fraction}" # 日志文件夹名
  383. f"_{self.params.remark}_{timestamp}" # 日志文件夹名
  384. )
  385. log_dir = os.path.join("./uf_dqn_tensorboard", log_name) # 日志文件夹
  386. os.makedirs(log_dir, exist_ok=True) # 创建日志文件夹
  387. return log_dir
  388. def _create_model(self):
  389. """
  390. 根据参数创建 DQN 模型
  391. """
  392. return DQN(
  393. policy="MlpPolicy", # 策略网络
  394. env=self.env, # 环境
  395. learning_rate=self.params.learning_rate, # 学习率
  396. buffer_size=self.params.buffer_size, # 经验回放缓冲区大小
  397. learning_starts=self.params.learning_starts, # 学习开始前需要收集的步数
  398. batch_size=self.params.batch_size, # 每次从经验池中采样的样本数量
  399. gamma=self.params.gamma, # 折扣因子,越接近1越重视长期奖励
  400. train_freq=self.params.train_freq, # 每隔多少步训练一次
  401. target_update_interval=self.params.target_update_interval, # 目标网络更新间隔
  402. exploration_initial_eps=self.params.exploration_initial_eps, # 初始探索率 ε
  403. exploration_fraction=self.params.exploration_fraction, # 从初始ε衰减到最终ε所占的训练比例
  404. exploration_final_eps=self.params.exploration_final_eps, # 最终探索率 ε
  405. verbose=1,
  406. tensorboard_log=self.log_dir
  407. )
  408. def train(self, total_timesteps: int):
  409. """
  410. 训练 DQN 模型,支持自定义回调器
  411. """
  412. if self.callback:
  413. self.model.learn(total_timesteps=total_timesteps, callback=self.callback) # 支持自定义回调器
  414. else:
  415. self.model.learn(total_timesteps=total_timesteps) # 不支持自定义回调器
  416. print(f"模型训练完成,日志保存在:{self.log_dir}")
  417. def save(self, path=None):
  418. """
  419. 保存模型到指定路径
  420. """
  421. if path is None:
  422. path = os.path.join(self.log_dir, "dqn_model.zip") # 模型文件名
  423. self.model.save(path)
  424. print(f"模型已保存到:{path}")
  425. def load(self, path):
  426. """
  427. 从指定路径加载模型
  428. """
  429. self.model = DQN.load(path, env=self.env) # 加载模型
  430. print(f"模型已从 {path} 加载")
  431. def train_uf_rl_agent(params: UFParams, total_timesteps: int = 10000, seed: int = 2025):
  432. """训练超滤系统RL代理(固定随机种子)"""
  433. # === 1. 固定全局随机种子 ===
  434. set_global_seed(seed)
  435. # === 2. 创建回调器 ===
  436. recorder = UFEpisodeRecorder() # 记录每一步的数据
  437. callback = UFTrainingCallback(recorder, verbose=1) # 训练回调器
  438. # === 3. 创建环境并固定种子 ===
  439. def make_env():
  440. env = UFSuperCycleEnv(params) # 创建环境
  441. env = Monitor(env) # 监控环境
  442. return env
  443. env = DummyVecEnv([make_env]) # 创建环境 多进程
  444. # === 4. 定义DQN参数 ===
  445. dqn_params = DQNParams()
  446. # === 5. 创建训练器 ===
  447. trainer = DQNTrainer(env, dqn_params, callback=callback)
  448. # === 6. 训练模型 ===
  449. trainer.train(total_timesteps)
  450. # === 7. 保存模型 ===
  451. trainer.save()
  452. # === 8. 输出训练统计信息 ===
  453. stats = callback.recorder.get_episode_stats()
  454. print(f"训练完成 - 总奖励: {stats.get('total_reward', 0):.2f}, 平均回收率: {stats.get('avg_recovery', 0):.3f}")
  455. return trainer.model
  456. # 训练和测试示例
  457. if __name__ == "__main__":
  458. # 初始化参数
  459. params = UFParams()
  460. # 训练RL代理
  461. print("开始训练RL代理...")
  462. train_uf_rl_agent(params, total_timesteps=8000)