| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- """
- dqn_config_loader.py
- DQN 配置加载器,负责从YAML文件加载DQN超参数。
- 与dqn_params.py同级,保持参数类的纯净性。
- """
- import yaml
- from pathlib import Path
- from typing import Dict, Any, Union, Optional
- from uf_train.rl_model.DQN.dqn_params import DQNParams
- class DQNConfigLoader:
- """DQN配置加载器,从YAML加载并创建DQNParams实例"""
- # DQNParams的字段列表(用于过滤)
- DQN_FIELDS = [
- 'learning_rate',
- 'buffer_size',
- 'learning_starts',
- 'batch_size',
- 'gamma',
- 'train_freq',
- 'target_update_interval',
- 'tau',
- 'exploration_initial_eps',
- 'exploration_fraction',
- 'exploration_final_eps',
- 'remark'
- ]
- def __init__(self, config_path: Union[str, Path]):
- """
- 初始化DQN配置加载器
- Args:
- config_path: YAML配置文件路径
- """
- self.config_path = Path(config_path)
- self._config = None
- @property
- def config(self) -> Dict[str, Any]:
- """懒加载配置"""
- if self._config is None:
- self._config = self._load_yaml()
- return self._config
- def _load_yaml(self) -> Dict[str, Any]:
- """加载YAML文件"""
- if not self.config_path.exists():
- raise FileNotFoundError(f"DQN config file not found: {self.config_path}")
- with open(self.config_path, 'r', encoding='utf-8') as f:
- return yaml.safe_load(f)
- def _filter_dqn_params(self, raw_config: Dict[str, Any]) -> Dict[str, Any]:
- """
- 过滤出DQNParams需要的参数
- Args:
- raw_config: 原始配置字典
- Returns:
- 只包含DQNParams字段的字典
- """
- return {k: v for k, v in raw_config.items() if k in self.DQN_FIELDS}
- def load_params(self) -> DQNParams:
- """
- 从YAML配置加载DQN参数
- Returns:
- DQNParams实例
- """
- filtered_config = self._filter_dqn_params(self.config)
- return DQNParams(**filtered_config)
- def validate_config(self) -> bool:
- """
- 验证配置文件
- Returns:
- bool: 验证通过返回True
- """
- # 检查是否有未知参数(可选)
- unknown_params = set(self.config.keys()) - set(self.DQN_FIELDS)
- if unknown_params:
- print(f"⚠️ Warning: Unknown parameters in DQN config: {unknown_params}")
- # 检查必需参数(所有参数都有默认值,所以都是可选的)
- print("✅ DQN config loaded successfully")
- return True
- def print_config_summary(self):
- """打印配置摘要"""
- params = self.load_params()
- print("\n" + "=" * 50)
- print("🤖 DQN 超参数配置(从YAML加载)")
- print("=" * 50)
- print(f"学习率 (learning_rate): {params.learning_rate}")
- print(f"缓冲区大小 (buffer_size): {params.buffer_size}")
- print(f"预热步数 (learning_starts): {params.learning_starts}")
- print(f"批次大小 (batch_size): {params.batch_size}")
- print(f"折扣因子 (gamma): {params.gamma}")
- print(f"训练频率 (train_freq): {params.train_freq}")
- print(f"目标网络更新间隔: {params.target_update_interval}")
- print(f"软更新系数 (tau): {params.tau}")
- print(f"初始探索率: {params.exploration_initial_eps}")
- print(f"探索率衰减比例: {params.exploration_fraction}")
- print(f"最终探索率: {params.exploration_final_eps}")
- print(f"实验备注: {params.remark}")
- print("=" * 50)
- # ========== 便捷函数 ==========
- def load_dqn_config(config_path: Union[str, Path]) -> DQNParams:
- """
- 便捷函数:从YAML文件加载DQN配置
- Args:
- config_path: 配置文件路径
- Returns:
- DQNParams实例
- """
- loader = DQNConfigLoader(config_path)
- return loader.load_params()
- def load_dqn_config_with_validation(config_path: Union[str, Path]) -> DQNParams:
- """
- 加载并验证DQN配置
- Args:
- config_path: 配置文件路径
- Returns:
- DQNParams实例
- """
- loader = DQNConfigLoader(config_path)
- loader.validate_config()
- return loader.load_params()
- # ========== 测试代码 ==========
- if __name__ == "__main__":
- # 测试配置加载
- import sys
- from pathlib import Path
- # 假设配置文件在项目根目录的config文件夹中
- current_dir = Path(__file__).parent
- project_root = current_dir.parent.parent # uf-rl
- default_config = Path(__file__).parent.parent.parent.parent / "config" / "xishan_dqn_config.yaml"
- if default_config.exists():
- print(f"Testing DQN config loading from: {default_config}")
- # 测试加载器
- loader = DQNConfigLoader(default_config)
- # 验证配置
- loader.validate_config()
- # 加载参数
- dqn_params = loader.load_params()
- # 打印摘要
- loader.print_config_summary()
- print("\n✅ DQN config loaded successfully!")
- print(f" Learning rate: {dqn_params.learning_rate}")
- print(f" Buffer size: {dqn_params.buffer_size}")
- print(f" Remark: {dqn_params.remark}")
- else:
- print(f"Config file not found: {default_config}")
- print("Please create a DQN config file first.")
|