config.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # -*- coding: utf-8 -*-
  2. """
  3. config.py - 部署环境配置
  4. ========================
  5. 包含部署环境所需的配置参数。
  6. 参数必须与训练环境保持一致以确保推理结果正确。
  7. """
  8. from pathlib import Path
  9. class DeployConfig:
  10. """
  11. 部署配置类
  12. 包含推理所需的路径和参数配置。
  13. 与训练环境的config.py保持参数一致。
  14. 支持两种模型目录结构:
  15. 1. 默认: models/ae_model.pth, models/thresholds/
  16. 2. 子目录: models/{subdir}/ae_model.pth, models/{subdir}/thresholds/
  17. """
  18. # ========================================
  19. # 路径配置
  20. # ========================================
  21. # 部署根目录(predictor的上级目录)
  22. DEPLOY_ROOT = Path(__file__).resolve().parent.parent
  23. # 模型根目录
  24. MODEL_ROOT = DEPLOY_ROOT / "models"
  25. # 模型子目录(可通过 set_model_subdir 设置)
  26. # 为空时使用 MODEL_ROOT,否则使用 MODEL_ROOT / MODEL_SUBDIR
  27. MODEL_SUBDIR = ""
  28. # 音频目录
  29. AUDIO_DIR = DEPLOY_ROOT / "data" / "audio"
  30. @classmethod
  31. def set_model_subdir(cls, subdir: str):
  32. """
  33. 设置模型子目录
  34. Args:
  35. subdir: 子目录名(如 "LT-2"),为空则使用默认目录
  36. """
  37. cls.MODEL_SUBDIR = subdir
  38. @property
  39. def MODEL_DIR(self) -> Path:
  40. """获取当前模型目录"""
  41. if self.MODEL_SUBDIR:
  42. return self.MODEL_ROOT / self.MODEL_SUBDIR
  43. return self.MODEL_ROOT
  44. @property
  45. def AE_MODEL_PATH(self) -> Path:
  46. """自编码器模型文件路径"""
  47. return self.MODEL_DIR / "ae_model.pth"
  48. @property
  49. def THRESHOLD_DIR(self) -> Path:
  50. """阈值目录路径"""
  51. return self.MODEL_DIR / "thresholds"
  52. @property
  53. def SCALE_FILE(self) -> Path:
  54. """全局标准化参数文件路径"""
  55. return self.MODEL_DIR / "global_scale.npy"
  56. # ========================================
  57. # 音频参数(必须与训练一致)
  58. # ========================================
  59. # 采样率
  60. SR = 16000
  61. # 窗口长度(秒)
  62. WIN_SEC = 8.0
  63. # 窗口步长(秒)
  64. HOP_SEC = 4.0
  65. # ========================================
  66. # Mel频谱参数(必须与训练一致)
  67. # ========================================
  68. # Mel频带数
  69. N_MELS = 64
  70. # FFT窗口大小
  71. N_FFT = 1024
  72. # STFT步长
  73. HOP_LENGTH = 256
  74. # 目标帧数
  75. TARGET_FRAMES = 504
  76. # ========================================
  77. # 预测参数
  78. # ========================================
  79. # 批次大小
  80. BATCH_SIZE = 64
  81. # 3σ阈值系数(与训练版保持一致)
  82. SIGMA_MULTIPLIER = 3.0
  83. # 阈值分位数(用于增量训练时计算阈值)
  84. THRESHOLD_QUANTILE = 0.95
  85. # 异常patch比例阈值
  86. ANOMALY_RATIO_THRESHOLD = 0.1
  87. # 全局配置实例
  88. CFG = DeployConfig()