wmy 2 nedēļas atpakaļ
vecāks
revīzija
33c6123101
35 mainītis faili ar 442 papildinājumiem un 315 dzēšanām
  1. 12 0
      README.md
  2. 16 6
      auto_training/data_cleanup.py
  3. 111 246
      auto_training/incremental_trainer.py
  4. 2 2
      config/auto_training.yaml
  5. 1 1
      config/config_manager.py
  6. BIN
      config/pickup_config.db
  7. BIN
      config/pickup_config.db-shm
  8. 0 0
      config/pickup_config.db-wal
  9. BIN
      config/yaml_backup/db_output/pickup_config_anzhen.db
  10. BIN
      config/yaml_backup/db_output/pickup_config_jianding.db
  11. BIN
      config/yaml_backup/db_output/pickup_config_longting.db
  12. BIN
      config/yaml_backup/db_output/pickup_config_longting.db-shm
  13. 0 0
      config/yaml_backup/db_output/pickup_config_longting.db-wal
  14. BIN
      config/yaml_backup/db_output/pickup_config_xishan.db
  15. BIN
      config/yaml_backup/db_output/pickup_config_xishan.db-shm
  16. 0 0
      config/yaml_backup/db_output/pickup_config_xishan.db-wal
  17. BIN
      config/yaml_backup/db_output/pickup_config_yancheng.db
  18. 12 1
      config/yaml_backup/rtsp_config_anzhen.yaml
  19. 1 1
      config/yaml_backup/rtsp_config_jianding.yaml
  20. 44 1
      config/yaml_backup/rtsp_config_longting.yaml
  21. 46 1
      config/yaml_backup/rtsp_config_xishan.yaml
  22. 1 1
      config/yaml_backup/rtsp_config_yancheng.yaml
  23. 1 1
      core/pump_state_monitor.py
  24. BIN
      models/LT-2/ae_model.pth
  25. BIN
      models/LT-2/global_scale.npy
  26. BIN
      models/LT-2/thresholds/threshold_default.npy
  27. BIN
      models/LT-5/ae_model.pth
  28. BIN
      models/LT-5/global_scale.npy
  29. BIN
      models/LT-5/thresholds/threshold_default.npy
  30. 31 2
      predictor/multi_model_predictor.py
  31. 21 10
      predictor/utils.py
  32. 1 0
      requirements.txt
  33. 37 20
      run_with_auto_training.py
  34. 104 21
      start.sh
  35. 1 1
      tool/migrate_yaml_to_db.py

+ 12 - 0
README.md

