import os import torch from pathlib import Path from uf_train.env.uf_resistance_models_define import ResistanceDecreaseModel, ResistanceIncreaseModel # ==================== 膜阻力模型加载函数 ==================== def load_resistance_models(phys): """ 加载膜阻力预测模型(单例模式) 功能: - 加载预训练的膜阻力上升模型和下降模型 - 使用全局变量实现单例模式,避免重复加载 - 仅在首次调用时执行加载操作 返回: tuple: (resistance_model_fp, resistance_model_bw) - resistance_model_fp: 过滤阶段阻力上升模型 - resistance_model_bw: 反洗阶段阻力下降模型 注意: - 模型文件必须与本脚本位于同一目录 - 模型已设置为推理模式(eval),不会更新参数 """ # 声明全局变量(实现单例模式) global resistance_model_fp, resistance_model_bw # 检查模型是否已加载(避免重复加载) if "resistance_model_fp" in globals() and resistance_model_fp is not None: return resistance_model_fp, resistance_model_bw print("🔄 正在加载膜阻力模型...") # 初始化模型对象 resistance_model_fp = ResistanceIncreaseModel(phys) resistance_model_bw = ResistanceDecreaseModel(phys) # 获取当前脚本所在目录 base_dir = Path(__file__).resolve().parent # 构造模型文件路径 fp_path = base_dir / "resistance_model_fp.pth" # 过滤阶段模型 bw_path = base_dir / "resistance_model_bw.pth" # 反洗阶段模型 # 检查模型文件是否存在 assert fp_path.exists(), f"缺少膜阻力上升模型文件: {fp_path.name}" assert bw_path.exists(), f"缺少膜阻力下降模型文件: {bw_path.name}" # 加载模型权重(map_location="cpu" 确保在没有GPU的环境也能运行) resistance_model_fp.load_state_dict(torch.load(fp_path, map_location="cpu")) resistance_model_bw.load_state_dict(torch.load(bw_path, map_location="cpu")) # 设置为推理模式(禁用 dropout、batchnorm 等训练特性) resistance_model_fp.eval() resistance_model_bw.eval() print("✅ 膜阻力模型加载成功!") return resistance_model_fp, resistance_model_bw