| 123456789101112131415161718192021222324252627282930313233 |
- import torch
- import numpy as np
- # TMP 上升量模型
- class TMPIncreaseModel(torch.nn.Module):
- def __init__(self):
- super().__init__()
- def forward(self, p, L_h):
- return float(p.alpha * (p.q_UF ** p.belta) * L_h)
- # 反洗 TMP 去除模型
- class TMPDecreaseModel(torch.nn.Module):
- def __init__(self):
- super().__init__()
- def forward(self, p, L_s, t_bw_s):
- L = max(float(L_s), 1.0)
- t = max(float(t_bw_s), 1e-6)
- upper_L = p.phi_bw_min + (p.phi_bw_max - p.phi_bw_min) * np.exp(- L / p.L_ref_s)
- time_gain = 1.0 - np.exp(- (t / p.tau_bw_s) ** p.gamma_t)
- phi = upper_L * time_gain
- return float(np.clip(phi, 0.0, 0.999))
- if __name__ == "__main__":
- model_fp = TMPIncreaseModel()
- model_bw = TMPDecreaseModel()
- torch.save(model_fp.state_dict(), "uf_fp.pth")
- torch.save(model_bw.state_dict(), "uf_bw.pth")
- print("模型已安全保存为 uf_fp.pth、uf_bw.pth")
|