uf_resistance_models_load.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import os
  2. import torch
  3. from pathlib import Path
  4. from uf_train.env.uf_resistance_models_define import ResistanceDecreaseModel, ResistanceIncreaseModel
  5. # ==================== 膜阻力模型加载函数 ====================
  6. def load_resistance_models(phys):
  7. """
  8. 加载膜阻力预测模型(单例模式)
  9. 功能:
  10. - 加载预训练的膜阻力上升模型和下降模型
  11. - 使用全局变量实现单例模式,避免重复加载
  12. - 仅在首次调用时执行加载操作
  13. 返回:
  14. tuple: (resistance_model_fp, resistance_model_bw)
  15. - resistance_model_fp: 过滤阶段阻力上升模型
  16. - resistance_model_bw: 反洗阶段阻力下降模型
  17. 注意:
  18. - 模型文件必须与本脚本位于同一目录
  19. - 模型已设置为推理模式(eval),不会更新参数
  20. """
  21. # 声明全局变量(实现单例模式)
  22. global resistance_model_fp, resistance_model_bw
  23. # 检查模型是否已加载(避免重复加载)
  24. if "resistance_model_fp" in globals() and resistance_model_fp is not None:
  25. return resistance_model_fp, resistance_model_bw
  26. print("🔄 正在加载膜阻力模型...")
  27. # 初始化模型对象
  28. resistance_model_fp = ResistanceIncreaseModel(phys)
  29. resistance_model_bw = ResistanceDecreaseModel(phys)
  30. # 获取当前脚本所在目录
  31. base_dir = Path(__file__).resolve().parent
  32. # 构造模型文件路径
  33. fp_path = base_dir / "resistance_model_fp.pth" # 过滤阶段模型
  34. bw_path = base_dir / "resistance_model_bw.pth" # 反洗阶段模型
  35. # 检查模型文件是否存在
  36. assert fp_path.exists(), f"缺少膜阻力上升模型文件: {fp_path.name}"
  37. assert bw_path.exists(), f"缺少膜阻力下降模型文件: {bw_path.name}"
  38. # 加载模型权重(map_location="cpu" 确保在没有GPU的环境也能运行)
  39. resistance_model_fp.load_state_dict(torch.load(fp_path, map_location="cpu"))
  40. resistance_model_bw.load_state_dict(torch.load(bw_path, map_location="cpu"))
  41. # 设置为推理模式(禁用 dropout、batchnorm 等训练特性)
  42. resistance_model_fp.eval()
  43. resistance_model_bw.eval()
  44. print("✅ 膜阻力模型加载成功!")
  45. return resistance_model_fp, resistance_model_bw