""" dqn_config_loader.py DQN 配置加载器,负责从YAML文件加载DQN超参数。 与dqn_params.py同级,保持参数类的纯净性。 """ import yaml from pathlib import Path from typing import Dict, Any, Union from rl_model.DQN.dqn_model.dqn_params import DQNParams class DQNConfigLoader: """DQN配置加载器,从YAML加载并创建DQNParams实例""" # DQNParams的字段列表(用于过滤) DQN_FIELDS = [ 'learning_rate', 'buffer_size', 'learning_starts', 'batch_size', 'gamma', 'train_freq', 'target_update_interval', 'tau', 'exploration_initial_eps', 'exploration_fraction', 'exploration_final_eps', 'remark' ] def __init__(self, config_path: Union[str, Path]): """ 初始化DQN配置加载器 Args: config_path: YAML配置文件路径 """ self.config_path = Path(config_path) self._config = None @property def config(self) -> Dict[str, Any]: """懒加载配置""" if self._config is None: self._config = self._load_yaml() return self._config def _load_yaml(self) -> Dict[str, Any]: """加载YAML文件""" if not self.config_path.exists(): raise FileNotFoundError(f"DQN config file not found: {self.config_path}") with open(self.config_path, 'r', encoding='utf-8') as f: return yaml.safe_load(f) def _filter_dqn_params(self, raw_config: Dict[str, Any]) -> Dict[str, Any]: """ 过滤出DQNParams需要的参数 Args: raw_config: 原始配置字典 Returns: 只包含DQNParams字段的字典 """ return {k: v for k, v in raw_config.items() if k in self.DQN_FIELDS} def load_params(self) -> DQNParams: """ 从YAML配置加载DQN参数 Returns: DQNParams实例 """ filtered_config = self._filter_dqn_params(self.config) return DQNParams(**filtered_config) def validate_config(self) -> bool: """ 验证配置文件 Returns: bool: 验证通过返回True """ # 检查是否有未知参数(可选) unknown_params = set(self.config.keys()) - set(self.DQN_FIELDS) if unknown_params: print(f"⚠️ Warning: Unknown parameters in DQN config: {unknown_params}") # 检查必需参数(所有参数都有默认值,所以都是可选的) print("✅ DQN config loaded successfully") return True def print_config_summary(self): """打印配置摘要""" params = self.load_params() print("\n" + "=" * 50) print("🤖 DQN 超参数配置(从YAML加载)") print("=" * 50) print(f"学习率 (learning_rate): {params.learning_rate}") print(f"缓冲区大小 (buffer_size): {params.buffer_size}") print(f"预热步数 (learning_starts): {params.learning_starts}") print(f"批次大小 (batch_size): {params.batch_size}") print(f"折扣因子 (gamma): {params.gamma}") print(f"训练频率 (train_freq): {params.train_freq}") print(f"目标网络更新间隔: {params.target_update_interval}") print(f"软更新系数 (tau): {params.tau}") print(f"初始探索率: {params.exploration_initial_eps}") print(f"探索率衰减比例: {params.exploration_fraction}") print(f"最终探索率: {params.exploration_final_eps}") print(f"实验备注: {params.remark}") print("=" * 50) # ========== 便捷函数 ========== def load_dqn_config(config_path: Union[str, Path]) -> DQNParams: """ 便捷函数:从YAML文件加载DQN配置 Args: config_path: 配置文件路径 Returns: DQNParams实例 """ loader = DQNConfigLoader(config_path) return loader.load_params() def load_dqn_config_with_validation(config_path: Union[str, Path]) -> DQNParams: """ 加载并验证DQN配置 Args: config_path: 配置文件路径 Returns: DQNParams实例 """ loader = DQNConfigLoader(config_path) loader.validate_config() return loader.load_params() # ========== 测试代码 ========== if __name__ == "__main__": # 测试配置加载 from pathlib import Path # 假设配置文件在项目根目录的config文件夹中 current_dir = Path(__file__).parent project_root = current_dir.parent.parent # uf-rl default_config = Path(__file__).parent.parent.parent.parent / "config" / "xishan_dqn_config.yaml" if default_config.exists(): print(f"Testing DQN config loading from: {default_config}") # 测试加载器 loader = DQNConfigLoader(default_config) # 验证配置 loader.validate_config() # 加载参数 dqn_params = loader.load_params() # 打印摘要 loader.print_config_summary() print("\n✅ DQN config loaded successfully!") print(f" Learning rate: {dqn_params.learning_rate}") print(f" Buffer size: {dqn_params.buffer_size}") print(f" Remark: {dqn_params.remark}") else: print(f"Config file not found: {default_config}") print("Please create a DQN config file first.")