dqn_config_loader.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. """
  2. dqn_config_loader.py
  3. DQN 配置加载器,负责从YAML文件加载DQN超参数。
  4. 与dqn_params.py同级,保持参数类的纯净性。
  5. """
  6. import yaml
  7. from pathlib import Path
  8. from typing import Dict, Any, Union, Optional
  9. from uf_train.rl_model.DQN.dqn_params import DQNParams
  10. class DQNConfigLoader:
  11. """DQN配置加载器,从YAML加载并创建DQNParams实例"""
  12. # DQNParams的字段列表(用于过滤)
  13. DQN_FIELDS = [
  14. 'learning_rate',
  15. 'buffer_size',
  16. 'learning_starts',
  17. 'batch_size',
  18. 'gamma',
  19. 'train_freq',
  20. 'target_update_interval',
  21. 'tau',
  22. 'exploration_initial_eps',
  23. 'exploration_fraction',
  24. 'exploration_final_eps',
  25. 'remark'
  26. ]
  27. def __init__(self, config_path: Union[str, Path]):
  28. """
  29. 初始化DQN配置加载器
  30. Args:
  31. config_path: YAML配置文件路径
  32. """
  33. self.config_path = Path(config_path)
  34. self._config = None
  35. @property
  36. def config(self) -> Dict[str, Any]:
  37. """懒加载配置"""
  38. if self._config is None:
  39. self._config = self._load_yaml()
  40. return self._config
  41. def _load_yaml(self) -> Dict[str, Any]:
  42. """加载YAML文件"""
  43. if not self.config_path.exists():
  44. raise FileNotFoundError(f"DQN config file not found: {self.config_path}")
  45. with open(self.config_path, 'r', encoding='utf-8') as f:
  46. return yaml.safe_load(f)
  47. def _filter_dqn_params(self, raw_config: Dict[str, Any]) -> Dict[str, Any]:
  48. """
  49. 过滤出DQNParams需要的参数
  50. Args:
  51. raw_config: 原始配置字典
  52. Returns:
  53. 只包含DQNParams字段的字典
  54. """
  55. return {k: v for k, v in raw_config.items() if k in self.DQN_FIELDS}
  56. def load_params(self) -> DQNParams:
  57. """
  58. 从YAML配置加载DQN参数
  59. Returns:
  60. DQNParams实例
  61. """
  62. filtered_config = self._filter_dqn_params(self.config)
  63. return DQNParams(**filtered_config)
  64. def validate_config(self) -> bool:
  65. """
  66. 验证配置文件
  67. Returns:
  68. bool: 验证通过返回True
  69. """
  70. # 检查是否有未知参数(可选)
  71. unknown_params = set(self.config.keys()) - set(self.DQN_FIELDS)
  72. if unknown_params:
  73. print(f"⚠️ Warning: Unknown parameters in DQN config: {unknown_params}")
  74. # 检查必需参数(所有参数都有默认值,所以都是可选的)
  75. print("✅ DQN config loaded successfully")
  76. return True
  77. def print_config_summary(self):
  78. """打印配置摘要"""
  79. params = self.load_params()
  80. print("\n" + "=" * 50)
  81. print("🤖 DQN 超参数配置(从YAML加载)")
  82. print("=" * 50)
  83. print(f"学习率 (learning_rate): {params.learning_rate}")
  84. print(f"缓冲区大小 (buffer_size): {params.buffer_size}")
  85. print(f"预热步数 (learning_starts): {params.learning_starts}")
  86. print(f"批次大小 (batch_size): {params.batch_size}")
  87. print(f"折扣因子 (gamma): {params.gamma}")
  88. print(f"训练频率 (train_freq): {params.train_freq}")
  89. print(f"目标网络更新间隔: {params.target_update_interval}")
  90. print(f"软更新系数 (tau): {params.tau}")
  91. print(f"初始探索率: {params.exploration_initial_eps}")
  92. print(f"探索率衰减比例: {params.exploration_fraction}")
  93. print(f"最终探索率: {params.exploration_final_eps}")
  94. print(f"实验备注: {params.remark}")
  95. print("=" * 50)
  96. # ========== 便捷函数 ==========
  97. def load_dqn_config(config_path: Union[str, Path]) -> DQNParams:
  98. """
  99. 便捷函数:从YAML文件加载DQN配置
  100. Args:
  101. config_path: 配置文件路径
  102. Returns:
  103. DQNParams实例
  104. """
  105. loader = DQNConfigLoader(config_path)
  106. return loader.load_params()
  107. def load_dqn_config_with_validation(config_path: Union[str, Path]) -> DQNParams:
  108. """
  109. 加载并验证DQN配置
  110. Args:
  111. config_path: 配置文件路径
  112. Returns:
  113. DQNParams实例
  114. """
  115. loader = DQNConfigLoader(config_path)
  116. loader.validate_config()
  117. return loader.load_params()
  118. # ========== 测试代码 ==========
  119. if __name__ == "__main__":
  120. # 测试配置加载
  121. import sys
  122. from pathlib import Path
  123. # 假设配置文件在项目根目录的config文件夹中
  124. current_dir = Path(__file__).parent
  125. project_root = current_dir.parent.parent # uf-rl
  126. default_config = Path(__file__).parent.parent.parent.parent / "config" / "xishan_dqn_config.yaml"
  127. if default_config.exists():
  128. print(f"Testing DQN config loading from: {default_config}")
  129. # 测试加载器
  130. loader = DQNConfigLoader(default_config)
  131. # 验证配置
  132. loader.validate_config()
  133. # 加载参数
  134. dqn_params = loader.load_params()
  135. # 打印摘要
  136. loader.print_config_summary()
  137. print("\n✅ DQN config loaded successfully!")
  138. print(f" Learning rate: {dqn_params.learning_rate}")
  139. print(f" Buffer size: {dqn_params.buffer_size}")
  140. print(f" Remark: {dqn_params.remark}")
  141. else:
  142. print(f"Config file not found: {default_config}")
  143. print("Please create a DQN config file first.")