uf_resistance_models_load.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import torch
  2. from pathlib import Path
  3. from env.uf_resistance_models_define import ResistanceDecreaseModel, ResistanceIncreaseModel
  4. # ==================== 膜阻力模型加载函数 ====================
  5. def load_resistance_models(phys):
  6. """
  7. 加载膜阻力预测模型(单例模式)
  8. 功能:
  9. - 加载预训练的膜阻力上升模型和下降模型
  10. - 使用全局变量实现单例模式,避免重复加载
  11. - 仅在首次调用时执行加载操作
  12. 返回:
  13. tuple: (resistance_model_fp, resistance_model_bw)
  14. - resistance_model_fp: 过滤阶段阻力上升模型
  15. - resistance_model_bw: 反洗阶段阻力下降模型
  16. 注意:
  17. - 模型文件必须与本脚本位于同一目录
  18. - 模型已设置为推理模式(eval),不会更新参数
  19. """
  20. # 声明全局变量(实现单例模式)
  21. global resistance_model_fp, resistance_model_bw
  22. # 检查模型是否已加载(避免重复加载)
  23. if "resistance_model_fp" in globals() and resistance_model_fp is not None:
  24. return resistance_model_fp, resistance_model_bw
  25. print("🔄 正在加载膜阻力模型...")
  26. # 初始化模型对象
  27. resistance_model_fp = ResistanceIncreaseModel(phys)
  28. resistance_model_bw = ResistanceDecreaseModel(phys)
  29. # 获取当前脚本所在目录
  30. base_dir = Path(__file__).resolve().parent
  31. # 构造模型文件路径
  32. fp_path = base_dir / "resistance_model_fp.pth" # 过滤阶段模型
  33. bw_path = base_dir / "resistance_model_bw.pth" # 反洗阶段模型
  34. # 检查模型文件是否存在
  35. assert fp_path.exists(), f"缺少膜阻力上升模型文件: {fp_path.name}"
  36. assert bw_path.exists(), f"缺少膜阻力下降模型文件: {bw_path.name}"
  37. # 加载模型权重(map_location="cpu" 确保在没有GPU的环境也能运行)
  38. resistance_model_fp.load_state_dict(torch.load(fp_path, map_location="cpu"))
  39. resistance_model_bw.load_state_dict(torch.load(bw_path, map_location="cpu"))
  40. # 设置为推理模式(禁用 dropout、batchnorm 等训练特性)
  41. resistance_model_fp.eval()
  42. resistance_model_bw.eval()
  43. print("✅ 膜阻力模型加载成功!")
  44. return resistance_model_fp, resistance_model_bw