| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257 |
- """
- env_config_loader.py
- 配置加载器,负责从YAML文件加载配置并实例化env_params中的参数类。
- """
- import yaml
- from pathlib import Path
- from typing import Dict, Any, Union, Optional, Type, TypeVar
- # 改为绝对导入
- from uf_train.env.env_params import (
- UFState,
- UFPhysicsParams,
- UFActionSpec,
- UFRewardParams,
- UFStateBounds
- )
- T = TypeVar('T')
- class EnvConfigLoader:
- """环境配置加载器,从YAML加载并实例化参数类"""
- # 参数类与配置节名的映射
- SECTION_TO_CLASS = {
- 'UFState': UFState,
- 'UFPhysicsParams': UFPhysicsParams,
- 'UFActionSpec': UFActionSpec,
- 'UFRewardParams': UFRewardParams,
- 'UFStateBounds': UFStateBounds,
- }
- def __init__(self, config_path: Union[str, Path]):
- """
- 初始化配置加载器
- Args:
- config_path: YAML配置文件路径
- """
- self.config_path = Path(config_path)
- self._config = None
- @property
- def config(self) -> Dict[str, Any]:
- """懒加载配置"""
- if self._config is None:
- self._config = self._load_yaml()
- return self._config
- def _load_yaml(self) -> Dict[str, Any]:
- """加载YAML文件"""
- if not self.config_path.exists():
- raise FileNotFoundError(f"Config file not found: {self.config_path}")
- with open(self.config_path, 'r', encoding='utf-8') as f:
- return yaml.safe_load(f)
- def _create_instance(self, cls: Type[T], section: str) -> T:
- """
- 从配置的指定部分创建参数类实例
- """
- if section not in self.config:
- raise KeyError(f"Section '{section}' not found in config. "
- f"Available sections: {list(self.config.keys())}")
- section_config = self.config[section]
- # 特殊处理:如果配置中有null值,需要转换为None
- # YAML中的null会自动转换为Python的None,所以不需要额外处理
- return cls(**section_config)
- # ========== 加载各个参数类 ==========
- def load_uf_state(self) -> UFState:
- """加载UFState(环境动态状态)"""
- return self._create_instance(UFState, 'UFState')
- def load_physics_params(self) -> UFPhysicsParams:
- """加载物理参数 UFPhysicsParams"""
- return self._create_instance(UFPhysicsParams, 'UFPhysicsParams')
- def load_action_spec(self) -> UFActionSpec:
- """加载动作规范 UFActionSpec"""
- return self._create_instance(UFActionSpec, 'UFActionSpec')
- def load_reward_params(self) -> UFRewardParams:
- """加载奖励参数 UFRewardParams"""
- return self._create_instance(UFRewardParams, 'UFRewardParams')
- def load_state_bounds(self) -> UFStateBounds:
- """加载状态边界 UFStateBounds"""
- return self._create_instance(UFStateBounds, 'UFStateBounds')
- def load_all(self) -> Dict[str, Any]:
- """加载所有参数类"""
- return {
- 'uf_state': self.load_uf_state(),
- 'physics': self.load_physics_params(),
- 'action': self.load_action_spec(),
- 'reward': self.load_reward_params(),
- 'state_bounds': self.load_state_bounds(),
- }
- # ========== 工具方法 ==========
- def get_raw_section(self, section: str) -> Dict[str, Any]:
- """获取配置的原始字典部分"""
- return self.config.get(section, {}).copy()
- def validate_config(self) -> bool:
- """
- 验证配置文件是否包含所有必需的配置节
- Returns:
- bool: 验证通过返回True,否则抛出异常
- """
- missing_sections = []
- for section in self.SECTION_TO_CLASS.keys():
- if section not in self.config:
- missing_sections.append(section)
- if missing_sections:
- raise ValueError(f"Missing required sections in config: {missing_sections}")
- print("✅ Config validation passed. All required sections present.")
- return True
- def print_config_summary(self):
- """打印配置摘要信息"""
- print("\n" + "=" * 50)
- print("📋 配置加载摘要")
- print("=" * 50)
- for section in self.SECTION_TO_CLASS.keys():
- if section in self.config:
- print(f"✅ {section}: {len(self.config[section])} parameters")
- else:
- print(f"❌ {section}: MISSING")
- print("=" * 50 + "\n")
- # ========== 便捷函数 ==========
- def load_env_config(config_path: Union[str, Path]) -> Dict[str, Any]:
- """
- 便捷函数:一次性加载所有环境配置
- Args:
- config_path: 配置文件路径
- Returns:
- 包含所有参数类实例的字典
- """
- loader = EnvConfigLoader(config_path)
- return loader.load_all()
- def load_single_config(config_path: Union[str, Path], section: str):
- """
- 便捷函数:加载单个配置节
- Args:
- config_path: 配置文件路径
- section: 配置部分名称 (UFState, UFPhysicsParams, UFActionSpec, UFRewardParams, UFStateBounds)
- Returns:
- 对应的参数类实例
- """
- loader = EnvConfigLoader(config_path)
- mapping = {
- 'UFState': loader.load_uf_state,
- 'UFPhysicsParams': loader.load_physics_params,
- 'UFActionSpec': loader.load_action_spec,
- 'UFRewardParams': loader.load_reward_params,
- 'UFStateBounds': loader.load_state_bounds,
- }
- if section not in mapping:
- raise ValueError(f"Unknown section: {section}. "
- f"Available: {list(mapping.keys())}")
- return mapping[section]()
- def create_env_params_from_yaml(config_path: Union[str, Path]) -> tuple:
- """
- 从YAML文件创建所有环境参数,返回元组方便解包
- Args:
- config_path: 配置文件路径
- Returns:
- (uf_state, physics_params, action_spec, reward_params, state_bounds)
- """
- loader = EnvConfigLoader(config_path)
- return (
- loader.load_uf_state(),
- loader.load_physics_params(),
- loader.load_action_spec(),
- loader.load_reward_params(),
- loader.load_state_bounds(),
- )
- # ========== 测试 ==========
- if __name__ == "__main__":
- # 测试配置加载
- import sys
- from pathlib import Path
- # 假设配置文件在当前目录的上一级config文件夹中
- default_config = Path(__file__).parent.parent.parent / "config" / "xishan_env_config.yaml"
- if default_config.exists():
- print(f"Testing config loading from: {default_config}")
- # 测试加载器
- loader = EnvConfigLoader(default_config)
- # 验证配置
- loader.validate_config()
- loader.print_config_summary()
- # 测试加载各个参数
- try:
- uf_state = loader.load_uf_state()
- print(f"UFState: q_UF={uf_state.q_UF}, TMP={uf_state.TMP}")
- physics = loader.load_physics_params()
- print(f"UFPhysicsParams: A={physics.A}, T_hard_limit={physics.global_TMP_hard_limit}")
- action = loader.load_action_spec()
- print(f"UFActionSpec: L_range=[{action.L_min_s}, {action.L_max_s}]")
- reward = loader.load_reward_params()
- print(f"UFRewardParams: w_tmp={reward.w_tmp}, residual_ref_ratio={reward.residual_ref_ratio}")
- bounds = loader.load_state_bounds()
- print(f"UFStateBounds: TMP0_range=[{bounds.TMP0_min}, {bounds.TMP0_max}]")
- # 测试一次性加载所有
- all_params = loader.load_all()
- print(f"\nAll {len(all_params)} parameter classes loaded successfully")
- except Exception as e:
- print(f"Error loading config: {e}")
- import traceback
- traceback.print_exc()
- else:
- print(f"Config file not found: {default_config}")
- print("Please specify config path manually for testing.")
|