| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- # -*- coding: utf-8 -*-
- """
- config.py - 部署环境配置
- ========================
- 包含部署环境所需的配置参数。
- 参数必须与训练环境保持一致以确保推理结果正确。
- """
- from pathlib import Path
- class DeployConfig:
- """
- 部署配置类
-
- 包含推理所需的路径和参数配置。
- 与训练环境的config.py保持参数一致。
-
- 支持两种模型目录结构:
- 1. 默认: models/ae_model.pth, models/thresholds/
- 2. 子目录: models/{subdir}/ae_model.pth, models/{subdir}/thresholds/
- """
-
- # ========================================
- # 路径配置
- # ========================================
-
- # 部署根目录(predictor的上级目录)
- DEPLOY_ROOT = Path(__file__).resolve().parent.parent
-
- # 模型根目录
- MODEL_ROOT = DEPLOY_ROOT / "models"
-
- # 模型子目录(可通过 set_model_subdir 设置)
- # 为空时使用 MODEL_ROOT,否则使用 MODEL_ROOT / MODEL_SUBDIR
- MODEL_SUBDIR = ""
-
- # 音频目录
- AUDIO_DIR = DEPLOY_ROOT / "data" / "audio"
-
- @classmethod
- def set_model_subdir(cls, subdir: str):
- """
- 设置模型子目录
-
- Args:
- subdir: 子目录名(如 "LT-2"),为空则使用默认目录
- """
- cls.MODEL_SUBDIR = subdir
-
- @property
- def MODEL_DIR(self) -> Path:
- """获取当前模型目录"""
- if self.MODEL_SUBDIR:
- return self.MODEL_ROOT / self.MODEL_SUBDIR
- return self.MODEL_ROOT
-
- @property
- def AE_MODEL_PATH(self) -> Path:
- """自编码器模型文件路径"""
- return self.MODEL_DIR / "ae_model.pth"
-
- @property
- def THRESHOLD_DIR(self) -> Path:
- """阈值目录路径"""
- return self.MODEL_DIR / "thresholds"
-
- @property
- def SCALE_FILE(self) -> Path:
- """全局标准化参数文件路径"""
- return self.MODEL_DIR / "global_scale.npy"
-
- # ========================================
- # 音频参数(必须与训练一致)
- # ========================================
-
- # 采样率
- SR = 16000
-
- # 窗口长度(秒)
- WIN_SEC = 8.0
-
- # 窗口步长(秒)
- HOP_SEC = 4.0
-
- # ========================================
- # Mel频谱参数(必须与训练一致)
- # ========================================
-
- # Mel频带数
- N_MELS = 64
-
- # FFT窗口大小
- N_FFT = 1024
-
- # STFT步长
- HOP_LENGTH = 256
-
- # 目标帧数
- TARGET_FRAMES = 504
-
- # ========================================
- # 预测参数
- # ========================================
-
- # 批次大小
- BATCH_SIZE = 64
-
- # 3σ阈值系数(与训练版保持一致)
- SIGMA_MULTIPLIER = 3.0
-
- # 阈值分位数(用于增量训练时计算阈值)
- THRESHOLD_QUANTILE = 0.95
-
- # 异常patch比例阈值
- ANOMALY_RATIO_THRESHOLD = 0.1
- # 全局配置实例
- CFG = DeployConfig()
|