import torch import numpy as np # TMP 上升量模型 class TMPIncreaseModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, p, L_h): return float(p.alpha * (p.q_UF ** p.belta) * L_h) # 反洗 TMP 去除模型 class TMPDecreaseModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, p, L_s, t_bw_s): L = max(float(L_s), 1.0) t = max(float(t_bw_s), 1e-6) upper_L = p.phi_bw_min + (p.phi_bw_max - p.phi_bw_min) * np.exp(- L / p.L_ref_s) time_gain = 1.0 - np.exp(- (t / p.tau_bw_s) ** p.gamma_t) phi = upper_L * time_gain return float(np.clip(phi, 0.0, 0.999)) if __name__ == "__main__": model_fp = TMPIncreaseModel() model_bw = TMPDecreaseModel() torch.save(model_fp.state_dict(), "uf_fp.pth") torch.save(model_bw.state_dict(), "uf_bw.pth") print("模型已安全保存为 uf_fp.pth、uf_bw.pth")