#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ incremental_trainer.py ---------------------- 模型训练模块 功能: 1. 支持全量训练(指定外部数据目录,每设备独立模型) 2. 支持增量训练(每日自动训练) 3. 滑动窗口提取Mel特征(8秒patches) 4. 每设备独立计算标准化参数和阈值 5. 产出目录结构与推理端 MultiModelPredictor 对齐 产出目录结构: models/ ├── {device_code_1}/ │ ├── ae_model.pth │ ├── global_scale.npy │ └── thresholds/ │ └── threshold_{device_code_1}.npy └── {device_code_2}/ ├── ae_model.pth ├── global_scale.npy └── thresholds/ └── threshold_{device_code_2}.npy """ import sys import random import shutil import logging import numpy as np import torch import torch.nn as nn from pathlib import Path from datetime import datetime, timedelta from typing import List, Dict, Tuple, Optional import yaml # 添加父目录到路径 sys.path.insert(0, str(Path(__file__).parent.parent)) from predictor import CFG from predictor.model_def import ConvAutoencoder from predictor.datasets import MelNPYDataset from predictor.utils import align_to_target logger = logging.getLogger('IncrementalTrainer') class IncrementalTrainer: """ 模型训练器 支持两种训练模式: 1. 全量训练:指定外部数据目录,每设备从零训练独立模型 2. 增量训练:使用运行中采集的数据,对已有模型微调(兼容旧逻辑) """ def __init__(self, config_file: Path): """ 初始化训练器 Args: config_file: auto_training.yaml 配置文件路径 """ self.config_file = config_file self.config = self._load_config() # 路径配置 self.deploy_root = Path(__file__).parent.parent self.audio_root = self.deploy_root / "data" / "audio" # 模型根目录(所有设备子目录的父目录) self.model_root = self.deploy_root / "models" self.backup_dir = self.model_root / "backups" # 临时目录 self.temp_mel_dir = self.deploy_root / "data" / "temp_mels" # 确保目录存在 self.backup_dir.mkdir(parents=True, exist_ok=True) # 运行时可覆盖的配置(用于冷启动) self.use_days_ago = None self.sample_hours = None self.epochs = None self.learning_rate = None self.cold_start_mode = False def _load_config(self) -> Dict: # 加载配置文件 with open(self.config_file, 'r', encoding='utf-8') as f: return yaml.safe_load(f) # ======================================== # 数据收集 # ======================================== def collect_from_external_dir(self, data_dir: Path, device_filter: Optional[List[str]] = None ) -> Dict[str, List[Path]]: """ 从外部数据目录收集训练数据 目录结构约定: data_dir/ ├── LT-2/ <-- 子文件夹名 = device_code │ ├── xxx.wav │ └── ... └── LT-5/ ├── yyy.wav └── ... 支持两种子目录结构: 1. 扁平结构:data_dir/{device_code}/*.wav 2. 日期嵌套:data_dir/{device_code}/{YYYYMMDD}/*.wav Args: data_dir: 外部数据目录路径 device_filter: 只训练指定设备(None=全部) Returns: {device_code: [wav_files]} 按设备分组的音频文件列表 """ data_dir = Path(data_dir) if not data_dir.exists(): raise FileNotFoundError(f"数据目录不存在: {data_dir}") logger.info(f"扫描外部数据目录: {data_dir}") device_files = {} for sub_dir in sorted(data_dir.iterdir()): # 跳过非目录 if not sub_dir.is_dir(): continue device_code = sub_dir.name # 应用设备过滤 if device_filter and device_code not in device_filter: logger.info(f" 跳过设备(不在过滤列表中): {device_code}") continue audio_files = [] # 扁平结构:直接查找 wav 文件 audio_files.extend(list(sub_dir.glob("*.wav"))) audio_files.extend(list(sub_dir.glob("*.mp4"))) # 日期嵌套结构:查找子目录中的 wav 文件 for nested_dir in sub_dir.iterdir(): if nested_dir.is_dir(): audio_files.extend(list(nested_dir.glob("*.wav"))) audio_files.extend(list(nested_dir.glob("*.mp4"))) # 去重 audio_files = list(set(audio_files)) if audio_files: device_files[device_code] = audio_files logger.info(f" {device_code}: {len(audio_files)} 个音频文件") else: logger.warning(f" {device_code}: 无音频文件,跳过") total = sum(len(f) for f in device_files.values()) logger.info(f"总计: {total} 个音频文件,{len(device_files)} 个设备") return device_files def collect_training_data(self, target_date: str) -> Dict[str, List[Path]]: """ 从内部数据目录收集训练数据(增量训练用) Args: target_date: 日期字符串 'YYYYMMDD' Returns: {device_code: [wav_files]} 按设备分组的音频文件 """ logger.info(f"收集 {target_date} 的训练数据") sample_hours = self.config['auto_training']['incremental'].get('sample_hours', 0) device_files = {} if not self.audio_root.exists(): logger.warning(f"音频目录不存在: {self.audio_root}") return device_files for device_dir in self.audio_root.iterdir(): if not device_dir.is_dir(): continue device_code = device_dir.name audio_files = [] # 冷启动模式:收集所有目录的数据 if self.cold_start_mode: # 收集current目录 current_dir = device_dir / "current" if current_dir.exists(): audio_files.extend(list(current_dir.glob("*.wav"))) audio_files.extend(list(current_dir.glob("*.mp4"))) # 收集所有日期目录 for sub_dir in device_dir.iterdir(): if sub_dir.is_dir() and sub_dir.name.isdigit() and len(sub_dir.name) == 8: audio_files.extend(list(sub_dir.glob("*.wav"))) audio_files.extend(list(sub_dir.glob("*.mp4"))) else: # 正常模式:只收集指定日期的目录 date_dir = device_dir / target_date if date_dir.exists(): audio_files.extend(list(date_dir.glob("*.wav"))) audio_files.extend(list(date_dir.glob("*.mp4"))) # 加上 verified_normal 目录 verified_dir = device_dir / "verified_normal" if verified_dir.exists(): audio_files.extend(list(verified_dir.glob("*.wav"))) audio_files.extend(list(verified_dir.glob("*.mp4"))) # 去重 audio_files = list(set(audio_files)) # 随机采样(如果配置了采样时长) if sample_hours > 0 and audio_files: files_needed = int(sample_hours * 3600 / 60) if len(audio_files) > files_needed: audio_files = random.sample(audio_files, files_needed) logger.info(f" {device_code}: 随机采样 {len(audio_files)} 个音频") else: logger.info(f" {device_code}: {len(audio_files)} 个音频(全部使用)") else: logger.info(f" {device_code}: {len(audio_files)} 个音频") if audio_files: device_files[device_code] = audio_files total = sum(len(f) for f in device_files.values()) logger.info(f"总计: {total} 个音频文件,{len(device_files)} 个设备") return device_files # ======================================== # 特征提取(每设备独立标准化参数) # ======================================== def _extract_mel_for_device(self, device_code: str, wav_files: List[Path]) -> Tuple[Optional[Path], Optional[Tuple[float, float]]]: """ 为单个设备提取 Mel 特征并计算独立的 Z-score 标准化参数 两遍扫描: 1. 第一遍:收集所有 mel_db 计算 mean/std 2. 第二遍:Z-score 标准化后保存 npy 文件 Args: device_code: 设备编码 wav_files: 该设备的音频文件列表 Returns: (mel_dir, (global_mean, global_std)),失败返回 (None, None) """ import librosa # 滑动窗口参数 win_samples = int(CFG.WIN_SEC * CFG.SR) hop_samples = int(CFG.HOP_SEC * CFG.SR) # 第一遍:收集所有 mel_db 值,用于计算 mean/std all_mel_data = [] all_values = [] # 收集所有像素值用于全局统计 for wav_file in wav_files: try: y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True) # 跳过过短的音频 if len(y) < CFG.SR: continue for idx, start in enumerate(range(0, len(y) - win_samples + 1, hop_samples)): segment = y[start:start + win_samples] mel_spec = librosa.feature.melspectrogram( y=segment, sr=CFG.SR, n_fft=CFG.N_FFT, hop_length=CFG.HOP_LENGTH, n_mels=CFG.N_MELS, power=2.0 ) mel_db = librosa.power_to_db(mel_spec, ref=np.max) # 对齐帧数 if mel_db.shape[1] < CFG.TARGET_FRAMES: pad = CFG.TARGET_FRAMES - mel_db.shape[1] mel_db = np.pad(mel_db, ((0, 0), (0, pad)), mode="constant") else: mel_db = mel_db[:, :CFG.TARGET_FRAMES] # 收集所有值用于 min/max 计算 all_values.append(mel_db.flatten()) all_mel_data.append((wav_file, idx, mel_db)) except Exception as e: logger.warning(f"跳过文件 {wav_file.name}: {e}") continue if not all_mel_data: logger.warning(f" {device_code}: 无有效数据") return None, None # 计算全局 min/max(Min-Max 标准化参数) all_values_concat = np.concatenate(all_values) global_min = float(np.min(all_values_concat)) global_max = float(np.max(all_values_concat)) logger.info(f" {device_code}: {len(all_mel_data)} patches | " f"min={global_min:.4f} max={global_max:.4f}") # 第二遍:Min-Max 标准化并保存 device_mel_dir = self.temp_mel_dir / device_code device_mel_dir.mkdir(parents=True, exist_ok=True) for wav_file, idx, mel_db in all_mel_data: # Min-Max: (x - min) / (max - min) mel_norm = (mel_db - global_min) / (global_max - global_min + 1e-6) npy_name = f"{device_code}@@{wav_file.stem}@@win{idx:05d}.npy" np.save(device_mel_dir / npy_name, mel_norm.astype(np.float32)) return device_mel_dir, (global_min, global_max) def prepare_mel_features_per_device(self, device_files: Dict[str, List[Path]] ) -> Dict[str, Tuple[Path, Tuple[float, float]]]: """ 为每个设备独立提取 Mel 特征 每设备分别计算自己的 Min-Max 标准化参数 (global_min, global_max) Args: device_files: {device_code: [wav_files]} Returns: {device_code: (mel_dir, (global_min, global_max))} """ logger.info("提取 Mel 特征(每设备独立标准化)") # 清空临时目录 if self.temp_mel_dir.exists(): shutil.rmtree(self.temp_mel_dir) self.temp_mel_dir.mkdir(parents=True, exist_ok=True) device_results = {} for device_code, wav_files in device_files.items(): mel_dir, scale = self._extract_mel_for_device(device_code, wav_files) if mel_dir is not None: device_results[device_code] = (mel_dir, scale) total_patches = sum( len(list(mel_dir.glob("*.npy"))) for mel_dir, _ in device_results.values() ) logger.info(f"提取完成: {total_patches} patches,{len(device_results)} 个设备") return device_results # ======================================== # 模型训练(每设备独立) # ======================================== def train_single_device(self, device_code: str, mel_dir: Path, epochs: int, lr: float, from_scratch: bool = True ) -> Tuple[nn.Module, float]: """ 训练单个设备的独立模型 Args: device_code: 设备编码 mel_dir: 该设备的 Mel 特征目录 epochs: 训练轮数 lr: 学习率 from_scratch: True=从零训练(全量),False=加载已有模型微调(增量) Returns: (model, final_loss) """ logger.info(f"训练设备 {device_code}: epochs={epochs}, lr={lr}, " f"mode={'全量' if from_scratch else '增量'}") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = ConvAutoencoder().to(device) # 增量模式下加载已有模型 if not from_scratch: model_path = self.model_root / device_code / "ae_model.pth" if model_path.exists(): model.load_state_dict(torch.load(model_path, map_location=device)) logger.info(f" 已加载已有模型: {model_path}") else: logger.warning(f" 模型不存在,自动切换为全量训练: {model_path}") # 加载数据 dataset = MelNPYDataset(mel_dir) if len(dataset) == 0: raise ValueError(f"设备 {device_code} 无训练数据") batch_size = self.config['auto_training']['incremental']['batch_size'] dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=0 ) # 训练 model.train() optimizer = torch.optim.Adam(model.parameters(), lr=lr) criterion = nn.MSELoss() avg_loss = 0.0 for epoch in range(epochs): epoch_loss = 0.0 batch_count = 0 for batch in dataloader: batch = batch.to(device) optimizer.zero_grad() output = model(batch) output = align_to_target(output, batch) loss = criterion(output, batch) loss.backward() optimizer.step() epoch_loss += loss.item() batch_count += 1 avg_loss = epoch_loss / batch_count # 每10轮或最后一轮打印日志,避免日志刷屏 if (epoch + 1) % 10 == 0 or epoch == epochs - 1: logger.info(f" [{device_code}] Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.6f}") return model, avg_loss # ======================================== # 产出部署(每设备独立目录) # ======================================== def deploy_device_model(self, device_code: str, model: nn.Module, scale: Tuple[float, float], mel_dir: Path): """ 部署单个设备的模型到 models/{device_code}/ 目录 产出文件: - models/{device_code}/ae_model.pth - models/{device_code}/global_scale.npy → [mean, std] - models/{device_code}/thresholds/threshold_{device_code}.npy Args: device_code: 设备编码 model: 训练后的模型 scale: (global_min, global_max) Min-Max 标准化参数 mel_dir: 该设备的 Mel 特征目录(用于计算阈值) """ # 创建设备模型目录 device_model_dir = self.model_root / device_code device_model_dir.mkdir(parents=True, exist_ok=True) # 1. 保存模型权重 model_path = device_model_dir / "ae_model.pth" torch.save(model.state_dict(), model_path) logger.info(f" 模型已保存: {model_path}") # 2. 保存 Min-Max 标准化参数 [min, max] scale_path = device_model_dir / "global_scale.npy" np.save(scale_path, np.array([scale[0], scale[1]])) logger.info(f" 标准化参数已保存: {scale_path}") # 3. 计算并保存阈值 threshold_dir = device_model_dir / "thresholds" threshold_dir.mkdir(parents=True, exist_ok=True) threshold = self._compute_threshold(model, mel_dir) threshold_file = threshold_dir / f"threshold_{device_code}.npy" np.save(threshold_file, threshold) logger.info(f" 阈值已保存: {threshold_file} (值={threshold:.6f})") def _compute_threshold(self, model: nn.Module, mel_dir: Path) -> float: """ 计算单个设备的阈值 使用 3σ 法则:threshold = mean + 3 * std Args: model: 训练后的模型 mel_dir: 该设备的 Mel 特征目录 Returns: 阈值(标量) """ device = next(model.parameters()).device model.eval() dataset = MelNPYDataset(mel_dir) if len(dataset) == 0: logger.warning("无数据计算阈值,使用默认值 0.01") return 0.01 dataloader = torch.utils.data.DataLoader( dataset, batch_size=64, shuffle=False ) all_errors = [] with torch.no_grad(): for batch in dataloader: batch = batch.to(device) output = model(batch) output = align_to_target(output, batch) mse = torch.mean((output - batch) ** 2, dim=[1, 2, 3]) all_errors.append(mse.cpu().numpy()) errors = np.concatenate(all_errors) # 3σ 法则 mean_err = float(np.mean(errors)) std_err = float(np.std(errors)) threshold = mean_err + 3 * std_err logger.info(f" 阈值统计: 3σ={threshold:.6f} | " f"mean={mean_err:.6f} std={std_err:.6f} | " f"样本数={len(errors)}") return threshold # ======================================== # 全量训练入口 # ======================================== def run_full_training(self, data_dir: Path, epochs: Optional[int] = None, lr: Optional[float] = None, device_filter: Optional[List[str]] = None) -> bool: """ 全量训练入口:每设备从零训练独立模型 流程: 1. 扫描外部数据目录 2. 每设备独立提取 Mel 特征+标准化参数 3. 每设备独立训练模型 4. 每设备独立部署(模型+标准化+阈值) Args: data_dir: 数据目录路径 epochs: 训练轮数(None=使用配置文件值) lr: 学习率(None=使用配置文件值) device_filter: 只训练指定设备(None=全部) Returns: bool: 是否成功 """ try: epochs = epochs or self.config['auto_training']['incremental']['epochs'] lr = lr or self.config['auto_training']['incremental']['learning_rate'] logger.info("=" * 60) logger.info(f"全量训练 | 数据目录: {data_dir}") logger.info(f"参数: epochs={epochs}, lr={lr}") if device_filter: logger.info(f"设备过滤: {device_filter}") logger.info("=" * 60) # 1. 收集数据 device_files = self.collect_from_external_dir(data_dir, device_filter) if not device_files: logger.error("无可用训练数据") return False # 2. 每设备提取特征 device_results = self.prepare_mel_features_per_device(device_files) if not device_results: logger.error("特征提取失败") return False # 3. 每设备独立训练+部署 success_count = 0 fail_count = 0 for device_code, (mel_dir, scale) in device_results.items(): try: logger.info(f"\n--- 训练设备: {device_code} ---") # 训练 model, final_loss = self.train_single_device( device_code, mel_dir, epochs, lr, from_scratch=True ) # 验证 if not self._validate_model(model): logger.error(f" {device_code}: 模型验证失败,跳过部署") fail_count += 1 continue # 部署到 models/{device_code}/ self.deploy_device_model(device_code, model, scale, mel_dir) success_count += 1 logger.info(f" {device_code}: 训练+部署完成 | loss={final_loss:.6f}") except Exception as e: logger.error(f" {device_code}: 训练失败 | {e}", exc_info=True) fail_count += 1 logger.info("=" * 60) logger.info(f"全量训练完成: 成功={success_count}, 失败={fail_count}") logger.info("=" * 60) return fail_count == 0 except Exception as e: logger.error(f"全量训练异常: {e}", exc_info=True) return False finally: # 清理临时文件 if self.temp_mel_dir.exists(): shutil.rmtree(self.temp_mel_dir) # ======================================== # 增量训练入口(保留兼容) # ======================================== def run_daily_training(self) -> bool: """ 执行每日增量训练(保留原有逻辑) 改造点:每设备独立训练+部署,不再共享模型 Returns: bool: 是否成功 """ try: days_ago = (self.use_days_ago if self.use_days_ago is not None else self.config['auto_training']['incremental']['use_days_ago']) target_date = (datetime.now() - timedelta(days=days_ago)).strftime('%Y%m%d') mode_str = "冷启动训练" if self.cold_start_mode else "增量训练" logger.info("=" * 60) logger.info(f"{mode_str} - {target_date}") logger.info("=" * 60) # 1. 收集数据 device_files = self.collect_training_data(target_date) total = sum(len(f) for f in device_files.values()) min_samples = self.config['auto_training']['incremental']['min_samples'] if total < min_samples: logger.warning(f"数据不足 ({total} < {min_samples}),跳过") return False # 2. 备份模型 if self.config['auto_training']['model']['backup_before_train']: self.backup_model(target_date) # 3. 每设备提取特征 device_results = self.prepare_mel_features_per_device(device_files) if not device_results: logger.error("特征提取失败") return False # 4. 训练参数 epochs = (self.epochs if self.epochs is not None else self.config['auto_training']['incremental']['epochs']) lr = (self.learning_rate if self.learning_rate is not None else self.config['auto_training']['incremental']['learning_rate']) # 5. 每设备独立训练+部署 # 冷启动=全量训练(从零),增量=加载已有模型微调 from_scratch = self.cold_start_mode success_count = 0 for device_code, (mel_dir, scale) in device_results.items(): try: model, _ = self.train_single_device( device_code, mel_dir, epochs, lr, from_scratch=from_scratch ) if self._validate_model(model): if self.config['auto_training']['model']['auto_deploy']: self.deploy_device_model(device_code, model, scale, mel_dir) success_count += 1 else: logger.error(f"{device_code}: 验证失败,跳过部署") except Exception as e: logger.error(f"{device_code}: 训练失败 | {e}") # 6. 更新分类器基线 self._update_classifier_baseline(device_files) logger.info("=" * 60) logger.info(f"增量训练完成: {success_count}/{len(device_results)} 个设备成功") logger.info("=" * 60) return success_count > 0 except Exception as e: logger.error(f"训练失败: {e}", exc_info=True) return False finally: if self.temp_mel_dir.exists(): shutil.rmtree(self.temp_mel_dir) # ======================================== # 辅助方法 # ======================================== def _validate_model(self, model: nn.Module) -> bool: # 验证模型输出形状是否合理 if not self.config['auto_training']['validation']['enabled']: return True try: device = next(model.parameters()).device test_input = torch.randn(1, 1, CFG.N_MELS, CFG.TARGET_FRAMES).to(device) with torch.no_grad(): output = model(test_input) h_diff = abs(output.shape[2] - test_input.shape[2]) w_diff = abs(output.shape[3] - test_input.shape[3]) if h_diff > 8 or w_diff > 8: logger.error(f"形状差异过大: {output.shape} vs {test_input.shape}") return False return True except Exception as e: logger.error(f"验证失败: {e}") return False def backup_model(self, date_tag: str): """ 完整备份当前所有设备的模型 备份目录结构: backups/{date_tag}/{device_code}/ ├── ae_model.pth ├── global_scale.npy └── thresholds/ """ backup_date_dir = self.backup_dir / date_tag backup_date_dir.mkdir(parents=True, exist_ok=True) backed_up = 0 # 遍历 models/ 下的所有设备子目录 for device_dir in self.model_root.iterdir(): if not device_dir.is_dir(): continue # 跳过 backups 目录本身 if device_dir.name == "backups": continue # 检查是否包含模型文件(判断是否为设备目录) if not (device_dir / "ae_model.pth").exists(): continue device_backup = backup_date_dir / device_dir.name # 递归复制整个设备目录 shutil.copytree(device_dir, device_backup, dirs_exist_ok=True) backed_up += 1 logger.info(f"备份完成: {backed_up} 个设备 -> {backup_date_dir}") # 清理旧备份 keep = self.config['auto_training']['model']['keep_backups'] backup_dirs = sorted( [d for d in self.backup_dir.iterdir() if d.is_dir() and d.name.isdigit()], reverse=True ) for old_dir in backup_dirs[keep:]: shutil.rmtree(old_dir) logger.info(f"已删除旧备份: {old_dir.name}") def restore_backup(self, date_tag: str) -> bool: """ 从备份恢复所有设备的模型 Args: date_tag: 备份日期标签 'YYYYMMDD' Returns: bool: 是否恢复成功 """ backup_date_dir = self.backup_dir / date_tag if not backup_date_dir.exists(): logger.error(f"备份目录不存在: {backup_date_dir}") return False logger.info(f"从备份恢复: {date_tag}") restored = 0 for device_backup in backup_date_dir.iterdir(): if not device_backup.is_dir(): continue target_dir = self.model_root / device_backup.name # 递归复制恢复 shutil.copytree(device_backup, target_dir, dirs_exist_ok=True) restored += 1 logger.info(f"恢复完成: {restored} 个设备") return restored > 0 def _update_classifier_baseline(self, device_files: Dict[str, List[Path]]): # 从训练数据计算并更新分类器基线 logger.info("更新分类器基线") try: import librosa from core.anomaly_classifier import AnomalyClassifier classifier = AnomalyClassifier() all_files = [] for files in device_files.values(): all_files.extend(files) if not all_files: logger.warning("无音频文件,跳过基线更新") return sample_files = random.sample(all_files, min(50, len(all_files))) all_features = [] for wav_file in sample_files: try: y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True) if len(y) < CFG.SR: continue features = classifier.extract_features(y, sr=CFG.SR) if features: all_features.append(features) except Exception: continue if not all_features: logger.warning("无法提取特征,跳过基线更新") return baseline = {} keys = all_features[0].keys() for key in keys: if key == 'has_periodic': values = [f[key] for f in all_features] baseline[key] = sum(values) > len(values) / 2 else: values = [f[key] for f in all_features] baseline[key] = float(np.mean(values)) classifier.save_baseline(baseline) logger.info(f" 基线已更新 (样本数: {len(all_features)})") except Exception as e: logger.warning(f"更新基线失败: {e}") def main(): # 命令行入口(增量训练) logging.basicConfig( level=logging.INFO, format='%(asctime)s | %(levelname)-8s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) config_file = Path(__file__).parent.parent / "config" / "auto_training.yaml" trainer = IncrementalTrainer(config_file) success = trainer.run_daily_training() sys.exit(0 if success else 1) if __name__ == "__main__": main()