import numpy as np import matplotlib.pyplot as plt # 训练进度 progress = np.linspace(0, 1, 200) # ----------------------------- # 参数设置(可调) # ----------------------------- alpha = 0.5 # 虚拟工况最终最大占比 k = 10.0 # Sigmoid 陡峭程度 p0 = 0.7 # 虚拟工况启动拐点 beta = 0.5 # 扰动工况线性增长系数 # ----------------------------- # 权重定义 # ----------------------------- # 虚拟工况(非线性,后期快速增长) w_virtual = alpha / (1.0 + np.exp(-k * (progress - p0))) # 扰动工况(线性增长) w_perturb = beta * progress # 真实工况(剩余比例) w_real = 1.0 - w_virtual - w_perturb w_real = np.clip(w_real, 0.0, 1.0) # 数值安全 # ----------------------------- # 扰动幅度 # ----------------------------- perturb_scale = 0.02 + 0.04 * progress # ----------------------------- # 绘图 # ----------------------------- fig, ax1 = plt.subplots(figsize=(8, 5)) ax1.plot(progress, w_real, label="w_real", linewidth=2) ax1.plot(progress, w_perturb, label="w_perturb", linewidth=2) ax1.plot(progress, w_virtual, label="w_virtual", linewidth=2) ax1.set_xlabel("Training Progress") ax1.set_ylabel("Sampling Weights") ax1.set_ylim(0, 1.05) ax1.grid(True, linestyle="--", alpha=0.5) ax1.legend(loc="upper left") # 第二纵轴:扰动幅度 ax2 = ax1.twinx() ax2.plot(progress, perturb_scale, label="perturb_scale", linestyle="--", linewidth=2) ax2.set_ylabel("Perturb Scale") ax2.set_ylim(0, 0.07) ax2.legend(loc="upper right") plt.title("Progressive Reset Sampling Strategy") plt.tight_layout() plt.show()