Browse Source

增加锡山水厂膜阻力48hCEB/48次CEB模型,相关模型及参数为models/uf-rl/xishan,调用入口为models/uf-rl/uf_train/rl_model/DQN/run_dqn_decide.py或models/uf-rl/uf_train/rl_model/DQN/run_dqn_deicde_totalstate.py

junc_WHU 1 tháng trước cách đây
mục cha
commit
fe768a3336

+ 257 - 0
models/uf-rl/uf_train/env/env_config_loader.py

@@ -0,0 +1,257 @@
+"""
+env_config_loader.py
+
+配置加载器,负责从YAML文件加载配置并实例化env_params中的参数类。
+"""
+
+import yaml
+from pathlib import Path
+from typing import Dict, Any, Union, Optional, Type, TypeVar
+
+# 改为绝对导入
+from uf_train.env.env_params import (
+    UFState,
+    UFPhysicsParams,
+    UFActionSpec,
+    UFRewardParams,
+    UFStateBounds
+)
+
+T = TypeVar('T')
+
+
+class EnvConfigLoader:
+    """环境配置加载器,从YAML加载并实例化参数类"""
+
+    # 参数类与配置节名的映射
+    SECTION_TO_CLASS = {
+        'UFState': UFState,
+        'UFPhysicsParams': UFPhysicsParams,
+        'UFActionSpec': UFActionSpec,
+        'UFRewardParams': UFRewardParams,
+        'UFStateBounds': UFStateBounds,
+    }
+
+    def __init__(self, config_path: Union[str, Path]):
+        """
+        初始化配置加载器
+
+        Args:
+            config_path: YAML配置文件路径
+        """
+        self.config_path = Path(config_path)
+        self._config = None
+
+    @property
+    def config(self) -> Dict[str, Any]:
+        """懒加载配置"""
+        if self._config is None:
+            self._config = self._load_yaml()
+        return self._config
+
+    def _load_yaml(self) -> Dict[str, Any]:
+        """加载YAML文件"""
+        if not self.config_path.exists():
+            raise FileNotFoundError(f"Config file not found: {self.config_path}")
+
+        with open(self.config_path, 'r', encoding='utf-8') as f:
+            return yaml.safe_load(f)
+
+    def _create_instance(self, cls: Type[T], section: str) -> T:
+        """
+        从配置的指定部分创建参数类实例
+        """
+        if section not in self.config:
+            raise KeyError(f"Section '{section}' not found in config. "
+                           f"Available sections: {list(self.config.keys())}")
+
+        section_config = self.config[section]
+
+        # 特殊处理:如果配置中有null值,需要转换为None
+        # YAML中的null会自动转换为Python的None,所以不需要额外处理
+
+        return cls(**section_config)
+
+    # ========== 加载各个参数类 ==========
+
+    def load_uf_state(self) -> UFState:
+        """加载UFState(环境动态状态)"""
+        return self._create_instance(UFState, 'UFState')
+
+    def load_physics_params(self) -> UFPhysicsParams:
+        """加载物理参数 UFPhysicsParams"""
+        return self._create_instance(UFPhysicsParams, 'UFPhysicsParams')
+
+    def load_action_spec(self) -> UFActionSpec:
+        """加载动作规范 UFActionSpec"""
+        return self._create_instance(UFActionSpec, 'UFActionSpec')
+
+    def load_reward_params(self) -> UFRewardParams:
+        """加载奖励参数 UFRewardParams"""
+        return self._create_instance(UFRewardParams, 'UFRewardParams')
+
+    def load_state_bounds(self) -> UFStateBounds:
+        """加载状态边界 UFStateBounds"""
+        return self._create_instance(UFStateBounds, 'UFStateBounds')
+
+    def load_all(self) -> Dict[str, Any]:
+        """加载所有参数类"""
+        return {
+            'uf_state': self.load_uf_state(),
+            'physics': self.load_physics_params(),
+            'action': self.load_action_spec(),
+            'reward': self.load_reward_params(),
+            'state_bounds': self.load_state_bounds(),
+        }
+
+    # ========== 工具方法 ==========
+
+    def get_raw_section(self, section: str) -> Dict[str, Any]:
+        """获取配置的原始字典部分"""
+        return self.config.get(section, {}).copy()
+
+    def validate_config(self) -> bool:
+        """
+        验证配置文件是否包含所有必需的配置节
+
+        Returns:
+            bool: 验证通过返回True,否则抛出异常
+        """
+        missing_sections = []
+        for section in self.SECTION_TO_CLASS.keys():
+            if section not in self.config:
+                missing_sections.append(section)
+
+        if missing_sections:
+            raise ValueError(f"Missing required sections in config: {missing_sections}")
+
+        print("✅ Config validation passed. All required sections present.")
+        return True
+
+    def print_config_summary(self):
+        """打印配置摘要信息"""
+        print("\n" + "=" * 50)
+        print("📋 配置加载摘要")
+        print("=" * 50)
+
+        for section in self.SECTION_TO_CLASS.keys():
+            if section in self.config:
+                print(f"✅ {section}: {len(self.config[section])} parameters")
+            else:
+                print(f"❌ {section}: MISSING")
+
+        print("=" * 50 + "\n")
+
+
+# ========== 便捷函数 ==========
+
+def load_env_config(config_path: Union[str, Path]) -> Dict[str, Any]:
+    """
+    便捷函数:一次性加载所有环境配置
+
+    Args:
+        config_path: 配置文件路径
+
+    Returns:
+        包含所有参数类实例的字典
+    """
+    loader = EnvConfigLoader(config_path)
+    return loader.load_all()
+
+
+def load_single_config(config_path: Union[str, Path], section: str):
+    """
+    便捷函数:加载单个配置节
+
+    Args:
+        config_path: 配置文件路径
+        section: 配置部分名称 (UFState, UFPhysicsParams, UFActionSpec, UFRewardParams, UFStateBounds)
+
+    Returns:
+        对应的参数类实例
+    """
+    loader = EnvConfigLoader(config_path)
+
+    mapping = {
+        'UFState': loader.load_uf_state,
+        'UFPhysicsParams': loader.load_physics_params,
+        'UFActionSpec': loader.load_action_spec,
+        'UFRewardParams': loader.load_reward_params,
+        'UFStateBounds': loader.load_state_bounds,
+    }
+
+    if section not in mapping:
+        raise ValueError(f"Unknown section: {section}. "
+                         f"Available: {list(mapping.keys())}")
+
+    return mapping[section]()
+
+
+def create_env_params_from_yaml(config_path: Union[str, Path]) -> tuple:
+    """
+    从YAML文件创建所有环境参数,返回元组方便解包
+
+    Args:
+        config_path: 配置文件路径
+
+    Returns:
+        (uf_state, physics_params, action_spec, reward_params, state_bounds)
+    """
+    loader = EnvConfigLoader(config_path)
+    return (
+        loader.load_uf_state(),
+        loader.load_physics_params(),
+        loader.load_action_spec(),
+        loader.load_reward_params(),
+        loader.load_state_bounds(),
+    )
+
+
+# ========== 测试 ==========
+if __name__ == "__main__":
+    # 测试配置加载
+    import sys
+    from pathlib import Path
+
+    # 假设配置文件在当前目录的上一级config文件夹中
+    default_config = Path(__file__).parent.parent.parent / "config" / "xishan_env_config.yaml"
+
+    if default_config.exists():
+        print(f"Testing config loading from: {default_config}")
+
+        # 测试加载器
+        loader = EnvConfigLoader(default_config)
+
+        # 验证配置
+        loader.validate_config()
+        loader.print_config_summary()
+
+        # 测试加载各个参数
+        try:
+            uf_state = loader.load_uf_state()
+            print(f"UFState: q_UF={uf_state.q_UF}, TMP={uf_state.TMP}")
+
+            physics = loader.load_physics_params()
+            print(f"UFPhysicsParams: A={physics.A}, T_hard_limit={physics.global_TMP_hard_limit}")
+
+            action = loader.load_action_spec()
+            print(f"UFActionSpec: L_range=[{action.L_min_s}, {action.L_max_s}]")
+
+            reward = loader.load_reward_params()
+            print(f"UFRewardParams: w_tmp={reward.w_tmp}, residual_ref_ratio={reward.residual_ref_ratio}")
+
+            bounds = loader.load_state_bounds()
+            print(f"UFStateBounds: TMP0_range=[{bounds.TMP0_min}, {bounds.TMP0_max}]")
+
+            # 测试一次性加载所有
+            all_params = loader.load_all()
+            print(f"\nAll {len(all_params)} parameter classes loaded successfully")
+
+        except Exception as e:
+            print(f"Error loading config: {e}")
+            import traceback
+
+            traceback.print_exc()
+    else:
+        print(f"Config file not found: {default_config}")
+        print("Please specify config path manually for testing.")

