""" env_config_loader.py 配置加载器,负责从YAML文件加载配置并实例化env_params中的参数类。 """ import yaml from pathlib import Path from typing import Dict, Any, Union, Type, TypeVar # 改为绝对导入 from env.env_params import ( UFState, UFPhysicsParams, UFActionSpec, UFRewardParams, UFStateBounds ) T = TypeVar('T') class EnvConfigLoader: """环境配置加载器,从YAML加载并实例化参数类""" # 参数类与配置节名的映射 SECTION_TO_CLASS = { 'UFState': UFState, 'UFPhysicsParams': UFPhysicsParams, 'UFActionSpec': UFActionSpec, 'UFRewardParams': UFRewardParams, 'UFStateBounds': UFStateBounds, } def __init__(self, config_path: Union[str, Path]): """ 初始化配置加载器 Args: config_path: YAML配置文件路径 """ self.config_path = Path(config_path) self._config = None @property def config(self) -> Dict[str, Any]: """懒加载配置""" if self._config is None: self._config = self._load_yaml() return self._config def _load_yaml(self) -> Dict[str, Any]: """加载YAML文件""" if not self.config_path.exists(): raise FileNotFoundError(f"Config file not found: {self.config_path}") with open(self.config_path, 'r', encoding='utf-8') as f: return yaml.safe_load(f) def _create_instance(self, cls: Type[T], section: str) -> T: """ 从配置的指定部分创建参数类实例 """ if section not in self.config: raise KeyError(f"Section '{section}' not found in config. " f"Available sections: {list(self.config.keys())}") section_config = self.config[section] # 特殊处理:如果配置中有null值,需要转换为None # YAML中的null会自动转换为Python的None,所以不需要额外处理 return cls(**section_config) # ========== 加载各个参数类 ========== def load_uf_state(self) -> UFState: """加载UFState(环境动态状态)""" return self._create_instance(UFState, 'UFState') def load_physics_params(self) -> UFPhysicsParams: """加载物理参数 UFPhysicsParams""" return self._create_instance(UFPhysicsParams, 'UFPhysicsParams') def load_action_spec(self) -> UFActionSpec: """加载动作规范 UFActionSpec""" return self._create_instance(UFActionSpec, 'UFActionSpec') def load_reward_params(self) -> UFRewardParams: """加载奖励参数 UFRewardParams""" return self._create_instance(UFRewardParams, 'UFRewardParams') def load_state_bounds(self) -> UFStateBounds: """加载状态边界 UFStateBounds""" return self._create_instance(UFStateBounds, 'UFStateBounds') def load_all(self) -> Dict[str, Any]: """加载所有参数类""" return { 'uf_state': self.load_uf_state(), 'physics': self.load_physics_params(), 'action': self.load_action_spec(), 'reward': self.load_reward_params(), 'state_bounds': self.load_state_bounds(), } # ========== 工具方法 ========== def get_raw_section(self, section: str) -> Dict[str, Any]: """获取配置的原始字典部分""" return self.config.get(section, {}).copy() def validate_config(self) -> bool: """ 验证配置文件是否包含所有必需的配置节 Returns: bool: 验证通过返回True,否则抛出异常 """ missing_sections = [] for section in self.SECTION_TO_CLASS.keys(): if section not in self.config: missing_sections.append(section) if missing_sections: raise ValueError(f"Missing required sections in config: {missing_sections}") print("✅ Config validation passed. All required sections present.") return True def print_config_summary(self): """打印配置摘要信息""" print("\n" + "=" * 50) print("📋 配置加载摘要") print("=" * 50) for section in self.SECTION_TO_CLASS.keys(): if section in self.config: print(f"✅ {section}: {len(self.config[section])} parameters") else: print(f"❌ {section}: MISSING") print("=" * 50 + "\n") # ========== 便捷函数 ========== def load_env_config(config_path: Union[str, Path]) -> Dict[str, Any]: """ 便捷函数:一次性加载所有环境配置 Args: config_path: 配置文件路径 Returns: 包含所有参数类实例的字典 """ loader = EnvConfigLoader(config_path) return loader.load_all() def load_single_config(config_path: Union[str, Path], section: str): """ 便捷函数:加载单个配置节 Args: config_path: 配置文件路径 section: 配置部分名称 (UFState, UFPhysicsParams, UFActionSpec, UFRewardParams, UFStateBounds) Returns: 对应的参数类实例 """ loader = EnvConfigLoader(config_path) mapping = { 'UFState': loader.load_uf_state, 'UFPhysicsParams': loader.load_physics_params, 'UFActionSpec': loader.load_action_spec, 'UFRewardParams': loader.load_reward_params, 'UFStateBounds': loader.load_state_bounds, } if section not in mapping: raise ValueError(f"Unknown section: {section}. " f"Available: {list(mapping.keys())}") return mapping[section]() def create_env_params_from_yaml(config_path: Union[str, Path]) -> tuple: """ 从YAML文件创建所有环境参数,返回元组方便解包 Args: config_path: 配置文件路径 Returns: (uf_state, physics_params, action_spec, reward_params, state_bounds) """ loader = EnvConfigLoader(config_path) return ( loader.load_uf_state(), loader.load_physics_params(), loader.load_action_spec(), loader.load_reward_params(), loader.load_state_bounds(), ) # ========== 测试 ========== if __name__ == "__main__": # 测试配置加载 from pathlib import Path # 假设配置文件在当前目录的上一级config文件夹中 default_config = Path(__file__).parent.parent.parent / "config" / "xishan_env_config.yaml" if default_config.exists(): print(f"Testing config loading from: {default_config}") # 测试加载器 loader = EnvConfigLoader(default_config) # 验证配置 loader.validate_config() loader.print_config_summary() # 测试加载各个参数 try: uf_state = loader.load_uf_state() print(f"UFState: q_UF={uf_state.q_UF}, TMP={uf_state.TMP}") physics = loader.load_physics_params() print(f"UFPhysicsParams: A={physics.A}, T_hard_limit={physics.global_TMP_hard_limit}") action = loader.load_action_spec() print(f"UFActionSpec: L_range=[{action.L_min_s}, {action.L_max_s}]") reward = loader.load_reward_params() print(f"UFRewardParams: w_tmp={reward.w_tmp}, residual_ref_ratio={reward.residual_ref_ratio}") bounds = loader.load_state_bounds() print(f"UFStateBounds: TMP0_range=[{bounds.TMP0_min}, {bounds.TMP0_max}]") # 测试一次性加载所有 all_params = loader.load_all() print(f"\nAll {len(all_params)} parameter classes loaded successfully") except Exception as e: print(f"Error loading config: {e}") import traceback traceback.print_exc() else: print(f"Config file not found: {default_config}") print("Please specify config path manually for testing.")