dqn_params.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from dataclasses import dataclass
  2. # ==================== DQN超参数配置类 ====================
  3. @dataclass
  4. class DQNParams:
  5. """
  6. DQN 超参数配置类
  7. 功能:统一管理DQN算法的所有超参数
  8. 超参数说明:
  9. - learning_rate: 神经网络学习率,控制梯度下降的步长
  10. - buffer_size: 经验回放缓冲区大小,存储历史经验
  11. - learning_starts: 开始训练前先收集的经验数量(warm-up)
  12. - batch_size: 每次训练采样的batch大小
  13. - gamma: 折扣因子,权衡即时奖励和长期奖励
  14. - train_freq: 训练频率,每隔多少步训练一次
  15. - target_update_interval: 目标网络更新频率
  16. - tau: 软更新系数(soft update)
  17. - exploration_*: ε-贪心策略的探索率参数
  18. """
  19. # ========== 神经网络参数 ==========
  20. learning_rate: float = 1e-4
  21. # 学习率,控制神经网络权重更新的步长
  22. # 典型范围:1e-5 ~ 1e-3
  23. # 过大:训练不稳定;过小:收敛慢
  24. # ========== 经验回放参数 ==========
  25. buffer_size: int = 100000
  26. # 经验回放缓冲区大小(可存储的transition数量)
  27. # 作用:打破样本间的时间相关性,提高训练稳定性
  28. # 建议:至少存储几个完整episode的经验
  29. learning_starts: int = 10000
  30. # 开始训练前先收集的步数(预填充缓冲区)
  31. # 作用:确保缓冲区有足够的多样性样本再开始训练
  32. # 建议:设为buffer_size的10%-20%
  33. batch_size: int = 32
  34. # 每次训练从缓冲区采样的样本数量
  35. # 典型值:32, 64, 128, 256
  36. # 过大:显存占用高,训练慢;过小:梯度估计不准确
  37. # ========== 强化学习参数 ==========
  38. gamma: float = 0.95
  39. # 折扣因子(discount factor),γ ∈ [0, 1]
  40. # 作用:权衡即时奖励和长期奖励
  41. # γ=0:只考虑当前奖励(短视)
  42. # γ=1:完全考虑未来奖励(长视)
  43. # 通常设为0.9-0.99
  44. train_freq: int = 4
  45. # 训练频率:每收集多少步执行一次训练
  46. # 作用:平衡数据收集和网络更新
  47. # 典型值:1(每步训练)或4-16(批量训练)
  48. # ========== 目标网络参数 ==========
  49. target_update_interval: int = 1
  50. # 目标网络更新间隔(硬更新)
  51. # 作用:目标网络每隔多少次训练更新一次
  52. # 注:使用软更新(tau)时此参数通常设为1
  53. tau: float = 0.005
  54. # 软更新系数(soft update)
  55. # θ_target = τ×θ + (1-τ)×θ_target
  56. # τ=1:硬更新(完全复制)
  57. # τ<<1:软更新(平滑过渡,更稳定)
  58. # 典型值:0.001 - 0.01
  59. # ========== 探索策略参数(ε-greedy) ==========
  60. exploration_initial_eps: float = 1.0
  61. # 初始探索率 ε_0
  62. # ε=1:完全随机探索
  63. # ε=0:完全利用已学知识
  64. exploration_fraction: float = 0.3
  65. # 探索率衰减比例
  66. # 表示训练总步数的前30%进行ε衰减
  67. # 例:总共10万步,前3万步ε从1.0衰减到0.02
  68. exploration_final_eps: float = 0.02
  69. # 最终探索率 ε_final
  70. # 衰减结束后保持此值(保留小概率探索)
  71. # 典型值:0.01 - 0.05
  72. # ========== 日志参数 ==========
  73. remark: str = "default"
  74. # 实验备注,用于区分不同训练实验
  75. # 会自动添加到TensorBoard日志目录名中