+ 182 - 0
models/uf-rl/uf_train/rl_model/DQN/dqn_config_loader.py

@@ -0,0 +1,182 @@
+"""
+dqn_config_loader.py
+
+DQN 配置加载器,负责从YAML文件加载DQN超参数。
+与dqn_params.py同级,保持参数类的纯净性。
+"""
+
+import yaml
+from pathlib import Path
+from typing import Dict, Any, Union, Optional
+from uf_train.rl_model.DQN.dqn_params import DQNParams
+
+
+
+
+class DQNConfigLoader:
+    """DQN配置加载器,从YAML加载并创建DQNParams实例"""
+
+    # DQNParams的字段列表(用于过滤)
+    DQN_FIELDS = [
+        'learning_rate',
+        'buffer_size',
+        'learning_starts',
+        'batch_size',
+        'gamma',
+        'train_freq',
+        'target_update_interval',
+        'tau',
+        'exploration_initial_eps',
+        'exploration_fraction',
+        'exploration_final_eps',
+        'remark'
+    ]
+
+    def __init__(self, config_path: Union[str, Path]):
+        """
+        初始化DQN配置加载器
+
+        Args:
+            config_path: YAML配置文件路径
+        """
+        self.config_path = Path(config_path)
+        self._config = None
+
+    @property
+    def config(self) -> Dict[str, Any]:
+        """懒加载配置"""
+        if self._config is None:
+            self._config = self._load_yaml()
+        return self._config
+
+    def _load_yaml(self) -> Dict[str, Any]:
+        """加载YAML文件"""
+        if not self.config_path.exists():
+            raise FileNotFoundError(f"DQN config file not found: {self.config_path}")
+
+        with open(self.config_path, 'r', encoding='utf-8') as f:
+            return yaml.safe_load(f)
+
+    def _filter_dqn_params(self, raw_config: Dict[str, Any]) -> Dict[str, Any]:
+        """
+        过滤出DQNParams需要的参数
+
+        Args:
+            raw_config: 原始配置字典
+
+        Returns:
+            只包含DQNParams字段的字典
+        """
+        return {k: v for k, v in raw_config.items() if k in self.DQN_FIELDS}
+
+    def load_params(self) -> DQNParams:
+        """
+        从YAML配置加载DQN参数
+
+        Returns:
+            DQNParams实例
+        """
+        filtered_config = self._filter_dqn_params(self.config)
+        return DQNParams(**filtered_config)
+
+    def validate_config(self) -> bool:
+        """
+        验证配置文件
+
+        Returns:
+            bool: 验证通过返回True
+        """
+        # 检查是否有未知参数(可选)
+        unknown_params = set(self.config.keys()) - set(self.DQN_FIELDS)
+        if unknown_params:
+            print(f"⚠️ Warning: Unknown parameters in DQN config: {unknown_params}")
+
+        # 检查必需参数(所有参数都有默认值,所以都是可选的)
+        print("✅ DQN config loaded successfully")
+        return True
+
+    def print_config_summary(self):
+        """打印配置摘要"""
+        params = self.load_params()
+        print("\n" + "=" * 50)
+        print("🤖 DQN 超参数配置(从YAML加载)")
+        print("=" * 50)
+        print(f"学习率 (learning_rate): {params.learning_rate}")
+        print(f"缓冲区大小 (buffer_size): {params.buffer_size}")
+        print(f"预热步数 (learning_starts): {params.learning_starts}")
+        print(f"批次大小 (batch_size): {params.batch_size}")
+        print(f"折扣因子 (gamma): {params.gamma}")
+        print(f"训练频率 (train_freq): {params.train_freq}")
+        print(f"目标网络更新间隔: {params.target_update_interval}")
+        print(f"软更新系数 (tau): {params.tau}")
+        print(f"初始探索率: {params.exploration_initial_eps}")
+        print(f"探索率衰减比例: {params.exploration_fraction}")
+        print(f"最终探索率: {params.exploration_final_eps}")
+        print(f"实验备注: {params.remark}")
+        print("=" * 50)
+
+
+# ========== 便捷函数 ==========
+
+def load_dqn_config(config_path: Union[str, Path]) -> DQNParams:
+    """
+    便捷函数:从YAML文件加载DQN配置
+
+    Args:
+        config_path: 配置文件路径
+
+    Returns:
+        DQNParams实例
+    """
+    loader = DQNConfigLoader(config_path)
+    return loader.load_params()
+
+
+def load_dqn_config_with_validation(config_path: Union[str, Path]) -> DQNParams:
+    """
+    加载并验证DQN配置
+
+    Args:
+        config_path: 配置文件路径
+
+    Returns:
+        DQNParams实例
+    """
+    loader = DQNConfigLoader(config_path)
+    loader.validate_config()
+    return loader.load_params()
+
+
+# ========== 测试代码 ==========
+if __name__ == "__main__":
+    # 测试配置加载
+    import sys
+    from pathlib import Path
+
+    # 假设配置文件在项目根目录的config文件夹中
+    current_dir = Path(__file__).parent
+    project_root = current_dir.parent.parent  # uf-rl
+    default_config = Path(__file__).parent.parent.parent.parent / "config" / "xishan_dqn_config.yaml"
+
+    if default_config.exists():
+        print(f"Testing DQN config loading from: {default_config}")
+
+        # 测试加载器
+        loader = DQNConfigLoader(default_config)
+
+        # 验证配置
+        loader.validate_config()
+
+        # 加载参数
+        dqn_params = loader.load_params()
+
+        # 打印摘要
+        loader.print_config_summary()
+
+        print("\n✅ DQN config loaded successfully!")
+        print(f"   Learning rate: {dqn_params.learning_rate}")
+        print(f"   Buffer size: {dqn_params.buffer_size}")
+        print(f"   Remark: {dqn_params.remark}")
+    else:
+        print(f"Config file not found: {default_config}")
+        print("Please create a DQN config file first.")

+ 0 - 0
models/uf-rl/训练/__init__.py


+ 0 - 0
models/uf-rl/训练/uf_train/data_to_rl/__init__.py