reset_plot.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. # 训练进度
  4. progress = np.linspace(0, 1, 200)
  5. # -----------------------------
  6. # 参数设置(可调)
  7. # -----------------------------
  8. alpha = 0.5 # 虚拟工况最终最大占比
  9. k = 10.0 # Sigmoid 陡峭程度
  10. p0 = 0.7 # 虚拟工况启动拐点
  11. beta = 0.5 # 扰动工况线性增长系数
  12. # -----------------------------
  13. # 权重定义
  14. # -----------------------------
  15. # 虚拟工况(非线性,后期快速增长)
  16. w_virtual = alpha / (1.0 + np.exp(-k * (progress - p0)))
  17. # 扰动工况(线性增长)
  18. w_perturb = beta * progress
  19. # 真实工况(剩余比例)
  20. w_real = 1.0 - w_virtual - w_perturb
  21. w_real = np.clip(w_real, 0.0, 1.0) # 数值安全
  22. # -----------------------------
  23. # 扰动幅度
  24. # -----------------------------
  25. perturb_scale = 0.02 + 0.04 * progress
  26. # -----------------------------
  27. # 绘图
  28. # -----------------------------
  29. fig, ax1 = plt.subplots(figsize=(8, 5))
  30. ax1.plot(progress, w_real, label="w_real", linewidth=2)
  31. ax1.plot(progress, w_perturb, label="w_perturb", linewidth=2)
  32. ax1.plot(progress, w_virtual, label="w_virtual", linewidth=2)
  33. ax1.set_xlabel("Training Progress")
  34. ax1.set_ylabel("Sampling Weights")
  35. ax1.set_ylim(0, 1.05)
  36. ax1.grid(True, linestyle="--", alpha=0.5)
  37. ax1.legend(loc="upper left")
  38. # 第二纵轴:扰动幅度
  39. ax2 = ax1.twinx()
  40. ax2.plot(progress, perturb_scale, label="perturb_scale",
  41. linestyle="--", linewidth=2)
  42. ax2.set_ylabel("Perturb Scale")
  43. ax2.set_ylim(0, 0.07)
  44. ax2.legend(loc="upper right")
  45. plt.title("Progressive Reset Sampling Strategy")
  46. plt.tight_layout()
  47. plt.show()