UF_models.py 963 B

123456789101112131415161718192021222324252627282930313233
  1. import torch
  2. import numpy as np
  3. # TMP 上升量模型
  4. class TMPIncreaseModel(torch.nn.Module):
  5. def __init__(self):
  6. super().__init__()
  7. def forward(self, p, L_h):
  8. return float(p.alpha * (p.q_UF ** p.belta) * L_h)
  9. # 反洗 TMP 去除模型
  10. class TMPDecreaseModel(torch.nn.Module):
  11. def __init__(self):
  12. super().__init__()
  13. def forward(self, p, L_s, t_bw_s):
  14. L = max(float(L_s), 1.0)
  15. t = max(float(t_bw_s), 1e-6)
  16. upper_L = p.phi_bw_min + (p.phi_bw_max - p.phi_bw_min) * np.exp(- L / p.L_ref_s)
  17. time_gain = 1.0 - np.exp(- (t / p.tau_bw_s) ** p.gamma_t)
  18. phi = upper_L * time_gain
  19. return float(np.clip(phi, 0.0, 0.999))
  20. if __name__ == "__main__":
  21. model_fp = TMPIncreaseModel()
  22. model_bw = TMPDecreaseModel()
  23. torch.save(model_fp.state_dict(), "uf_fp.pth")
  24. torch.save(model_bw.state_dict(), "uf_bw.pth")
  25. print("模型已安全保存为 uf_fp.pth、uf_bw.pth")