| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- import numpy as np
- import matplotlib.pyplot as plt
- # 训练进度
- progress = np.linspace(0, 1, 200)
- # -----------------------------
- # 参数设置(可调)
- # -----------------------------
- alpha = 0.5 # 虚拟工况最终最大占比
- k = 10.0 # Sigmoid 陡峭程度
- p0 = 0.7 # 虚拟工况启动拐点
- beta = 0.5 # 扰动工况线性增长系数
- # -----------------------------
- # 权重定义
- # -----------------------------
- # 虚拟工况(非线性,后期快速增长)
- w_virtual = alpha / (1.0 + np.exp(-k * (progress - p0)))
- # 扰动工况(线性增长)
- w_perturb = beta * progress
- # 真实工况(剩余比例)
- w_real = 1.0 - w_virtual - w_perturb
- w_real = np.clip(w_real, 0.0, 1.0) # 数值安全
- # -----------------------------
- # 扰动幅度
- # -----------------------------
- perturb_scale = 0.02 + 0.04 * progress
- # -----------------------------
- # 绘图
- # -----------------------------
- fig, ax1 = plt.subplots(figsize=(8, 5))
- ax1.plot(progress, w_real, label="w_real", linewidth=2)
- ax1.plot(progress, w_perturb, label="w_perturb", linewidth=2)
- ax1.plot(progress, w_virtual, label="w_virtual", linewidth=2)
- ax1.set_xlabel("Training Progress")
- ax1.set_ylabel("Sampling Weights")
- ax1.set_ylim(0, 1.05)
- ax1.grid(True, linestyle="--", alpha=0.5)
- ax1.legend(loc="upper left")
- # 第二纵轴:扰动幅度
- ax2 = ax1.twinx()
- ax2.plot(progress, perturb_scale, label="perturb_scale",
- linestyle="--", linewidth=2)
- ax2.set_ylabel("Perturb Scale")
- ax2.set_ylim(0, 0.07)
- ax2.legend(loc="upper right")
- plt.title("Progressive Reset Sampling Strategy")
- plt.tight_layout()
- plt.show()
|