@@ -146,3 +146,15 @@ deploy_pickup/
 ├── tool/migrate_yaml_to_db.py # YAML → DB 迁移
 └── data/                      # 运行时音频
 ```
+# 启用 NPU 推理步骤
+# 未来在 BM1684X 服务器上启用时:
+# 1. 导出 ONNX
+# python tool/convert_to_bmodel.py --all
+# 2. 安装 TPU-MLIR 后生成 BModel
+# python tool/convert_to_bmodel.py --all --with-bmodel --quantize fp16
+# 3. 取消 multi_model_predictor.py 中的注释
+#    - import BM1684XEngine, is_bm1684x_available
+#    - self.bm_engine = self._load_bmodel()
+#    - _load_bmodel() 方法
+# 4. 修改 _compute_reconstruction_error() 中
+#    判断 device_predictor.bm_engine 是否存在,优先调用 NPU 推理

+ 16 - 6
auto_training/data_cleanup.py

@@ -24,10 +24,18 @@ logger = logging.getLogger('DataCleanup')
 class DataCleaner:
     """数据清理器"""
     
-    def __init__(self, config_file: Path):
-        """初始化清理器"""
-        self.config_file = config_file
-        self.config = self._load_config()
+    def __init__(self, config_file: Path = None, config: dict = None):
+        # 支持两种初始化方式:
+        # 1. 传 config dict(从数据库读取后直接传入,主程序使用)
+        # 2. 传 config_file YAML 路径(命令行独立运行使用)
+        if config is not None:
+            self.config = config
+            self.config_file = None
+        elif config_file is not None:
+            self.config_file = config_file
+            self.config = self._load_config()
+        else:
+            raise ValueError("必须提供 config_file 或 config 之一")
         
         # 路径配置
         self.deploy_root = Path(__file__).parent.parent
@@ -37,7 +45,7 @@ class DataCleaner:
         self.logs_dir = self.deploy_root / "logs"
     
     def _load_config(self):
-        """加载配置"""
+        # 从 YAML 文件加载配置(仅 config_file 模式使用)
         with open(self.config_file, 'r', encoding='utf-8') as f:
             return yaml.safe_load(f)
     
@@ -60,12 +68,14 @@ class DataCleaner:
                 continue
             
             for date_dir in device_dir.iterdir():
-                if not date_dir.is_dir() or date_dir.name == "current":
+                # current: 正在写入的目录; verified_normal: 核查确认的正常音频(增训用)
+                if not date_dir.is_dir() or date_dir.name in ("current", "verified_normal"):
                     continue
                 
                 # 检查日期
                 if date_dir.name < cutoff_date:
                     if date_dir.exists():
+                        # rglob 递归统计所有子目录(normal/ + pump_transition/)中的音频
                         for f in date_dir.rglob("*.wav"):
                             total_size += f.stat().st_size
                             total_deleted += 1

+ 111 - 246
auto_training/incremental_trainer.py

@@ -57,15 +57,18 @@ class IncrementalTrainer:
     2. 增量训练:使用运行中采集的数据,对已有模型微调(兼容旧逻辑)
     """
 
-    def __init__(self, config_file: Path):
-        """
-        初始化训练器
-
-        Args:
-            config_file: auto_training.yaml 配置文件路径
-        """
-        self.config_file = config_file
-        self.config = self._load_config()
+    def __init__(self, config_file: Path = None, config: dict = None):
+        # 支持两种初始化方式:
+        # 1. 传 config dict(从数据库读取后直接传入,主程序使用)
+        # 2. 传 config_file YAML 路径(standalone_train.py 等独立工具使用)
+        if config is not None:
+            self.config = config
+            self.config_file = None
+        elif config_file is not None:
+            self.config_file = config_file
+            self.config = self._load_config()
+        else:
+            raise ValueError("必须提供 config_file 或 config 之一")
 
         # 路径配置
         self.deploy_root = Path(__file__).parent.parent
@@ -88,7 +91,7 @@ class IncrementalTrainer:
         self.cold_start_mode = False
 
     def _load_config(self) -> Dict:
-        # 加载配置文件
+        # 从 YAML 文件加载配置(仅 config_file 模式使用)
         with open(self.config_file, 'r', encoding='utf-8') as f:
             return yaml.safe_load(f)
 
@@ -192,30 +195,43 @@ class IncrementalTrainer:
             device_code = device_dir.name
             audio_files = []
 
-            # 冷启动模式:收集所有已归档日期目录的数据(跳过 current/)
+            # 冷启动模式:收集所有已归档日期目录的正常音频(跳过 current/)
             if self.cold_start_mode:
                 # 注意:跳过 current/ 目录,因其中可能包含 FFmpeg 正在写入的不完整文件
                 for sub_dir in device_dir.iterdir():
                     if sub_dir.is_dir() and sub_dir.name.isdigit() and len(sub_dir.name) == 8:
+                        # 新结构:从 {date}/normal/ 子目录读取
+                        normal_dir = sub_dir / "normal"
+                        if normal_dir.exists():
+                            audio_files.extend(list(normal_dir.glob("*.wav")))
+                            audio_files.extend(list(normal_dir.glob("*.mp4")))
+                        # 兼容旧结构:日期目录下直接存放的音频文件
                         audio_files.extend(list(sub_dir.glob("*.wav")))
                         audio_files.extend(list(sub_dir.glob("*.mp4")))
             else:
-                # 正常模式:只收集指定日期的目录
+                # 正常模式:只收集指定日期的正常音频
                 date_dir = device_dir / target_date
                 if date_dir.exists():
+                    # 新结构:从 {date}/normal/ 子目录读取
+                    normal_dir = date_dir / "normal"
+                    if normal_dir.exists():
+                        audio_files.extend(list(normal_dir.glob("*.wav")))
+                        audio_files.extend(list(normal_dir.glob("*.mp4")))
+                    # 兼容旧结构:日期目录下直接存放的音频文件
                     audio_files.extend(list(date_dir.glob("*.wav")))
                     audio_files.extend(list(date_dir.glob("*.mp4")))
 
-            # 加上 verified_normal 目录
+            # 加上 verified_normal 目录(单独收集,不参与采样和质量预筛)
             verified_dir = device_dir / "verified_normal"
+            verified_files = []
             if verified_dir.exists():
-                audio_files.extend(list(verified_dir.glob("*.wav")))
-                audio_files.extend(list(verified_dir.glob("*.mp4")))
+                verified_files.extend(list(verified_dir.glob("*.wav")))
+                verified_files.extend(list(verified_dir.glob("*.mp4")))
 
-            # 去重
+            # 去重(仅日期目录音频)
             audio_files = list(set(audio_files))
 
-            # 数据质量预筛:过滤能量/频谱异常的音频
+            # 数据质量预筛:仅对日期目录音频过滤,verified_normal 已经人工确认,跳过
             if audio_files and not self.cold_start_mode:
                 before_count = len(audio_files)
                 audio_files = self._filter_audio_quality(audio_files, device_code)
@@ -223,7 +239,7 @@ class IncrementalTrainer:
                 if filtered > 0:
                     logger.info(f"  {device_code}: 质量预筛过滤 {filtered} 个异常音频")
 
-            # 随机采样(如果配置了采样时长
+            # 随机采样(仅对日期目录音频采样,verified_normal 不参与
             if sample_hours > 0 and audio_files:
                 files_needed = int(sample_hours * 3600 / 60)
                 if len(audio_files) > files_needed:
@@ -234,6 +250,11 @@ class IncrementalTrainer:
             else:
                 logger.info(f"  {device_code}: {len(audio_files)} 个音频")
 
+            # 合并 verified_normal(采样后追加,保证全量参与训练)
+            if verified_files:
+                audio_files.extend(verified_files)
+                logger.info(f"  {device_code}: +{len(verified_files)} 个核查确认音频(verified_normal)")
+
             if audio_files:
                 device_files[device_code] = audio_files
 
@@ -302,8 +323,7 @@ class IncrementalTrainer:
     # ========================================
 
     def _extract_mel_for_device(self, device_code: str,
-                                wav_files: List[Path],
-                                inherit_scale: bool = False
+                                wav_files: List[Path]
                                 ) -> Tuple[Optional[Path], Optional[Tuple[float, float]]]:
         """
         为单个设备提取 Mel 特征并计算独立的 Min-Max 标准化参数
@@ -315,7 +335,6 @@ class IncrementalTrainer:
         Args:
             device_code: 设备编码
             wav_files: 该设备的音频文件列表
-            inherit_scale: 增量训练时是否继承已部署的 scale 参数
 
         Returns:
             (mel_dir, (global_min, global_max)),失败返回 (None, None)
@@ -369,18 +388,6 @@ class IncrementalTrainer:
             logger.warning(f"  {device_code}: 无有效数据")
             return None, None
 
-        # 增量训练时,用已部署的 scale 做 EMA 平滑,避免剧烈偏移
-        if inherit_scale:
-            old_scale = self._load_deployed_scale(device_code)
-            if old_scale is not None:
-                ema_alpha = 0.3  # 新数据权重
-                old_min, old_max = old_scale
-                global_min = ema_alpha * global_min + (1 - ema_alpha) * old_min
-                global_max = ema_alpha * global_max + (1 - ema_alpha) * old_max
-                logger.info(f"  {device_code}: scale EMA 融合 | "
-                            f"old=[{old_min:.4f}, {old_max:.4f}] → "
-                            f"new=[{global_min:.4f}, {global_max:.4f}]")
-
         logger.info(f"  {device_code}: {patch_count} patches | "
                     f"min={global_min:.4f} max={global_max:.4f}")
 
@@ -396,20 +403,7 @@ class IncrementalTrainer:
 
         return device_mel_dir, (global_min, global_max)
 
-    def _load_deployed_scale(self, device_code: str) -> Optional[Tuple[float, float]]:
-        """加载已部署的 global_scale.npy,用于增量训练时的 scale 继承"""
-        scale_path = self.model_root / device_code / "global_scale.npy"
-        if not scale_path.exists():
-            return None
-        try:
-            scale = np.load(scale_path)
-            return float(scale[0]), float(scale[1])
-        except Exception as e:
-            logger.warning(f"加载旧 scale 失败: {device_code} | {e}")
-            return None
-
-    def prepare_mel_features_per_device(self, device_files: Dict[str, List[Path]],
-                                        inherit_scale: bool = False
+    def prepare_mel_features_per_device(self, device_files: Dict[str, List[Path]]
                                         ) -> Dict[str, Tuple[Path, Tuple[float, float]]]:
         """
         为每个设备独立提取 Mel 特征
@@ -418,7 +412,6 @@ class IncrementalTrainer:
 
         Args:
             device_files: {device_code: [wav_files]}
-            inherit_scale: 增量训练时传 True,将新旧 scale 做 EMA 融合
 
         Returns:
             {device_code: (mel_dir, (global_min, global_max))}
@@ -433,9 +426,7 @@ class IncrementalTrainer:
         device_results = {}
 
         for device_code, wav_files in device_files.items():
-            mel_dir, scale = self._extract_mel_for_device(
-                device_code, wav_files, inherit_scale=inherit_scale
-            )
+            mel_dir, scale = self._extract_mel_for_device(device_code, wav_files)
             if mel_dir is not None:
                 device_results[device_code] = (mel_dir, scale)
 
@@ -451,8 +442,8 @@ class IncrementalTrainer:
     # ========================================
 
     def _select_training_device(self) -> torch.device:
-        # 智能选择训练设备:GPU/NPU 显存充足则使用,否则回退 CPU
-        # 训练配置中可通过 training_device 强制指定 (auto/cpu/cuda/npu)
+        # 智能选择训练设备:GPU 显存充足则使用,否则回退 CPU
+        # 训练配置中可通过 training_device 强制指定 (auto/cpu/cuda)
         training_cfg = self.config['auto_training']['incremental']
         forced_device = training_cfg.get('training_device', 'auto')
 
@@ -465,14 +456,8 @@ class IncrementalTrainer:
                 return torch.device('cuda')
             logger.warning("配置指定 CUDA 但不可用,回退 CPU")
             return torch.device('cpu')
-        if forced_device == 'npu':
-            if self._npu_available():
-                return torch.device('npu')
-            logger.warning("配置指定 NPU 但不可用,回退 CPU")
-            return torch.device('cpu')
 
-        # auto 模式:依次检测 CUDA → NPU → CPU
-        # 1. 检测 CUDA
+        # auto 模式:检测 CUDA → CPU
         if torch.cuda.is_available():
             try:
                 free_mem = torch.cuda.mem_get_info()[0] / (1024 * 1024)
@@ -486,32 +471,9 @@ class IncrementalTrainer:
             except Exception as e:
                 logger.warning(f"CUDA 显存检测失败: {e}")
 
-        # 2. 检测 NPU (华为昇腾)
-        if self._npu_available():
-            try:
-                free_mem = torch.npu.mem_get_info()[0] / (1024 * 1024)
-                min_gpu_mem_mb = training_cfg.get('min_gpu_mem_mb', 512)
-                if free_mem >= min_gpu_mem_mb:
-                    logger.info(f"训练设备: NPU(空闲显存 {free_mem:.0f}MB)")
-                    return torch.device('npu')
-                logger.info(
-                    f"NPU 空闲显存不足 ({free_mem:.0f}MB < {min_gpu_mem_mb}MB)"
-                )
-            except Exception as e:
-                logger.warning(f"NPU 显存检测失败: {e},回退 CPU")
-
         logger.info("训练设备: CPU")
         return torch.device('cpu')
 
-    @staticmethod
-    def _npu_available() -> bool:
-        """检查华为昇腾 NPU 是否可用"""
-        try:
-            import torch_npu  # noqa: F401
-            return torch.npu.is_available()
-        except ImportError:
-            return False
-
     def _run_training_loop(self, device_code: str, model: nn.Module,
                            train_loader, val_loader, epochs: int, lr: float,
                            device: torch.device) -> Tuple[nn.Module, float]:
@@ -522,8 +484,8 @@ class IncrementalTrainer:
         optimizer = torch.optim.Adam(model.parameters(), lr=lr)
         criterion = nn.MSELoss()
 
-        # AMP 混合精度(GPU/NPU 生效,减少约 40% 显存占用)
-        use_amp = device.type in ('cuda', 'npu')
+        # AMP 混合精度(GPU 生效,减少约 40% 显存占用)
+        use_amp = device.type == 'cuda'
         scaler = torch.amp.GradScaler(device.type) if use_amp else None
 
         # 早停配置
@@ -608,11 +570,9 @@ class IncrementalTrainer:
                            f"最终轮数={actual_epochs}/{epochs} | Loss={avg_loss:.6f}")
                 break
 
-        # 训练后清理加速器缓存
+        # 训练后清理 GPU 缓存
         if device.type == 'cuda':
             torch.cuda.empty_cache()
-        elif device.type == 'npu':
-            torch.npu.empty_cache()
 
         if actual_epochs < epochs:
             logger.info(f"  [{device_code}] 早停节省 {epochs - actual_epochs} 轮训练")
@@ -631,15 +591,11 @@ class IncrementalTrainer:
         # 智能选择训练设备
         device = self._select_training_device()
 
-        # 训练前清理加速器缓存,释放推理残留的显存碎片
+        # 训练前清理 GPU 缓存,释放推理残留的显存碎片
         if device.type == 'cuda':
             torch.cuda.empty_cache()
             import gc
             gc.collect()
-        elif device.type == 'npu':
-            torch.npu.empty_cache()
-            import gc
-            gc.collect()
 
         model = ConvAutoencoder()
 
@@ -684,26 +640,22 @@ class IncrementalTrainer:
             logger.info(f"  数据量不足20,跳过验证集划分(共{len(dataset)}样本)")
 
         # 尝试在选定设备上训练
-        if device.type in ('cuda', 'npu'):
+        if device.type == 'cuda':
             try:
                 return self._run_training_loop(
                     device_code, model, train_loader, val_loader,
                     epochs, lr, device
                 )
             except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
-                # GPU/NPU OOM -> 清理显存后回退 CPU 重试
+                # GPU OOM -> 清理显存后回退 CPU 重试
                 if 'out of memory' not in str(e).lower() and isinstance(e, RuntimeError):
                     raise  # 非 OOM 的 RuntimeError 不拦截
                 logger.warning(
-                    f"  [{device_code}] {device.type.upper()} OOM,"
-                    f"清理显存后回退 CPU 训练"
+                    f"  [{device_code}] CUDA OOM,清理显存后回退 CPU 训练"
                 )
                 import gc
                 gc.collect()
-                if device.type == 'cuda':
-                    torch.cuda.empty_cache()
-                elif device.type == 'npu':
-                    torch.npu.empty_cache()
+                torch.cuda.empty_cache()
                 # 模型可能处于脏状态,重新初始化
                 model = ConvAutoencoder()
                 if not from_scratch:
@@ -812,6 +764,31 @@ class IncrementalTrainer:
 
         return threshold
 
+    def _eval_model_error(self, model: nn.Module, mel_dir: Path) -> float:
+        """在验证数据上计算模型的平均重建误差,用于新旧模型对比"""
+        device = next(model.parameters()).device
+        model.eval()
+
+        dataset = MelNPYDataset(mel_dir)
+        if len(dataset) == 0:
+            return float('inf')
+
+        dataloader = torch.utils.data.DataLoader(
+            dataset, batch_size=64, shuffle=False
+        )
+
+        all_errors = []
+        with torch.no_grad():
+            for batch in dataloader:
+                batch = batch.to(device)
+                output = model(batch)
+                output = align_to_target(output, batch)
+                mse = torch.mean((output - batch) ** 2, dim=[1, 2, 3])
+                all_errors.append(mse.cpu().numpy())
+
+        errors = np.concatenate(all_errors)
+        return float(np.mean(errors))
+
     # ========================================
     # 全量训练入口
     # ========================================
@@ -954,7 +931,6 @@ class IncrementalTrainer:
                   else self.config['auto_training']['incremental']['learning_rate'])
 
             from_scratch = self.cold_start_mode
-            inherit_scale = not self.cold_start_mode
 
             model_cfg = self.config['auto_training']['model']
             rollback_enabled = model_cfg.get('rollback_on_degradation', True)
@@ -973,7 +949,7 @@ class IncrementalTrainer:
                 try:
                     # ── 4a. 单设备特征提取 ──
                     mel_dir, scale = self._extract_mel_for_device(
-                        device_code, wav_files, inherit_scale=inherit_scale
+                        device_code, wav_files
                     )
                     if mel_dir is None:
                         logger.warning(f"{device_code}: 特征提取无有效数据,跳过")
@@ -989,52 +965,49 @@ class IncrementalTrainer:
                         logger.error(f"{device_code}: 形状验证失败,跳过部署")
                         continue
 
-                    # ── 4d. 损失退化检测(增量训练时生效) ──
+                    # ── 4d. 新旧模型对比(增量训练时生效) ──
+                    # 在相同验证数据上比较新旧模型的重建误差,新模型更差则跳过部署
                     if rollback_enabled and not self.cold_start_mode:
-                        old_threshold = self._get_old_threshold(device_code)
-                        if old_threshold and old_threshold > 0:
-                            if final_loss > old_threshold * rollback_factor:
+                        old_model_path = self.model_root / device_code / "ae_model.pth"
+                        if old_model_path.exists():
+                            new_avg_err = self._eval_model_error(model, mel_dir)
+                            old_model = ConvAutoencoder()
+                            old_model.load_state_dict(
+                                torch.load(old_model_path, map_location='cpu')
+                            )
+                            old_avg_err = self._eval_model_error(old_model, mel_dir)
+
+                            logger.info(
+                                f"  {device_code}: 新旧模型对比 | "
+                                f"旧模型误差={old_avg_err:.6f} 新模型误差={new_avg_err:.6f}"
+                            )
+
+                            if new_avg_err > old_avg_err * rollback_factor:
                                 logger.warning(
-                                    f"{device_code}: 损失退化检测触发 | "
-                                    f"训练损失={final_loss:.6f} > "
-                                    f"旧阈值={old_threshold:.6f} × {rollback_factor} = "
-                                    f"{old_threshold * rollback_factor:.6f}"
+                                    f"{device_code}: 新模型退化 | "
+                                    f"新={new_avg_err:.6f} > 旧={old_avg_err:.6f} × {rollback_factor},跳过部署"
                                 )
                                 degraded_count += 1
                                 continue
 
-                    # ── 4e. 阈值偏移检测 + 部署 ──
+                    # ── 4e. 部署 ──
                     if model_cfg.get('auto_deploy', True):
-                        if rollback_enabled and not self.cold_start_mode:
-                            old_threshold = self._get_old_threshold(device_code)
-                            if old_threshold and old_threshold > 0:
-                                new_threshold = self._compute_threshold(model, mel_dir)
-                                drift_ratio = abs(new_threshold - old_threshold) / old_threshold
-                                # 记录阈值变化趋势(用于长期漂移监控)
-                                self._log_threshold_history(
-                                    device_code, target_date,
-                                    old_threshold, new_threshold, final_loss
-                                )
-                                if drift_ratio > 0.3:
-                                    logger.warning(
-                                        f"{device_code}: 阈值偏移告警 | "
-                                        f"旧={old_threshold:.6f} → "
-                                        f"新={new_threshold:.6f} | "
-                                        f"偏移={drift_ratio:.1%}"
-                                    )
-                                    if drift_ratio > 1.0:
-                                        logger.warning(
-                                            f"{device_code}: 阈值偏移过大"
-                                            f"(>{drift_ratio:.0%}),跳过部署"
-                                        )
-                                        degraded_count += 1
-                                        continue
                         self.deploy_device_model(device_code, model, scale, mel_dir)
 
                     success_count += 1
                     logger.info(f"{device_code}: 训练+部署完成 | loss={final_loss:.6f}")
 
-                    # ── 4f. 即时通知该设备模型重载 ──
+                    # ── 4f. 清理已参与训练的 verified_normal 目录 ──
+                    # 核查确认的音频已被模型吸收,训练后清空释放磁盘空间
+                    verified_dir = self.audio_root / device_code / "verified_normal"
+                    if verified_dir.exists():
+                        v_count = len(list(verified_dir.glob("*")))
+                        if v_count > 0:
+                            shutil.rmtree(verified_dir)
+                            verified_dir.mkdir(parents=True, exist_ok=True)
+                            logger.info(f"{device_code}: 已清理 verified_normal ({v_count} 个文件)")
+
+                    # ── 4g. 即时通知该设备模型重载 ──
                     if on_device_trained:
                         try:
                             on_device_trained(device_code)
@@ -1045,7 +1018,7 @@ class IncrementalTrainer:
                     logger.error(f"{device_code}: 训练失败 | {e}", exc_info=True)
 
                 finally:
-                    # ── 4g. 清理该设备的临时 Mel 文件,释放磁盘空间 ──
+                    # ── 4h. 清理该设备的临时 Mel 文件,释放磁盘空间 ──
                     device_mel_dir = self.temp_mel_dir / device_code
                     if device_mel_dir.exists():
                         shutil.rmtree(device_mel_dir)
@@ -1064,9 +1037,6 @@ class IncrementalTrainer:
                     f"{success_count} 个设备部署成功"
                 )
 
-            # 6. 更新分类器基线
-            self._update_classifier_baseline(device_files)
-
             logger.info("=" * 60)
             logger.info(f"增量训练完成: {success_count}/{device_count} 个设备成功")
             if degraded_count > 0:
@@ -1086,60 +1056,6 @@ class IncrementalTrainer:
     # 辅助方法
     # ========================================
 
-    def _get_old_threshold(self, device_code: str) -> float:
-        """
-        读取设备当前已部署的阈值(训练前的旧阈值)
-
-        用于损失退化校验:新模型的训练损失不应远超旧阈值。
-        阈值文件路径: models/{device_code}/thresholds/threshold_{device_code}.npy
-
-        返回:
-            阈值浮点数,文件不存在时返回 0.0
-        """
-        threshold_file = self.model_root / device_code / "thresholds" / f"threshold_{device_code}.npy"
-        if not threshold_file.exists():
-            return 0.0
-        try:
-            data = np.load(threshold_file)
-            return float(data.flat[0])
-        except Exception as e:
-            logger.warning(f"读取旧阈值失败: {device_code} | {e}")
-            return 0.0
-
-    def _log_threshold_history(self, device_code: str, date_str: str,
-                                old_threshold: float, new_threshold: float,
-                                train_loss: float):
-        """
-        记录阈值变化历史到 CSV,用于监控模型长期漂移趋势
-
-        文件路径: logs/threshold_history.csv
-        格式: date,device_code,old_threshold,new_threshold,drift_ratio,train_loss
-        """
-        import csv
-
-        log_dir = self.deploy_root / "logs"
-        log_dir.mkdir(parents=True, exist_ok=True)
-        csv_path = log_dir / "threshold_history.csv"
-
-        drift_ratio = (new_threshold - old_threshold) / old_threshold if old_threshold > 0 else 0.0
-        write_header = not csv_path.exists()
-
-        try:
-            with open(csv_path, 'a', newline='', encoding='utf-8') as f:
-                writer = csv.writer(f)
-                if write_header:
-                    writer.writerow([
-                        'date', 'device_code', 'old_threshold', 'new_threshold',
-                        'drift_ratio', 'train_loss'
-                    ])
-                writer.writerow([
-                    date_str, device_code,
-                    f"{old_threshold:.8f}", f"{new_threshold:.8f}",
-                    f"{drift_ratio:.4f}", f"{train_loss:.8f}"
-                ])
-        except Exception as e:
-            logger.warning(f"写入阈值历史失败: {e}")
-
     def _validate_model(self, model: nn.Module) -> bool:
         # 验证模型输出形状是否合理
         if not self.config['auto_training']['validation']['enabled']:
@@ -1235,57 +1151,6 @@ class IncrementalTrainer:
         logger.info(f"恢复完成: {restored} 个设备")
         return restored > 0
 
-    def _update_classifier_baseline(self, device_files: Dict[str, List[Path]]):
-        # 从训练数据计算并更新分类器基线
-        logger.info("更新分类器基线")
-
-        try:
-            import librosa
-            from core.anomaly_classifier import AnomalyClassifier
-
-            classifier = AnomalyClassifier()
-
-            all_files = []
-            for files in device_files.values():
-                all_files.extend(files)
-
-            if not all_files:
-                logger.warning("无音频文件,跳过基线更新")
-                return
-
-            sample_files = random.sample(all_files, min(50, len(all_files)))
-
-            all_features = []
-            for wav_file in sample_files:
-                try:
-                    y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True)
-                    if len(y) < CFG.SR:
-                        continue
-                    features = classifier.extract_features(y, sr=CFG.SR)
-                    if features:
-                        all_features.append(features)
-                except Exception:
-                    continue
-
-            if not all_features:
-                logger.warning("无法提取特征,跳过基线更新")
-                return
-
-            baseline = {}
-            keys = all_features[0].keys()
-            for key in keys:
-                if key == 'has_periodic':
-                    values = [f[key] for f in all_features]
-                    baseline[key] = sum(values) > len(values) / 2
-                else:
-                    values = [f[key] for f in all_features]
-                    baseline[key] = float(np.mean(values))
-
-            classifier.save_baseline(baseline)
-            logger.info(f"  基线已更新 (样本数: {len(all_features)})")
-
-        except Exception as e:
-            logger.warning(f"更新基线失败: {e}")
 
 
 def main():

+ 2 - 2
config/auto_training.yaml

@@ -25,8 +25,8 @@ auto_training:
     learning_rate: 0.0001       # 学习率
     batch_size: 32              # 批大小(降低显存占用)
     early_stop_patience: 5      # 早停耐心值:连续N轮loss无改善则停止
-    training_device: cpu           # 训练设备选择:auto(自动检测显存)/cpu/cuda/npu
-                                    # 低配服务器推荐 cpu,模型小(~214KB) CPU训练30epoch耗时可接受
+    training_device: cpu           # 训练设备选择:auto(自动检测GPU显存)/cpu/cuda
+                                    # 低配服务器推荐 cpu,模型小(~192KB) CPU训练30epoch耗时可接受
     min_gpu_mem_mb: 512          # auto模式下,GPU空闲显存低于此值(MB)时回退CPU
     
   # 模型管理

+ 1 - 1
config/config_manager.py

@@ -49,7 +49,7 @@ class ConfigManager:
         config['plants'] = self._build_plants_list()
 
         # 2. 组装系统级配置(audio, prediction, push_notification, scada_api, human_detection)
-        for section in ['audio', 'prediction', 'push_notification', 'scada_api', 'human_detection']:
+        for section in ['audio', 'prediction', 'push_notification', 'scada_api', 'human_detection', 'auto_training']:
             config[section] = self._get_section_config(section)
 
         return config

BIN
config/pickup_config.db


BIN
config/pickup_config.db-shm


+ 0 - 0
config/pickup_config.db-wal


BIN
config/yaml_backup/db_output/pickup_config_anzhen.db


BIN
config/yaml_backup/db_output/pickup_config_jianding.db


BIN
config/yaml_backup/db_output/pickup_config_longting.db


BIN
config/yaml_backup/db_output/pickup_config_longting.db-shm


+ 0 - 0
config/yaml_backup/db_output/pickup_config_longting.db-wal


BIN
config/yaml_backup/db_output/pickup_config_xishan.db


BIN
config/yaml_backup/db_output/pickup_config_xishan.db-shm


+ 0 - 0
config/yaml_backup/db_output/pickup_config_xishan.db-wal


BIN
config/yaml_backup/db_output/pickup_config_yancheng.db


+ 12 - 1
config/yaml_backup/rtsp_config_anzhen.yaml

@@ -83,10 +83,21 @@ push_notification:
     window_seconds: 300
     min_devices: 2
 
+
+  # ----------------------------------------------------------
+  # 项目模式调度(参观/检修/调试模式下自动暂停异响检测)
+  # ----------------------------------------------------------
+  project_mode:
+    base_url: http://120.55.44.4:8900    # 平台 API 根地址
+    poll_interval: 60                     # 查询间隔(秒)
+    request_timeout: 10                   # 请求超时(秒)
+
+  # ----------
+
 scada_api:
   enabled: true
   base_url: http://120.55.44.4:8900/api/v1/jinke-cloud/db/device/history-data
-  realtime_url: http://47.96.12.136:8788/api/v1/jinke-cloud/device/current-data
+  realtime_url: http://120.55.44.4:8900/api/v1/jinke-cloud/device/current-data
   jwt_token: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJJRCI6NywiVXNlcm5hbWUiOiJhZG1pbiIsIkRlcCI6IjEzNSIsImV4cCI6MTc3NjExOTExNCwiaXNzIjoiZ2luLWJsb2cifQ.0HTtzHZjyd2mHo8VCy8icYROxmntRMuQhyoZsAYRL_M
   timeout: 10
 

+ 1 - 1
config/yaml_backup/rtsp_config_jianding.yaml

@@ -82,7 +82,7 @@ push_notification:
 scada_api:
   enabled: true
   base_url: http://120.55.44.4:8900/api/v1/jinke-cloud/db/device/history-data
-  realtime_url: http://47.96.12.136:8788/api/v1/jinke-cloud/device/current-data
+  realtime_url: http://120.55.44.4:8900/api/v1/jinke-cloud/device/current-data
   jwt_token: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJJRCI6NywiVXNlcm5hbWUiOiJhZG1pbiIsIkRlcCI6IjEzNSIsImV4cCI6MTc3NjExOTExNCwiaXNzIjoiZ2luLWJsb2cifQ.0HTtzHZjyd2mHo8VCy8icYROxmntRMuQhyoZsAYRL_M
   timeout: 10
 

+ 44 - 1
config/yaml_backup/rtsp_config_longting.yaml

@@ -105,13 +105,21 @@ push_notification:
     window_seconds: 300
     min_devices: 2
 
+# ----------------------------------------------------------
+# 项目模式调度(参观/检修/调试模式下自动暂停异响检测)
+# ----------------------------------------------------------
+project_mode:
+  base_url: http://120.55.44.4:8900    # 平台 API 根地址
+  poll_interval: 60                     # 查询间隔(秒)
+  request_timeout: 10                   # 请求超时(秒)
+
 # ----------------------------------------------------------
 # SCADA/PLC 接口
 # ----------------------------------------------------------
 scada_api:
   enabled: true
   base_url: http://120.55.44.4:8900/api/v1/jinke-cloud/db/device/history-data
-  realtime_url: http://47.96.12.136:8788/api/v1/jinke-cloud/device/current-data
+  realtime_url: http://120.55.44.4:8900/api/v1/jinke-cloud/device/current-data
   jwt_token: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJJRCI6NywiVXNlcm5hbWUiOiJhZG1pbiIsIkRlcCI6IjEzNSIsImV4cCI6MTc3NjExOTExNCwiaXNzIjoiZ2luLWJsb2cifQ.0HTtzHZjyd2mHo8VCy8icYROxmntRMuQhyoZsAYRL_M
   timeout: 10
 
@@ -122,3 +130,38 @@ human_detection:
   enabled: false
   db_path: /data/human_detector/detection_status.db
   cooldown_minutes: 5
+
+# ----------------------------------------------------------
+# 自动增量训练
+# ----------------------------------------------------------
+auto_training:
+  enabled: True                      # 总开关(暂时关闭自动增训)
+  data:
+    keep_normal_days: 7               # 正常音频保留天数
+    keep_anomaly_days: -1             # 异常音频保留天数(-1=永久)
+    cleanup_time: "00:00"             # 每日清理时间(0点)
+  incremental:
+    enabled: true
+    schedule_time: "02:00"            # 每日训练时间
+    use_days_ago: 1                   # 使用N天前的数据(1=昨天)
+    sample_hours: 1                   # 随机采样时长(小时),0=使用全部
+    min_samples: 50                   # 最少样本数,不足则跳过
+    epochs: 30                        # 训练轮数(配合早停,实际通常更少)
+    learning_rate: 0.0001             # 学习率
+    batch_size: 32                    # 批大小(降低显存占用)
+    early_stop_patience: 5            # 早停耐心值:连续N轮loss无改善则停止
+    training_device: auto
+    min_gpu_mem_mb: 512               # auto模式下GPU空闲显存低于此值(MB)时回退CPU
+  model:
+    backup_before_train: true         # 训练前备份
+    keep_backups: 7                   # 保留备份数量
+    auto_deploy: true                 # 自动部署新模型
+    update_thresholds: true           # 训练后更新阈值npy
+    rollback_on_degradation: true     # 训练后损失异常时自动回滚到备份
+    rollback_factor: 2.0              # 新模型损失 > 旧阈值 * 此因子则判定为退化
+  validation:
+    enabled: true
+  cold_start:
+    enabled: true
+    wait_hours: 2                     # 等待收集数据时长
+    min_samples: 100                  # 最少样本数

+ 46 - 1
config/yaml_backup/rtsp_config_xishan.yaml

@@ -121,13 +121,21 @@ push_notification:
     window_seconds: 300              # 聚合窗口(秒)
     min_devices: 2                   # 至少 2 台设备同时异常才触发聚合告警
 
+# ----------------------------------------------------------
+# 项目模式调度(参观/检修/调试模式下自动暂停异响检测)
+# ----------------------------------------------------------
+project_mode:
+  base_url: http://120.55.44.4:8900    # 平台 API 根地址
+  poll_interval: 60                     # 查询间隔(秒)
+  request_timeout: 10                   # 请求超时(秒)
+
 # ----------------------------------------------------------
 # SCADA/PLC 接口(泵状态查询)
 # ----------------------------------------------------------
 scada_api:
   enabled: true                      # 是否启用 PLC 查询(false 时用音频能量判断启停)
   base_url: http://120.55.44.4:8900/api/v1/jinke-cloud/db/device/history-data    # 历史数据接口
-  realtime_url: http://47.96.12.136:8788/api/v1/jinke-cloud/device/current-data  # 实时数据接口
+  realtime_url: http://120.55.44.4:8900/api/v1/jinke-cloud/device/current-data  # 实时数据接口
   jwt_token: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJJRCI6NywiVXNlcm5hbWUiOiJhZG1pbiIsIkRlcCI6IjEzNSIsImV4cCI6MTc3NjExOTExNCwiaXNzIjoiZ2luLWJsb2cifQ.0HTtzHZjyd2mHo8VCy8icYROxmntRMuQhyoZsAYRL_M
   timeout: 10                        # 查询超时(秒)
 
@@ -138,3 +146,40 @@ human_detection:
   enabled: false                     # 是否启用(需要独立的人体检测服务)
   db_path: /data/human_detector/detection_status.db  # 人体检测状态 DB 路径
   cooldown_minutes: 5                # 检测到有人后抑制告警的时间(分钟)
+
+
+# ----------------------------------------------------------
+# 自动增量训练
+# ----------------------------------------------------------
+auto_training:
+  enabled: True                      # 总开关(暂时关闭自动增训)
+  data:
+    keep_normal_days: 7               # 正常音频保留天数
+    keep_anomaly_days: -1             # 异常音频保留天数(-1=永久)
+    cleanup_time: "00:00"             # 每日清理时间(0点)
+  incremental:
+    enabled: true
+    schedule_time: "18:00"            # 每日训练时间
+    use_days_ago: 1                   # 使用N天前的数据(1=昨天)
+    sample_hours: 1                   # 随机采样时长(小时),0=使用全部
+    min_samples: 50                   # 最少样本数,不足则跳过
+    epochs: 30                        # 训练轮数(配合早停,实际通常更少)
+    learning_rate: 0.0001             # 学习率
+    batch_size: 32                    # 批大小(降低显存占用)
+    early_stop_patience: 5            # 早停耐心值:连续N轮loss无改善则停止
+    training_device: auto
+    min_gpu_mem_mb: 512               # auto模式下GPU空闲显存低于此值(MB)时回退CPU
+  model:
+    backup_before_train: true         # 训练前备份
+    keep_backups: 7                   # 保留备份数量
+    auto_deploy: true                 # 自动部署新模型
+    update_thresholds: true           # 训练后更新阈值npy
+    rollback_on_degradation: true     # 训练后损失异常时自动回滚到备份
+    rollback_factor: 2.0              # 新模型损失 > 旧阈值 * 此因子则判定为退化
+  validation:
+    enabled: true
+  cold_start:
+    enabled: true
+    wait_hours: 2                     # 等待收集数据时长
+    min_samples: 100                  # 最少样本数
+

+ 1 - 1
config/yaml_backup/rtsp_config_yancheng.yaml

@@ -86,7 +86,7 @@ push_notification:
 scada_api:
   enabled: true
   base_url: http://120.55.44.4:8900/api/v1/jinke-cloud/db/device/history-data
-  realtime_url: http://47.96.12.136:8788/api/v1/jinke-cloud/device/current-data
+  realtime_url: http://120.55.44.4:8900/api/v1/jinke-cloud/device/current-data
   jwt_token: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJJRCI6NywiVXNlcm5hbWUiOiJhZG1pbiIsIkRlcCI6IjEzNSIsImV4cCI6MTc3NjExOTExNCwiaXNzIjoiZ2luLWJsb2cifQ.0HTtzHZjyd2mHo8VCy8icYROxmntRMuQhyoZsAYRL_M
   timeout: 10
 

+ 1 - 1
core/pump_state_monitor.py

@@ -11,7 +11,7 @@ pump_state_monitor.py - 泵状态监控模块
 from pump_state_monitor import PumpStateMonitor
 
 monitor = PumpStateMonitor(
-    scada_url="http://47.96.12.136:8788/api/v1/jinke-cloud/device/current-data",
+    scada_url="http://120.55.44.4:8900/api/v1/jinke-cloud/device/current-data",
     scada_jwt="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJJRCI6NywiVXNlcm5hbWUiOiJhZG1pbiIsIkRlcCI6IjEzNSIsImV4cCI6MTc3NjExOTExNCwiaXNzIjoiZ2luLWJsb2cifQ.0HTtzHZjyd2mHo8VCy8icYROxmntRMuQhyoZsAYRL_M",
     project_id=92,
     transition_window_minutes=15

BIN
models/LT-2/ae_model.pth


BIN
models/LT-2/global_scale.npy


BIN
models/LT-2/thresholds/threshold_default.npy


BIN
models/LT-5/ae_model.pth


BIN
models/LT-5/global_scale.npy


BIN
models/LT-5/thresholds/threshold_default.npy


+ 31 - 2
predictor/multi_model_predictor.py

@@ -24,6 +24,10 @@ from .config import CFG, DeployConfig
 from .model_def import ConvAutoencoder
 from .utils import get_device
 
+# --- BM1684X NPU 推理适配(预埋,暂未启用) ---
+# 取消以下注释以启用 BM1684X NPU 推理(需先用 convert_to_bmodel.py 生成 .bmodel)
+from .bm1684x_engine import BM1684XEngine, is_bm1684x_available
+
 logger = logging.getLogger('MultiModelPredictor')
 
 
@@ -52,11 +56,18 @@ class DevicePredictor:
         # 阈值(标量)
         self.threshold = self._load_threshold()
 
+        # --- BM1684X NPU 推理适配(预埋,暂未启用) ---
+        # 如果 .bmodel 文件存在且 BM1684X 硬件可用,优先使用 NPU 推理
+        # 启用后 self.bm_engine 不为 None,推理时调用 bm_engine.predict() 而非 self.model()
+        self.bm_engine = None
+        self.bm_engine = self._load_bmodel()
+
         # 记录文件 mtime(用于热加载检测)
         self._model_mtime = self._get_mtime(self.model_path)
         self._scale_mtime = self._get_mtime(self.scale_path)
 
-        logger.info(f"设备 {device_code} 模型加载完成 | 目录: {model_subdir} | "
+        engine_type = "BM1684X BModel" if self.bm_engine else "PyTorch (.pth)"
+        logger.info(f"设备 {device_code} 模型加载完成 | 引擎: {engine_type} | 目录: {model_subdir} | "
                     f"阈值: {self.threshold:.6f}")
 
     def _get_mtime(self, path: Path) -> float:
@@ -65,7 +76,7 @@ class DevicePredictor:
             return os.path.getmtime(path)
         except OSError:
             return 0.0
-
+    
     def has_files_changed(self) -> bool:
         # 检查模型或标准化参数文件是否有更新
         new_model_mtime = self._get_mtime(self.model_path)
@@ -83,6 +94,24 @@ class DevicePredictor:
         model.eval()
         return model
 
+    # --- BM1684X NPU 推理适配(预埋,暂未启用) ---
+    # 取消注释以启用 BModel 加载逻辑
+    def _load_bmodel(self):
+        """尝试加载 BModel,成功返回 BM1684XEngine 实例,否则返回 None"""
+        bmodel_path = self.model_dir / "ae_model.bmodel"
+        if not bmodel_path.exists():
+            return None
+        if not is_bm1684x_available():
+            logger.info(f"设备 {self.device_code}: 存在 .bmodel 但 BM1684X 不可用,回退 PyTorch")
+            return None
+        try:
+            engine = BM1684XEngine(str(bmodel_path))
+            logger.info(f"设备 {self.device_code}: 已加载 BM1684X BModel 推理引擎")
+            return engine
+        except Exception as e:
+            logger.warning(f"设备 {self.device_code}: BModel 加载失败,回退 PyTorch | {e}")
+            return None
+
     def _load_scale(self) -> Tuple[float, float]:
         # 加载 Min-Max 标准化参数 [min, max]
         if not self.scale_path.exists():

+ 21 - 10
predictor/utils.py

@@ -32,25 +32,36 @@ def ensure_dirs():
 
 def get_device():
     """
-    获取可用的计算设备
+    获取可用的 PyTorch 计算设备
 
-    优先级: CUDA > NPU (华为昇腾) > CPU
+    优先级: CUDA > CPU
+
+    注意: BM1684X (算能) 不走此函数,它不是 PyTorch 后端。
+    BM1684X 推理由 bm1684x_engine.py 中的 sophon.sail 独立处理。
 
     返回:
-        str: "cuda", "npu" 或 "cpu"
+        str: "cuda" 或 "cpu"
     """
     if torch.cuda.is_available():
         return "cuda"
-    # 华为昇腾 NPU
-    try:
-        import torch_npu  # noqa: F401
-        if torch.npu.is_available():
-            return "npu"
-    except ImportError:
-        pass
     return "cpu"
 
 
+# --- BM1684X NPU 推理适配(预埋,暂未启用) ---
+# BM1684X 不是 PyTorch 后端,不能通过 get_device() 返回。
+# 以下函数用于检测 BM1684X 硬件是否可用,供 multi_model_predictor 判断。
+# 启用时从 bm1684x_engine.py import is_bm1684x_available 即可。
+#
+# def is_bm1684x_available() -> bool:
+#     """检测 BM1684X 硬件是否可用(SDK + 设备节点)"""
+#     try:
+#         import sophon.sail  # noqa: F401
+#     except ImportError:
+#         return False
+#     import glob
+#     return len(glob.glob("/dev/bm-sophon*")) > 0
+
+
 def align_to_target(pred, target):
     """
     将预测tensor对齐到目标tensor的尺寸

+ 1 - 0
requirements.txt

@@ -8,6 +8,7 @@
 
 # 深度学习
 torch>=2.0.0              # 模型推理
+onnxruntime>=1.14.0       # ONNX 模型验证 / BModel 转换前精度检查
 
 # 数值计算
 numpy>=1.23.0,<2.0        # PyTorch 2.x 不兼容 NumPy 2.x

+ 37 - 20
run_with_auto_training.py

@@ -27,12 +27,13 @@ sys.path.insert(0, str(Path(__file__).parent))
 try:
     from apscheduler.schedulers.background import BackgroundScheduler
     from apscheduler.triggers.cron import CronTrigger
-    import yaml
 except ImportError:
     print("错误:缺少依赖库")
-    print("请运行:pip install apscheduler pyyaml")
+    print("请运行:pip install apscheduler")
     sys.exit(1)
 
+from config.config_manager import ConfigManager
+
 
 def setup_logging():
     # 配置日志系统(按文件大小轮转),与 run_pickup_monitor.py 共用同一日志文件
@@ -90,6 +91,7 @@ class ColdStartManager:
         self.model_root = deploy_root / "models"
         # 音频数据根目录
         self.audio_root = deploy_root / "data" / "audio"
+        # self.audio_root = "/Volumes/mo/水厂正常音频/龙亭"
 
         # 冷启动配置
         cold_start_cfg = config.get('auto_training', {}).get('cold_start', {})
@@ -175,6 +177,11 @@ class ColdStartManager:
             total_samples = 0
             for sub_dir in device_dir.iterdir():
                 if sub_dir.is_dir():
+                    # 新结构:{date}/normal/ 子目录
+                    normal_dir = sub_dir / "normal"
+                    if normal_dir.exists():
+                        total_samples += len(list(normal_dir.glob("*.wav")))
+                    # 兼容旧结构 + current 目录:直接存放的 wav
                     total_samples += len(list(sub_dir.glob("*.wav")))
 
             if total_samples < self.min_samples:
@@ -203,8 +210,8 @@ class ColdStartManager:
         try:
             from auto_training.incremental_trainer import IncrementalTrainer
 
-            config_file = self.deploy_root / "config" / "auto_training.yaml"
-            trainer = IncrementalTrainer(config_file)
+            # 从当前内存中的配置 dict 初始化训练器(配置来源为数据库)
+            trainer = IncrementalTrainer(config=self.config)
 
             # 冷启动模式:收集所有目录的数据,用全量训练
             trainer.cold_start_mode = True
@@ -243,10 +250,27 @@ class IntegratedSystem:
 
     def __init__(self):
         self.deploy_root = Path(__file__).parent
-        self.auto_config_file = self.deploy_root / "config" / "auto_training.yaml"
 
-        # 加载自动训练配置
-        self.auto_config = self._load_yaml(self.auto_config_file)
+        # =========================================================================
+        # 配置加载来源:自动检测
+        #   优先使用 YAML(config/rtsp_config.yaml 存在时)
+        #   否则使用 SQLite 数据库
+        # =========================================================================
+        yaml_path = self.deploy_root / "config" / "rtsp_config.yaml"
+
+        if yaml_path.exists():
+            import yaml
+            with open(yaml_path, 'r', encoding='utf-8') as f:
+                full_config = yaml.safe_load(f)
+            self.full_yaml_config = full_config
+            self.auto_config = {'auto_training': full_config.get('auto_training', {})}
+            logger.info(f"已从 YAML ({yaml_path.name}) 加载配置")
+        else:
+            self.full_yaml_config = None
+            mgr = ConfigManager()
+            self.auto_config = {'auto_training': mgr.get_system_config('auto_training')}
+            mgr.close()
+            logger.info(f"已从数据库加载 auto_training 配置 ({len(self.auto_config.get('auto_training', {}))} 项)")
 
         # 运行时对象
         self.scheduler = None
@@ -254,15 +278,6 @@ class IntegratedSystem:
         self.cold_start_manager = None
         self.cold_start_thread = None
 
-    def _load_yaml(self, config_file: Path) -> dict:
-        # 加载 YAML 配置文件,不存在时返回空字典
-        if not config_file.exists():
-            logger.warning(f"配置文件不存在: {config_file}")
-            return {}
-
-        with open(config_file, 'r', encoding='utf-8') as f:
-            return yaml.safe_load(f) or {}
-
     def _check_and_handle_cold_start(self) -> bool:
         """
         检查并处理冷启动
@@ -406,7 +421,8 @@ class IntegratedSystem:
             logger.info("定时任务触发:增量训练开始")
 
             from auto_training.incremental_trainer import IncrementalTrainer
-            trainer = IncrementalTrainer(self.auto_config_file)
+            # 传 config dict 而非 YAML 路径,配置来源为数据库
+            trainer = IncrementalTrainer(config=self.auto_config)
             success = trainer.run_daily_training(
                 on_device_trained=self._reload_single_device
             )
@@ -425,7 +441,8 @@ class IntegratedSystem:
             logger.info("定时任务触发:数据清理开始")
             from auto_training.data_cleanup import DataCleaner
 
-            cleaner = DataCleaner(self.auto_config_file)
+            # 传 config dict 而非 YAML 路径,配置来源为数据库
+            cleaner = DataCleaner(config=self.auto_config)
             cleaner.run_cleanup()
         except Exception as e:
             logger.error(f"数据清理异常: {e}", exc_info=True)
@@ -439,7 +456,7 @@ class IntegratedSystem:
         # 1. 创建 PickupMonitoringSystem(会初始化 multi_predictor + 注册设备)
         logger.info("初始化监控系统...")
         from run_pickup_monitor import PickupMonitoringSystem
-        self.pickup_system = PickupMonitoringSystem()
+        self.pickup_system = PickupMonitoringSystem(yaml_config=self.full_yaml_config)
 
         # 2. 检查冷启动(需要在 pickup_system 初始化之后,因为需要设备注册信息)
         is_cold_start = self._check_and_handle_cold_start()
@@ -447,7 +464,7 @@ class IntegratedSystem:
         # 3. 设置定时任务
         self._setup_auto_training_tasks()
 
-        # 4. 覆盖信号处理(确保优雅关闭 scheduler)
+        # 4. 覆盖信号处理(确保关闭 scheduler)
         signal.signal(signal.SIGINT, self._signal_handler)
         signal.signal(signal.SIGTERM, self._signal_handler)
 

+ 104 - 21
start.sh

@@ -21,6 +21,24 @@ cd "$(dirname "$0")"
 
 # PID文件路径
 PID_FILE="logs/pid.txt"
+STARTUP_TIMEOUT=5
+HEALTH_CHECK_INTERVAL=1
+
+# ========================================
+# 函数:按PID精确清理PID文件
+# ========================================
+cleanup_pid_file_if_matches() {
+    local expected_pid="$1"
+    if [ ! -f "$PID_FILE" ]; then
+        return 0
+    fi
+
+    local current_pid
+    current_pid=$(cat "$PID_FILE" 2>/dev/null)
+    if [ -z "$expected_pid" ] || [ "$current_pid" = "$expected_pid" ]; then
+        rm -f "$PID_FILE"
+    fi
+}
 
 # ========================================
 # 函数:激活conda环境
@@ -34,16 +52,42 @@ activate_conda() {
     fi
 }
 
+# ========================================
+# 函数:检查PID是否为当前服务进程
+# ========================================
+is_expected_process() {
+    local pid="$1"
+    if [ -z "$pid" ]; then
+        return 1
+    fi
+
+    if ! ps -p "$pid" > /dev/null 2>&1; then
+        return 1
+    fi
+
+    local command
+    command=$(ps -p "$pid" -o command= 2>/dev/null)
+    case "$command" in
+        *"run_with_auto_training.py"*)
+            return 0
+            ;;
+        *)
+            return 1
+            ;;
+    esac
+}
+
 # ========================================
 # 函数:检查进程是否运行
 # ========================================
 is_running() {
     if [ -f "$PID_FILE" ]; then
         PID=$(cat "$PID_FILE")
-        # 检查进程是否存在
-        if ps -p "$PID" > /dev/null 2>&1; then
+        # 不仅检查PID是否存在,还要确认是本服务进程,避免PID复用误判
+        if is_expected_process "$PID"; then
             return 0  # 运行中
         fi
+        cleanup_pid_file_if_matches "$PID"
     fi
     return 1  # 未运行
 }
@@ -59,6 +103,43 @@ get_pid() {
     fi
 }
 
+# ========================================
+# 函数:等待服务稳定启动
+# ========================================
+wait_for_service_ready() {
+    local pid="$1"
+    local elapsed=0
+
+    while [ "$elapsed" -lt "$STARTUP_TIMEOUT" ]; do
+        if ! is_expected_process "$pid"; then
+            return 1
+        fi
+        sleep "$HEALTH_CHECK_INTERVAL"
+        elapsed=$((elapsed + HEALTH_CHECK_INTERVAL))
+    done
+
+    return 0
+}
+
+# ========================================
+# 函数:后台监控PID,进程退出后自动清理PID文件
+# ========================================
+spawn_pid_watcher() {
+    local watched_pid="$1"
+    nohup bash -c '
+        watched_pid="$1"
+        pid_file="$2"
+
+        while ps -p "$watched_pid" > /dev/null 2>&1; do
+            sleep 2
+        done
+
+        if [ -f "$pid_file" ] && [ "$(cat "$pid_file" 2>/dev/null)" = "$watched_pid" ]; then
+            rm -f "$pid_file"
+        fi
+    ' _ "$watched_pid" "$PID_FILE" > /dev/null 2>&1 &
+}
+
 # ========================================
 # 函数:启动服务
 # ========================================
@@ -78,16 +159,17 @@ start_service() {
         echo "错误: run_with_auto_training.py 不存在"
         exit 1
     fi
-    
-    if [ ! -f "config/pickup_config.db" ]; then
-        echo "错误: config/pickup_config.db 不存在"
-        echo "请先运行迁移脚本: python tool/migrate_yaml_to_db.py"
+
+    # 检查配置文件(YAML 或 DB 至少存在一个)
+    if [ ! -f "config/pickup_config.db" ] && [ ! -f "config/rtsp_config.yaml" ]; then
+        echo "错误: 找不到配置文件"
+        echo "需要 config/pickup_config.db 或 config/rtsp_config.yaml 之一"
         exit 1
     fi
-    
+
     # 创建日志目录
     mkdir -p logs
-    
+
     # 启动服务
     echo "后台运行模式..."
     # stdout/stderr 丢弃,所有日志由 RotatingFileHandler 写入 logs/system.log
@@ -95,9 +177,9 @@ start_service() {
     PID=$!
     echo $PID > "$PID_FILE"
     
-    # 等待1秒检查是否正常启动
-    sleep 1
-    if ps -p "$PID" > /dev/null 2>&1; then
+    # 等待一段观察窗口,避免“刚启动1秒就退出”仍被误判为成功
+    if wait_for_service_ready "$PID"; then
+        spawn_pid_watcher "$PID"
         echo "服务启动成功, PID: $PID"
         echo "日志文件: logs/system.log"
         echo ""
@@ -106,7 +188,7 @@ start_service() {
         echo "重启服务: ./start.sh restart"
     else
         echo "服务启动失败,请检查日志: logs/system.log"
-        rm -f "$PID_FILE"
+        cleanup_pid_file_if_matches "$PID"
         return 1
     fi
 }
@@ -117,7 +199,7 @@ start_service() {
 stop_service() {
     if ! is_running; then
         echo "服务未运行"
-        rm -f "$PID_FILE"
+        cleanup_pid_file_if_matches ""
         return 0
     fi
     
@@ -140,7 +222,7 @@ stop_service() {
         echo "等待进程结束... ($WAIT_COUNT/10)"
     done
     
-    rm -f "$PID_FILE"
+    cleanup_pid_file_if_matches "$PID"
     echo "服务已停止"
 }
 
@@ -185,7 +267,7 @@ show_status() {
         echo "状态: 未运行"
         if [ -f "$PID_FILE" ]; then
             echo "注意: PID文件存在但进程已停止,可能是异常退出"
-            rm -f "$PID_FILE"
+            cleanup_pid_file_if_matches ""
         fi
     fi
 }
@@ -209,16 +291,17 @@ run_foreground() {
         echo "错误: run_with_auto_training.py 不存在"
         exit 1
     fi
-    
-    if [ ! -f "config/pickup_config.db" ]; then
-        echo "错误: config/pickup_config.db 不存在"
-        echo "请先运行迁移脚本: python tool/migrate_yaml_to_db.py"
+
+    # 检查配置文件(YAML 或 DB 至少存在一个)
+    if [ ! -f "config/pickup_config.db" ] && [ ! -f "config/rtsp_config.yaml" ]; then
+        echo "错误: 找不到配置文件"
+        echo "需要 config/pickup_config.db 或 config/rtsp_config.yaml 之一"
         exit 1
     fi
-    
+
     # 创建日志目录
     mkdir -p logs
-    
+
     echo "前台运行模式..."
     python run_with_auto_training.py
 }

+ 1 - 1
tool/migrate_yaml_to_db.py

@@ -125,7 +125,7 @@ def migrate_yaml_to_db(yaml_path: str, db_path: str = None, force: bool = False)
     # ========================================
     # 2. 迁移系统级配置
     # ========================================
-    system_sections = ['audio', 'prediction', 'push_notification', 'scada_api', 'human_detection']
+    system_sections = ['audio', 'prediction', 'push_notification', 'scada_api', 'human_detection', 'auto_training']
 
     for section in system_sections:
         section_data = config.get(section, {})