| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- import torch
- from pathlib import Path
- from env.uf_resistance_models_define import ResistanceDecreaseModel, ResistanceIncreaseModel
- # ==================== 膜阻力模型加载函数 ====================
- def load_resistance_models(phys):
- """
- 加载膜阻力预测模型(单例模式)
- 功能:
- - 加载预训练的膜阻力上升模型和下降模型
- - 使用全局变量实现单例模式,避免重复加载
- - 仅在首次调用时执行加载操作
- 返回:
- tuple: (resistance_model_fp, resistance_model_bw)
- - resistance_model_fp: 过滤阶段阻力上升模型
- - resistance_model_bw: 反洗阶段阻力下降模型
- 注意:
- - 模型文件必须与本脚本位于同一目录
- - 模型已设置为推理模式(eval),不会更新参数
- """
- # 声明全局变量(实现单例模式)
- global resistance_model_fp, resistance_model_bw
- # 检查模型是否已加载(避免重复加载)
- if "resistance_model_fp" in globals() and resistance_model_fp is not None:
- return resistance_model_fp, resistance_model_bw
- print("🔄 正在加载膜阻力模型...")
- # 初始化模型对象
- resistance_model_fp = ResistanceIncreaseModel(phys)
- resistance_model_bw = ResistanceDecreaseModel(phys)
- # 获取当前脚本所在目录
- base_dir = Path(__file__).resolve().parent
- # 构造模型文件路径
- fp_path = base_dir / "resistance_model_fp.pth" # 过滤阶段模型
- bw_path = base_dir / "resistance_model_bw.pth" # 反洗阶段模型
- # 检查模型文件是否存在
- assert fp_path.exists(), f"缺少膜阻力上升模型文件: {fp_path.name}"
- assert bw_path.exists(), f"缺少膜阻力下降模型文件: {bw_path.name}"
- # 加载模型权重(map_location="cpu" 确保在没有GPU的环境也能运行)
- resistance_model_fp.load_state_dict(torch.load(fp_path, map_location="cpu"))
- resistance_model_bw.load_state_dict(torch.load(bw_path, map_location="cpu"))
- # 设置为推理模式(禁用 dropout、batchnorm 等训练特性)
- resistance_model_fp.eval()
- resistance_model_bw.eval()
- print("✅ 膜阻力模型加载成功!")
- return resistance_model_fp, resistance_model_bw
|