env_config_loader.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. """
  2. env_config_loader.py
  3. 配置加载器,负责从YAML文件加载配置并实例化env_params中的参数类。
  4. """
  5. import yaml
  6. from pathlib import Path
  7. from typing import Dict, Any, Union, Optional, Type, TypeVar
  8. # 改为绝对导入
  9. from uf_train.env.env_params import (
  10. UFState,
  11. UFPhysicsParams,
  12. UFActionSpec,
  13. UFRewardParams,
  14. UFStateBounds
  15. )
  16. T = TypeVar('T')
  17. class EnvConfigLoader:
  18. """环境配置加载器,从YAML加载并实例化参数类"""
  19. # 参数类与配置节名的映射
  20. SECTION_TO_CLASS = {
  21. 'UFState': UFState,
  22. 'UFPhysicsParams': UFPhysicsParams,
  23. 'UFActionSpec': UFActionSpec,
  24. 'UFRewardParams': UFRewardParams,
  25. 'UFStateBounds': UFStateBounds,
  26. }
  27. def __init__(self, config_path: Union[str, Path]):
  28. """
  29. 初始化配置加载器
  30. Args:
  31. config_path: YAML配置文件路径
  32. """
  33. self.config_path = Path(config_path)
  34. self._config = None
  35. @property
  36. def config(self) -> Dict[str, Any]:
  37. """懒加载配置"""
  38. if self._config is None:
  39. self._config = self._load_yaml()
  40. return self._config
  41. def _load_yaml(self) -> Dict[str, Any]:
  42. """加载YAML文件"""
  43. if not self.config_path.exists():
  44. raise FileNotFoundError(f"Config file not found: {self.config_path}")
  45. with open(self.config_path, 'r', encoding='utf-8') as f:
  46. return yaml.safe_load(f)
  47. def _create_instance(self, cls: Type[T], section: str) -> T:
  48. """
  49. 从配置的指定部分创建参数类实例
  50. """
  51. if section not in self.config:
  52. raise KeyError(f"Section '{section}' not found in config. "
  53. f"Available sections: {list(self.config.keys())}")
  54. section_config = self.config[section]
  55. # 特殊处理:如果配置中有null值,需要转换为None
  56. # YAML中的null会自动转换为Python的None,所以不需要额外处理
  57. return cls(**section_config)
  58. # ========== 加载各个参数类 ==========
  59. def load_uf_state(self) -> UFState:
  60. """加载UFState(环境动态状态)"""
  61. return self._create_instance(UFState, 'UFState')
  62. def load_physics_params(self) -> UFPhysicsParams:
  63. """加载物理参数 UFPhysicsParams"""
  64. return self._create_instance(UFPhysicsParams, 'UFPhysicsParams')
  65. def load_action_spec(self) -> UFActionSpec:
  66. """加载动作规范 UFActionSpec"""
  67. return self._create_instance(UFActionSpec, 'UFActionSpec')
  68. def load_reward_params(self) -> UFRewardParams:
  69. """加载奖励参数 UFRewardParams"""
  70. return self._create_instance(UFRewardParams, 'UFRewardParams')
  71. def load_state_bounds(self) -> UFStateBounds:
  72. """加载状态边界 UFStateBounds"""
  73. return self._create_instance(UFStateBounds, 'UFStateBounds')
  74. def load_all(self) -> Dict[str, Any]:
  75. """加载所有参数类"""
  76. return {
  77. 'uf_state': self.load_uf_state(),
  78. 'physics': self.load_physics_params(),
  79. 'action': self.load_action_spec(),
  80. 'reward': self.load_reward_params(),
  81. 'state_bounds': self.load_state_bounds(),
  82. }
  83. # ========== 工具方法 ==========
  84. def get_raw_section(self, section: str) -> Dict[str, Any]:
  85. """获取配置的原始字典部分"""
  86. return self.config.get(section, {}).copy()
  87. def validate_config(self) -> bool:
  88. """
  89. 验证配置文件是否包含所有必需的配置节
  90. Returns:
  91. bool: 验证通过返回True,否则抛出异常
  92. """
  93. missing_sections = []
  94. for section in self.SECTION_TO_CLASS.keys():
  95. if section not in self.config:
  96. missing_sections.append(section)
  97. if missing_sections:
  98. raise ValueError(f"Missing required sections in config: {missing_sections}")
  99. print("✅ Config validation passed. All required sections present.")
  100. return True
  101. def print_config_summary(self):
  102. """打印配置摘要信息"""
  103. print("\n" + "=" * 50)
  104. print("📋 配置加载摘要")
  105. print("=" * 50)
  106. for section in self.SECTION_TO_CLASS.keys():
  107. if section in self.config:
  108. print(f"✅ {section}: {len(self.config[section])} parameters")
  109. else:
  110. print(f"❌ {section}: MISSING")
  111. print("=" * 50 + "\n")
  112. # ========== 便捷函数 ==========
  113. def load_env_config(config_path: Union[str, Path]) -> Dict[str, Any]:
  114. """
  115. 便捷函数:一次性加载所有环境配置
  116. Args:
  117. config_path: 配置文件路径
  118. Returns:
  119. 包含所有参数类实例的字典
  120. """
  121. loader = EnvConfigLoader(config_path)
  122. return loader.load_all()
  123. def load_single_config(config_path: Union[str, Path], section: str):
  124. """
  125. 便捷函数:加载单个配置节
  126. Args:
  127. config_path: 配置文件路径
  128. section: 配置部分名称 (UFState, UFPhysicsParams, UFActionSpec, UFRewardParams, UFStateBounds)
  129. Returns:
  130. 对应的参数类实例
  131. """
  132. loader = EnvConfigLoader(config_path)
  133. mapping = {
  134. 'UFState': loader.load_uf_state,
  135. 'UFPhysicsParams': loader.load_physics_params,
  136. 'UFActionSpec': loader.load_action_spec,
  137. 'UFRewardParams': loader.load_reward_params,
  138. 'UFStateBounds': loader.load_state_bounds,
  139. }
  140. if section not in mapping:
  141. raise ValueError(f"Unknown section: {section}. "
  142. f"Available: {list(mapping.keys())}")
  143. return mapping[section]()
  144. def create_env_params_from_yaml(config_path: Union[str, Path]) -> tuple:
  145. """
  146. 从YAML文件创建所有环境参数,返回元组方便解包
  147. Args:
  148. config_path: 配置文件路径
  149. Returns:
  150. (uf_state, physics_params, action_spec, reward_params, state_bounds)
  151. """
  152. loader = EnvConfigLoader(config_path)
  153. return (
  154. loader.load_uf_state(),
  155. loader.load_physics_params(),
  156. loader.load_action_spec(),
  157. loader.load_reward_params(),
  158. loader.load_state_bounds(),
  159. )
  160. # ========== 测试 ==========
  161. if __name__ == "__main__":
  162. # 测试配置加载
  163. import sys
  164. from pathlib import Path
  165. # 假设配置文件在当前目录的上一级config文件夹中
  166. default_config = Path(__file__).parent.parent.parent / "config" / "xishan_env_config.yaml"
  167. if default_config.exists():
  168. print(f"Testing config loading from: {default_config}")
  169. # 测试加载器
  170. loader = EnvConfigLoader(default_config)
  171. # 验证配置
  172. loader.validate_config()
  173. loader.print_config_summary()
  174. # 测试加载各个参数
  175. try:
  176. uf_state = loader.load_uf_state()
  177. print(f"UFState: q_UF={uf_state.q_UF}, TMP={uf_state.TMP}")
  178. physics = loader.load_physics_params()
  179. print(f"UFPhysicsParams: A={physics.A}, T_hard_limit={physics.global_TMP_hard_limit}")
  180. action = loader.load_action_spec()
  181. print(f"UFActionSpec: L_range=[{action.L_min_s}, {action.L_max_s}]")
  182. reward = loader.load_reward_params()
  183. print(f"UFRewardParams: w_tmp={reward.w_tmp}, residual_ref_ratio={reward.residual_ref_ratio}")
  184. bounds = loader.load_state_bounds()
  185. print(f"UFStateBounds: TMP0_range=[{bounds.TMP0_min}, {bounds.TMP0_max}]")
  186. # 测试一次性加载所有
  187. all_params = loader.load_all()
  188. print(f"\nAll {len(all_params)} parameter classes loaded successfully")
  189. except Exception as e:
  190. print(f"Error loading config: {e}")
  191. import traceback
  192. traceback.print_exc()
  193. else:
  194. print(f"Config file not found: {default_config}")
  195. print("Please specify config path manually for testing.")