|
|
@@ -0,0 +1,257 @@
|
|
|
+"""
|
|
|
+env_config_loader.py
|
|
|
+
|
|
|
+配置加载器,负责从YAML文件加载配置并实例化env_params中的参数类。
|
|
|
+"""
|
|
|
+
|
|
|
+import yaml
|
|
|
+from pathlib import Path
|
|
|
+from typing import Dict, Any, Union, Optional, Type, TypeVar
|
|
|
+
|
|
|
+# 改为绝对导入
|
|
|
+from uf_train.env.env_params import (
|
|
|
+ UFState,
|
|
|
+ UFPhysicsParams,
|
|
|
+ UFActionSpec,
|
|
|
+ UFRewardParams,
|
|
|
+ UFStateBounds
|
|
|
+)
|
|
|
+
|
|
|
+T = TypeVar('T')
|
|
|
+
|
|
|
+
|
|
|
+class EnvConfigLoader:
|
|
|
+ """环境配置加载器,从YAML加载并实例化参数类"""
|
|
|
+
|
|
|
+ # 参数类与配置节名的映射
|
|
|
+ SECTION_TO_CLASS = {
|
|
|
+ 'UFState': UFState,
|
|
|
+ 'UFPhysicsParams': UFPhysicsParams,
|
|
|
+ 'UFActionSpec': UFActionSpec,
|
|
|
+ 'UFRewardParams': UFRewardParams,
|
|
|
+ 'UFStateBounds': UFStateBounds,
|
|
|
+ }
|
|
|
+
|
|
|
+ def __init__(self, config_path: Union[str, Path]):
|
|
|
+ """
|
|
|
+ 初始化配置加载器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config_path: YAML配置文件路径
|
|
|
+ """
|
|
|
+ self.config_path = Path(config_path)
|
|
|
+ self._config = None
|
|
|
+
|
|
|
+ @property
|
|
|
+ def config(self) -> Dict[str, Any]:
|
|
|
+ """懒加载配置"""
|
|
|
+ if self._config is None:
|
|
|
+ self._config = self._load_yaml()
|
|
|
+ return self._config
|
|
|
+
|
|
|
+ def _load_yaml(self) -> Dict[str, Any]:
|
|
|
+ """加载YAML文件"""
|
|
|
+ if not self.config_path.exists():
|
|
|
+ raise FileNotFoundError(f"Config file not found: {self.config_path}")
|
|
|
+
|
|
|
+ with open(self.config_path, 'r', encoding='utf-8') as f:
|
|
|
+ return yaml.safe_load(f)
|
|
|
+
|
|
|
+ def _create_instance(self, cls: Type[T], section: str) -> T:
|
|
|
+ """
|
|
|
+ 从配置的指定部分创建参数类实例
|
|
|
+ """
|
|
|
+ if section not in self.config:
|
|
|
+ raise KeyError(f"Section '{section}' not found in config. "
|
|
|
+ f"Available sections: {list(self.config.keys())}")
|
|
|
+
|
|
|
+ section_config = self.config[section]
|
|
|
+
|
|
|
+ # 特殊处理:如果配置中有null值,需要转换为None
|
|
|
+ # YAML中的null会自动转换为Python的None,所以不需要额外处理
|
|
|
+
|
|
|
+ return cls(**section_config)
|
|
|
+
|
|
|
+ # ========== 加载各个参数类 ==========
|
|
|
+
|
|
|
+ def load_uf_state(self) -> UFState:
|
|
|
+ """加载UFState(环境动态状态)"""
|
|
|
+ return self._create_instance(UFState, 'UFState')
|
|
|
+
|
|
|
+ def load_physics_params(self) -> UFPhysicsParams:
|
|
|
+ """加载物理参数 UFPhysicsParams"""
|
|
|
+ return self._create_instance(UFPhysicsParams, 'UFPhysicsParams')
|
|
|
+
|
|
|
+ def load_action_spec(self) -> UFActionSpec:
|
|
|
+ """加载动作规范 UFActionSpec"""
|
|
|
+ return self._create_instance(UFActionSpec, 'UFActionSpec')
|
|
|
+
|
|
|
+ def load_reward_params(self) -> UFRewardParams:
|
|
|
+ """加载奖励参数 UFRewardParams"""
|
|
|
+ return self._create_instance(UFRewardParams, 'UFRewardParams')
|
|
|
+
|
|
|
+ def load_state_bounds(self) -> UFStateBounds:
|
|
|
+ """加载状态边界 UFStateBounds"""
|
|
|
+ return self._create_instance(UFStateBounds, 'UFStateBounds')
|
|
|
+
|
|
|
+ def load_all(self) -> Dict[str, Any]:
|
|
|
+ """加载所有参数类"""
|
|
|
+ return {
|
|
|
+ 'uf_state': self.load_uf_state(),
|
|
|
+ 'physics': self.load_physics_params(),
|
|
|
+ 'action': self.load_action_spec(),
|
|
|
+ 'reward': self.load_reward_params(),
|
|
|
+ 'state_bounds': self.load_state_bounds(),
|
|
|
+ }
|
|
|
+
|
|
|
+ # ========== 工具方法 ==========
|
|
|
+
|
|
|
+ def get_raw_section(self, section: str) -> Dict[str, Any]:
|
|
|
+ """获取配置的原始字典部分"""
|
|
|
+ return self.config.get(section, {}).copy()
|
|
|
+
|
|
|
+ def validate_config(self) -> bool:
|
|
|
+ """
|
|
|
+ 验证配置文件是否包含所有必需的配置节
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ bool: 验证通过返回True,否则抛出异常
|
|
|
+ """
|
|
|
+ missing_sections = []
|
|
|
+ for section in self.SECTION_TO_CLASS.keys():
|
|
|
+ if section not in self.config:
|
|
|
+ missing_sections.append(section)
|
|
|
+
|
|
|
+ if missing_sections:
|
|
|
+ raise ValueError(f"Missing required sections in config: {missing_sections}")
|
|
|
+
|
|
|
+ print("✅ Config validation passed. All required sections present.")
|
|
|
+ return True
|
|
|
+
|
|
|
+ def print_config_summary(self):
|
|
|
+ """打印配置摘要信息"""
|
|
|
+ print("\n" + "=" * 50)
|
|
|
+ print("📋 配置加载摘要")
|
|
|
+ print("=" * 50)
|
|
|
+
|
|
|
+ for section in self.SECTION_TO_CLASS.keys():
|
|
|
+ if section in self.config:
|
|
|
+ print(f"✅ {section}: {len(self.config[section])} parameters")
|
|
|
+ else:
|
|
|
+ print(f"❌ {section}: MISSING")
|
|
|
+
|
|
|
+ print("=" * 50 + "\n")
|
|
|
+
|
|
|
+
|
|
|
+# ========== 便捷函数 ==========
|
|
|
+
|
|
|
+def load_env_config(config_path: Union[str, Path]) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 便捷函数:一次性加载所有环境配置
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config_path: 配置文件路径
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 包含所有参数类实例的字典
|
|
|
+ """
|
|
|
+ loader = EnvConfigLoader(config_path)
|
|
|
+ return loader.load_all()
|
|
|
+
|
|
|
+
|
|
|
+def load_single_config(config_path: Union[str, Path], section: str):
|
|
|
+ """
|
|
|
+ 便捷函数:加载单个配置节
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config_path: 配置文件路径
|
|
|
+ section: 配置部分名称 (UFState, UFPhysicsParams, UFActionSpec, UFRewardParams, UFStateBounds)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 对应的参数类实例
|
|
|
+ """
|
|
|
+ loader = EnvConfigLoader(config_path)
|
|
|
+
|
|
|
+ mapping = {
|
|
|
+ 'UFState': loader.load_uf_state,
|
|
|
+ 'UFPhysicsParams': loader.load_physics_params,
|
|
|
+ 'UFActionSpec': loader.load_action_spec,
|
|
|
+ 'UFRewardParams': loader.load_reward_params,
|
|
|
+ 'UFStateBounds': loader.load_state_bounds,
|
|
|
+ }
|
|
|
+
|
|
|
+ if section not in mapping:
|
|
|
+ raise ValueError(f"Unknown section: {section}. "
|
|
|
+ f"Available: {list(mapping.keys())}")
|
|
|
+
|
|
|
+ return mapping[section]()
|
|
|
+
|
|
|
+
|
|
|
+def create_env_params_from_yaml(config_path: Union[str, Path]) -> tuple:
|
|
|
+ """
|
|
|
+ 从YAML文件创建所有环境参数,返回元组方便解包
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config_path: 配置文件路径
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (uf_state, physics_params, action_spec, reward_params, state_bounds)
|
|
|
+ """
|
|
|
+ loader = EnvConfigLoader(config_path)
|
|
|
+ return (
|
|
|
+ loader.load_uf_state(),
|
|
|
+ loader.load_physics_params(),
|
|
|
+ loader.load_action_spec(),
|
|
|
+ loader.load_reward_params(),
|
|
|
+ loader.load_state_bounds(),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+# ========== 测试 ==========
|
|
|
+if __name__ == "__main__":
|
|
|
+ # 测试配置加载
|
|
|
+ import sys
|
|
|
+ from pathlib import Path
|
|
|
+
|
|
|
+ # 假设配置文件在当前目录的上一级config文件夹中
|
|
|
+ default_config = Path(__file__).parent.parent.parent / "config" / "xishan_env_config.yaml"
|
|
|
+
|
|
|
+ if default_config.exists():
|
|
|
+ print(f"Testing config loading from: {default_config}")
|
|
|
+
|
|
|
+ # 测试加载器
|
|
|
+ loader = EnvConfigLoader(default_config)
|
|
|
+
|
|
|
+ # 验证配置
|
|
|
+ loader.validate_config()
|
|
|
+ loader.print_config_summary()
|
|
|
+
|
|
|
+ # 测试加载各个参数
|
|
|
+ try:
|
|
|
+ uf_state = loader.load_uf_state()
|
|
|
+ print(f"UFState: q_UF={uf_state.q_UF}, TMP={uf_state.TMP}")
|
|
|
+
|
|
|
+ physics = loader.load_physics_params()
|
|
|
+ print(f"UFPhysicsParams: A={physics.A}, T_hard_limit={physics.global_TMP_hard_limit}")
|
|
|
+
|
|
|
+ action = loader.load_action_spec()
|
|
|
+ print(f"UFActionSpec: L_range=[{action.L_min_s}, {action.L_max_s}]")
|
|
|
+
|
|
|
+ reward = loader.load_reward_params()
|
|
|
+ print(f"UFRewardParams: w_tmp={reward.w_tmp}, residual_ref_ratio={reward.residual_ref_ratio}")
|
|
|
+
|
|
|
+ bounds = loader.load_state_bounds()
|
|
|
+ print(f"UFStateBounds: TMP0_range=[{bounds.TMP0_min}, {bounds.TMP0_max}]")
|
|
|
+
|
|
|
+ # 测试一次性加载所有
|
|
|
+ all_params = loader.load_all()
|
|
|
+ print(f"\nAll {len(all_params)} parameter classes loaded successfully")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error loading config: {e}")
|
|
|
+ import traceback
|
|
|
+
|
|
|
+ traceback.print_exc()
|
|
|
+ else:
|
|
|
+ print(f"Config file not found: {default_config}")
|
|
|
+ print("Please specify config path manually for testing.")
|