#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ incremental_trainer.py ---------------------- 模型训练模块 功能: 1. 支持全量训练(指定外部数据目录,每设备独立模型) 2. 支持增量训练(每日自动训练) 3. 滑动窗口提取Mel特征(8秒patches) 4. 每设备独立计算标准化参数和阈值 5. 产出目录结构与推理端 MultiModelPredictor 对齐 产出目录结构: models/ ├── {device_code_1}/ │ ├── ae_model.pth │ ├── global_scale.npy │ └── thresholds/ │ └── threshold_{device_code_1}.npy └── {device_code_2}/ ├── ae_model.pth ├── global_scale.npy └── thresholds/ └── threshold_{device_code_2}.npy """ import sys import random import shutil import logging import numpy as np import torch import torch.nn as nn from pathlib import Path from datetime import datetime, timedelta from typing import List, Dict, Tuple, Optional import yaml # 添加父目录到路径 sys.path.insert(0, str(Path(__file__).parent.parent)) from predictor import CFG from predictor.model_def import ConvAutoencoder from predictor.datasets import MelNPYDataset from predictor.utils import align_to_target logger = logging.getLogger('IncrementalTrainer') class IncrementalTrainer: """ 模型训练器 支持两种训练模式: 1. 全量训练:指定外部数据目录,每设备从零训练独立模型 2. 增量训练:使用运行中采集的数据,对已有模型微调(兼容旧逻辑) """ def __init__(self, config_file: Path = None, config: dict = None): # 支持两种初始化方式: # 1. 传 config dict(从数据库读取后直接传入,主程序使用) # 2. 传 config_file YAML 路径(standalone_train.py 等独立工具使用) if config is not None: self.config = config self.config_file = None elif config_file is not None: self.config_file = config_file self.config = self._load_config() else: raise ValueError("必须提供 config_file 或 config 之一") # 路径配置 self.deploy_root = Path(__file__).parent.parent self.audio_root = self.deploy_root / "data" / "audio" # 模型根目录(所有设备子目录的父目录) self.model_root = self.deploy_root / "models" self.backup_dir = self.model_root / "backups" # 临时目录 self.temp_mel_dir = self.deploy_root / "data" / "temp_mels" # 确保目录存在 self.backup_dir.mkdir(parents=True, exist_ok=True) # 运行时可覆盖的配置(用于冷启动) self.use_days_ago = None self.sample_hours = None self.epochs = None self.learning_rate = None self.cold_start_mode = False def _load_config(self) -> Dict: # 从 YAML 文件加载配置(仅 config_file 模式使用) with open(self.config_file, 'r', encoding='utf-8') as f: return yaml.safe_load(f) # ======================================== # 数据收集 # ======================================== def collect_from_external_dir(self, data_dir: Path, device_filter: Optional[List[str]] = None ) -> Dict[str, List[Path]]: """ 从外部数据目录收集训练数据 目录结构约定: data_dir/ ├── LT-2/ <-- 子文件夹名 = device_code │ ├── xxx.wav │ └── ... └── LT-5/ ├── yyy.wav └── ... 支持两种子目录结构: 1. 扁平结构:data_dir/{device_code}/*.wav 2. 日期嵌套:data_dir/{device_code}/{YYYYMMDD}/*.wav Args: data_dir: 外部数据目录路径 device_filter: 只训练指定设备(None=全部) Returns: {device_code: [wav_files]} 按设备分组的音频文件列表 """ data_dir = Path(data_dir) if not data_dir.exists(): raise FileNotFoundError(f"数据目录不存在: {data_dir}") logger.info(f"扫描外部数据目录: {data_dir}") device_files = {} for sub_dir in sorted(data_dir.iterdir()): # 跳过非目录 if not sub_dir.is_dir(): continue device_code = sub_dir.name # 应用设备过滤 if device_filter and device_code not in device_filter: logger.info(f" 跳过设备(不在过滤列表中): {device_code}") continue audio_files = [] # 扁平结构:直接查找 wav 文件 audio_files.extend(list(sub_dir.glob("*.wav"))) audio_files.extend(list(sub_dir.glob("*.mp4"))) # 日期嵌套结构:查找子目录中的 wav 文件 for nested_dir in sub_dir.iterdir(): if nested_dir.is_dir(): audio_files.extend(list(nested_dir.glob("*.wav"))) audio_files.extend(list(nested_dir.glob("*.mp4"))) # 去重 audio_files = list(set(audio_files)) if audio_files: device_files[device_code] = audio_files logger.info(f" {device_code}: {len(audio_files)} 个音频文件") else: logger.warning(f" {device_code}: 无音频文件,跳过") total = sum(len(f) for f in device_files.values()) logger.info(f"总计: {total} 个音频文件,{len(device_files)} 个设备") return device_files def collect_training_data(self, target_date: str) -> Dict[str, List[Path]]: """ 从内部数据目录收集训练数据(增量训练用) Args: target_date: 日期字符串 'YYYYMMDD' Returns: {device_code: [wav_files]} 按设备分组的音频文件 """ logger.info(f"收集 {target_date} 的训练数据") sample_hours = self.config['auto_training']['incremental'].get('sample_hours', 0) device_files = {} if not self.audio_root.exists(): logger.warning(f"音频目录不存在: {self.audio_root}") return device_files for device_dir in self.audio_root.iterdir(): if not device_dir.is_dir(): continue device_code = device_dir.name audio_files = [] # 冷启动模式:收集所有已归档日期目录的正常音频(跳过 current/) if self.cold_start_mode: # 注意:跳过 current/ 目录,因其中可能包含 FFmpeg 正在写入的不完整文件 for sub_dir in device_dir.iterdir(): if sub_dir.is_dir() and sub_dir.name.isdigit() and len(sub_dir.name) == 8: # 新结构:从 {date}/normal/ 子目录读取 normal_dir = sub_dir / "normal" if normal_dir.exists(): audio_files.extend(list(normal_dir.glob("*.wav"))) audio_files.extend(list(normal_dir.glob("*.mp4"))) # 兼容旧结构:日期目录下直接存放的音频文件 audio_files.extend(list(sub_dir.glob("*.wav"))) audio_files.extend(list(sub_dir.glob("*.mp4"))) else: # 正常模式:只收集指定日期的正常音频 date_dir = device_dir / target_date if date_dir.exists(): # 新结构:从 {date}/normal/ 子目录读取 normal_dir = date_dir / "normal" if normal_dir.exists(): audio_files.extend(list(normal_dir.glob("*.wav"))) audio_files.extend(list(normal_dir.glob("*.mp4"))) # 兼容旧结构:日期目录下直接存放的音频文件 audio_files.extend(list(date_dir.glob("*.wav"))) audio_files.extend(list(date_dir.glob("*.mp4"))) # 加上 verified_normal 目录(单独收集,不参与采样和质量预筛) verified_dir = device_dir / "verified_normal" verified_files = [] if verified_dir.exists(): verified_files.extend(list(verified_dir.glob("*.wav"))) verified_files.extend(list(verified_dir.glob("*.mp4"))) # 去重(仅日期目录音频) audio_files = list(set(audio_files)) # 数据质量预筛:仅对日期目录音频过滤,verified_normal 已经人工确认,跳过 if audio_files and not self.cold_start_mode: before_count = len(audio_files) audio_files = self._filter_audio_quality(audio_files, device_code) filtered = before_count - len(audio_files) if filtered > 0: logger.info(f" {device_code}: 质量预筛过滤 {filtered} 个异常音频") # 随机采样(仅对日期目录音频采样,verified_normal 不参与) if sample_hours > 0 and audio_files: files_needed = int(sample_hours * 3600 / 60) if len(audio_files) > files_needed: audio_files = random.sample(audio_files, files_needed) logger.info(f" {device_code}: 随机采样 {len(audio_files)} 个音频") else: logger.info(f" {device_code}: {len(audio_files)} 个音频(全部使用)") else: logger.info(f" {device_code}: {len(audio_files)} 个音频") # 合并 verified_normal(采样后追加,保证全量参与训练) if verified_files: audio_files.extend(verified_files) logger.info(f" {device_code}: +{len(verified_files)} 个核查确认音频(verified_normal)") if audio_files: device_files[device_code] = audio_files total = sum(len(f) for f in device_files.values()) logger.info(f"总计: {total} 个音频文件,{len(device_files)} 个设备") return device_files def _filter_audio_quality(self, audio_files: List[Path], device_code: str) -> List[Path]: """ 音频质量预筛:基于 RMS 能量和频谱质心过滤明显异常的样本 使用 IQR (四分位距) 方法检测离群值: - 计算所有文件的 RMS 能量和频谱质心 - 过滤超出 [Q1 - 2*IQR, Q3 + 2*IQR] 范围的文件 需要至少 10 个文件才执行过滤,否则样本太少无统计意义。 Args: audio_files: 待过滤的音频文件列表 device_code: 设备编码(用于日志) Returns: 过滤后的文件列表 """ if len(audio_files) < 10: return audio_files import librosa # 快速计算每个文件的 RMS 能量 rms_values = [] valid_files = [] for wav_file in audio_files: try: y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True, duration=10) # 只读前10秒加速 if len(y) < CFG.SR: continue rms = float(np.sqrt(np.mean(y ** 2))) rms_values.append(rms) valid_files.append(wav_file) except Exception: continue if len(rms_values) < 10: return audio_files # IQR 离群值检测 rms_arr = np.array(rms_values) q1, q3 = np.percentile(rms_arr, [25, 75]) iqr = q3 - q1 lower_bound = q1 - 2 * iqr upper_bound = q3 + 2 * iqr filtered = [] for f, rms in zip(valid_files, rms_values): if lower_bound <= rms <= upper_bound: filtered.append(f) return filtered # ======================================== # 特征提取(每设备独立标准化参数) # ======================================== def _extract_mel_for_device(self, device_code: str, wav_files: List[Path] ) -> Tuple[Optional[Path], Optional[Tuple[float, float]]]: """ 为单个设备提取 Mel 特征并计算独立的 Min-Max 标准化参数 流式两遍扫描(内存优化): 1. 第一遍:只计算 running min/max(O(1) 内存),不保存 mel_db 2. 第二遍:用第一遍的 min/max 标准化后直接写 npy 文件 Args: device_code: 设备编码 wav_files: 该设备的音频文件列表 Returns: (mel_dir, (global_min, global_max)),失败返回 (None, None) """ import librosa # 滑动窗口参数 win_samples = int(CFG.WIN_SEC * CFG.SR) hop_samples = int(CFG.HOP_SEC * CFG.SR) def _iter_mel_patches(files): """生成器:逐文件逐 patch 产出 mel_db,避免全量加载到内存""" for wav_file in files: try: y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True) if len(y) < CFG.SR: continue for idx, start in enumerate(range(0, len(y) - win_samples + 1, hop_samples)): segment = y[start:start + win_samples] mel_spec = librosa.feature.melspectrogram( y=segment, sr=CFG.SR, n_fft=CFG.N_FFT, hop_length=CFG.HOP_LENGTH, n_mels=CFG.N_MELS, power=2.0 ) mel_db = librosa.power_to_db(mel_spec, ref=np.max) # 对齐帧数 if mel_db.shape[1] < CFG.TARGET_FRAMES: pad = CFG.TARGET_FRAMES - mel_db.shape[1] mel_db = np.pad(mel_db, ((0, 0), (0, pad)), mode="constant") else: mel_db = mel_db[:, :CFG.TARGET_FRAMES] yield wav_file, idx, mel_db except Exception as e: logger.warning(f"跳过文件 {wav_file.name}: {e}") continue # ── 第一遍:流式计算 running min/max(O(1) 内存) ── global_min = float('inf') global_max = float('-inf') patch_count = 0 for _, _, mel_db in _iter_mel_patches(wav_files): local_min = float(mel_db.min()) local_max = float(mel_db.max()) if local_min < global_min: global_min = local_min if local_max > global_max: global_max = local_max patch_count += 1 if patch_count == 0: logger.warning(f" {device_code}: 无有效数据") return None, None logger.info(f" {device_code}: {patch_count} patches | " f"min={global_min:.4f} max={global_max:.4f}") # ── 第二遍:Min-Max 标准化并保存 ── device_mel_dir = self.temp_mel_dir / device_code device_mel_dir.mkdir(parents=True, exist_ok=True) scale_range = global_max - global_min + 1e-6 for wav_file, idx, mel_db in _iter_mel_patches(wav_files): mel_norm = (mel_db - global_min) / scale_range npy_name = f"{device_code}@@{wav_file.stem}@@win{idx:05d}.npy" np.save(device_mel_dir / npy_name, mel_norm.astype(np.float32)) return device_mel_dir, (global_min, global_max) def prepare_mel_features_per_device(self, device_files: Dict[str, List[Path]] ) -> Dict[str, Tuple[Path, Tuple[float, float]]]: """ 为每个设备独立提取 Mel 特征 每设备分别计算自己的 Min-Max 标准化参数 (global_min, global_max) Args: device_files: {device_code: [wav_files]} Returns: {device_code: (mel_dir, (global_min, global_max))} """ logger.info("提取 Mel 特征(每设备独立标准化)") # 清空临时目录 if self.temp_mel_dir.exists(): shutil.rmtree(self.temp_mel_dir) self.temp_mel_dir.mkdir(parents=True, exist_ok=True) device_results = {} for device_code, wav_files in device_files.items(): mel_dir, scale = self._extract_mel_for_device(device_code, wav_files) if mel_dir is not None: device_results[device_code] = (mel_dir, scale) total_patches = sum( len(list(mel_dir.glob("*.npy"))) for mel_dir, _ in device_results.values() ) logger.info(f"提取完成: {total_patches} patches,{len(device_results)} 个设备") return device_results # ======================================== # 模型训练(每设备独立) # ======================================== def _select_training_device(self) -> torch.device: # 智能选择训练设备:GPU 显存充足则使用,否则回退 CPU # 训练配置中可通过 training_device 强制指定 (auto/cpu/cuda) training_cfg = self.config['auto_training']['incremental'] forced_device = training_cfg.get('training_device', 'auto') # 强制指定设备时直接返回 if forced_device == 'cpu': logger.info("训练设备: CPU(配置强制指定)") return torch.device('cpu') if forced_device == 'cuda': if torch.cuda.is_available(): return torch.device('cuda') logger.warning("配置指定 CUDA 但不可用,回退 CPU") return torch.device('cpu') # auto 模式:检测 CUDA → CPU if torch.cuda.is_available(): try: free_mem = torch.cuda.mem_get_info()[0] / (1024 * 1024) min_gpu_mem_mb = training_cfg.get('min_gpu_mem_mb', 512) if free_mem >= min_gpu_mem_mb: logger.info(f"训练设备: CUDA(空闲显存 {free_mem:.0f}MB)") return torch.device('cuda') logger.info( f"CUDA 空闲显存不足 ({free_mem:.0f}MB < {min_gpu_mem_mb}MB)" ) except Exception as e: logger.warning(f"CUDA 显存检测失败: {e}") logger.info("训练设备: CPU") return torch.device('cpu') def _run_training_loop(self, device_code: str, model: nn.Module, train_loader, val_loader, epochs: int, lr: float, device: torch.device) -> Tuple[nn.Module, float]: # 执行实际的训练循环,与设备选择解耦 # 早停基于验证集损失(如有),否则基于训练损失 model = model.to(device) model.train() optimizer = torch.optim.Adam(model.parameters(), lr=lr) criterion = nn.MSELoss() # AMP 混合精度(GPU 生效,减少约 40% 显存占用) use_amp = device.type == 'cuda' scaler = torch.amp.GradScaler(device.type) if use_amp else None # 早停配置 early_stop_cfg = self.config['auto_training']['incremental'] patience = early_stop_cfg.get('early_stop_patience', 5) best_loss = float('inf') no_improve_count = 0 avg_loss = 0.0 actual_epochs = 0 for epoch in range(epochs): # ── 训练阶段 ── model.train() epoch_loss = 0.0 batch_count = 0 for batch in train_loader: batch = batch.to(device) optimizer.zero_grad() if use_amp: with torch.amp.autocast(device.type): output = model(batch) output = align_to_target(output, batch) loss = criterion(output, batch) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: output = model(batch) output = align_to_target(output, batch) loss = criterion(output, batch) loss.backward() optimizer.step() epoch_loss += loss.item() batch_count += 1 avg_loss = epoch_loss / batch_count actual_epochs = epoch + 1 # ── 验证阶段(如有验证集) ── if val_loader is not None: model.eval() val_loss = 0.0 val_count = 0 with torch.no_grad(): for batch in val_loader: batch = batch.to(device) if use_amp: with torch.amp.autocast(device.type): output = model(batch) output = align_to_target(output, batch) loss = criterion(output, batch) else: output = model(batch) output = align_to_target(output, batch) loss = criterion(output, batch) val_loss += loss.item() val_count += 1 avg_val_loss = val_loss / val_count monitor_loss = avg_val_loss # 早停监控验证损失 else: avg_val_loss = None monitor_loss = avg_loss # 无验证集时回退训练损失 if actual_epochs % 10 == 0 or epoch == epochs - 1: val_str = f" | ValLoss: {avg_val_loss:.6f}" if avg_val_loss is not None else "" logger.info(f" [{device_code}] Epoch {actual_epochs}/{epochs} | " f"Loss: {avg_loss:.6f}{val_str} | device={device.type}") # 早停检测:连续 patience 轮无改善则提前终止 if monitor_loss < best_loss: best_loss = monitor_loss no_improve_count = 0 else: no_improve_count += 1 if no_improve_count >= patience and actual_epochs >= 10: logger.info(f" [{device_code}] 早停触发: 连续{patience}轮无改善 | " f"最终轮数={actual_epochs}/{epochs} | Loss={avg_loss:.6f}") break # 训练后清理 GPU 缓存 if device.type == 'cuda': torch.cuda.empty_cache() if actual_epochs < epochs: logger.info(f" [{device_code}] 早停节省 {epochs - actual_epochs} 轮训练") return model, avg_loss def train_single_device(self, device_code: str, mel_dir: Path, epochs: int, lr: float, from_scratch: bool = True ) -> Tuple[nn.Module, float]: # 训练单个设备的独立模型 # 策略:优先 GPU 训练,显存不足自动回退 CPU;训练中 OOM 也会捕获并用 CPU 重试 logger.info(f"训练设备 {device_code}: epochs={epochs}, lr={lr}, " f"mode={'全量' if from_scratch else '增量'}") # 智能选择训练设备 device = self._select_training_device() # 训练前清理 GPU 缓存,释放推理残留的显存碎片 if device.type == 'cuda': torch.cuda.empty_cache() import gc gc.collect() model = ConvAutoencoder() # 增量模式下加载已有模型 if not from_scratch: model_path = self.model_root / device_code / "ae_model.pth" if model_path.exists(): model.load_state_dict(torch.load(model_path, map_location='cpu')) logger.info(f" 已加载已有模型: {model_path}") else: logger.warning(f" 模型不存在,自动切换为全量训练: {model_path}") # 加载数据并按 80/20 划分训练集/验证集 dataset = MelNPYDataset(mel_dir) if len(dataset) == 0: raise ValueError(f"设备 {device_code} 无训练数据") batch_size = self.config['auto_training']['incremental']['batch_size'] # 验证集划分:数据量 >= 20 时才划分(否则太少无统计意义) val_loader = None if len(dataset) >= 20: val_size = max(1, int(len(dataset) * 0.2)) train_size = len(dataset) - val_size train_dataset, val_dataset = torch.utils.data.random_split( dataset, [train_size, val_size] ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False ) logger.info(f" 数据集划分: 训练={train_size}, 验证={val_size}") else: train_loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False ) logger.info(f" 数据量不足20,跳过验证集划分(共{len(dataset)}样本)") # 尝试在选定设备上训练 if device.type == 'cuda': try: return self._run_training_loop( device_code, model, train_loader, val_loader, epochs, lr, device ) except (torch.cuda.OutOfMemoryError, RuntimeError) as e: # GPU OOM -> 清理显存后回退 CPU 重试 if 'out of memory' not in str(e).lower() and isinstance(e, RuntimeError): raise # 非 OOM 的 RuntimeError 不拦截 logger.warning( f" [{device_code}] CUDA OOM,清理显存后回退 CPU 训练" ) import gc gc.collect() torch.cuda.empty_cache() # 模型可能处于脏状态,重新初始化 model = ConvAutoencoder() if not from_scratch: model_path = self.model_root / device_code / "ae_model.pth" if model_path.exists(): model.load_state_dict( torch.load(model_path, map_location='cpu') ) return self._run_training_loop( device_code, model, train_loader, val_loader, epochs, lr, torch.device('cpu') ) else: # CPU 训练,无需 OOM 保护 return self._run_training_loop( device_code, model, train_loader, val_loader, epochs, lr, device ) # ======================================== # 产出部署(每设备独立目录) # ======================================== def deploy_device_model(self, device_code: str, model: nn.Module, scale: Tuple[float, float], mel_dir: Path): """ 部署单个设备的模型到 models/{device_code}/ 目录 产出文件: - models/{device_code}/ae_model.pth - models/{device_code}/global_scale.npy → [mean, std] - models/{device_code}/thresholds/threshold_{device_code}.npy Args: device_code: 设备编码 model: 训练后的模型 scale: (global_min, global_max) Min-Max 标准化参数 mel_dir: 该设备的 Mel 特征目录(用于计算阈值) """ # 创建设备模型目录 device_model_dir = self.model_root / device_code device_model_dir.mkdir(parents=True, exist_ok=True) # 1. 保存模型权重 model_path = device_model_dir / "ae_model.pth" torch.save(model.state_dict(), model_path) logger.info(f" 模型已保存: {model_path}") # 2. 保存 Min-Max 标准化参数 [min, max] scale_path = device_model_dir / "global_scale.npy" np.save(scale_path, np.array([scale[0], scale[1]])) logger.info(f" 标准化参数已保存: {scale_path}") # 3. 计算并保存阈值 threshold_dir = device_model_dir / "thresholds" threshold_dir.mkdir(parents=True, exist_ok=True) threshold = self._compute_threshold(model, mel_dir) threshold_file = threshold_dir / f"threshold_{device_code}.npy" np.save(threshold_file, threshold) logger.info(f" 阈值已保存: {threshold_file} (值={threshold:.6f})") def _compute_threshold(self, model: nn.Module, mel_dir: Path) -> float: """ 计算单个设备的阈值 使用 3σ 法则:threshold = mean + 3 * std Args: model: 训练后的模型 mel_dir: 该设备的 Mel 特征目录 Returns: 阈值(标量) """ device = next(model.parameters()).device model.eval() dataset = MelNPYDataset(mel_dir) if len(dataset) == 0: logger.warning("无数据计算阈值,使用默认值 0.01") return 0.01 dataloader = torch.utils.data.DataLoader( dataset, batch_size=64, shuffle=False ) all_errors = [] with torch.no_grad(): for batch in dataloader: batch = batch.to(device) output = model(batch) output = align_to_target(output, batch) mse = torch.mean((output - batch) ** 2, dim=[1, 2, 3]) all_errors.append(mse.cpu().numpy()) errors = np.concatenate(all_errors) # 3σ 法则 mean_err = float(np.mean(errors)) std_err = float(np.std(errors)) threshold = mean_err + 3 * std_err logger.info(f" 阈值统计: 3σ={threshold:.6f} | " f"mean={mean_err:.6f} std={std_err:.6f} | " f"样本数={len(errors)}") return threshold def _eval_model_error(self, model: nn.Module, mel_dir: Path) -> float: """在验证数据上计算模型的平均重建误差,用于新旧模型对比""" device = next(model.parameters()).device model.eval() dataset = MelNPYDataset(mel_dir) if len(dataset) == 0: return float('inf') dataloader = torch.utils.data.DataLoader( dataset, batch_size=64, shuffle=False ) all_errors = [] with torch.no_grad(): for batch in dataloader: batch = batch.to(device) output = model(batch) output = align_to_target(output, batch) mse = torch.mean((output - batch) ** 2, dim=[1, 2, 3]) all_errors.append(mse.cpu().numpy()) errors = np.concatenate(all_errors) return float(np.mean(errors)) # ======================================== # 全量训练入口 # ======================================== def run_full_training(self, data_dir: Path, epochs: Optional[int] = None, lr: Optional[float] = None, device_filter: Optional[List[str]] = None) -> bool: """ 全量训练入口:每设备从零训练独立模型 流程: 1. 扫描外部数据目录 2. 每设备独立提取 Mel 特征+标准化参数 3. 每设备独立训练模型 4. 每设备独立部署(模型+标准化+阈值) Args: data_dir: 数据目录路径 epochs: 训练轮数(None=使用配置文件值) lr: 学习率(None=使用配置文件值) device_filter: 只训练指定设备(None=全部) Returns: bool: 是否成功 """ try: epochs = epochs or self.config['auto_training']['incremental']['epochs'] lr = lr or self.config['auto_training']['incremental']['learning_rate'] logger.info("=" * 60) logger.info(f"全量训练 | 数据目录: {data_dir}") logger.info(f"参数: epochs={epochs}, lr={lr}") if device_filter: logger.info(f"设备过滤: {device_filter}") logger.info("=" * 60) # 1. 收集数据 device_files = self.collect_from_external_dir(data_dir, device_filter) if not device_files: logger.error("无可用训练数据") return False # 2. 每设备提取特征 device_results = self.prepare_mel_features_per_device(device_files) if not device_results: logger.error("特征提取失败") return False # 3. 每设备独立训练+部署 success_count = 0 fail_count = 0 for device_code, (mel_dir, scale) in device_results.items(): try: logger.info(f"\n--- 训练设备: {device_code} ---") # 训练 model, final_loss = self.train_single_device( device_code, mel_dir, epochs, lr, from_scratch=True ) # 验证 if not self._validate_model(model): logger.error(f" {device_code}: 模型验证失败,跳过部署") fail_count += 1 continue # 部署到 models/{device_code}/ self.deploy_device_model(device_code, model, scale, mel_dir) success_count += 1 logger.info(f" {device_code}: 训练+部署完成 | loss={final_loss:.6f}") except Exception as e: logger.error(f" {device_code}: 训练失败 | {e}", exc_info=True) fail_count += 1 logger.info("=" * 60) logger.info(f"全量训练完成: 成功={success_count}, 失败={fail_count}") logger.info("=" * 60) return fail_count == 0 except Exception as e: logger.error(f"全量训练异常: {e}", exc_info=True) return False finally: # 清理临时文件 if self.temp_mel_dir.exists(): shutil.rmtree(self.temp_mel_dir) # ======================================== # 增量训练入口(保留兼容) # ======================================== def run_daily_training(self, on_device_trained=None) -> bool: """ 执行每日增量训练 — 逐设备串行处理 流程(每个设备完整走完再处理下一个,降低内存/CPU 峰值): 1. 收集所有设备的文件列表(仅路径,开销极低) 2. 备份模型 3. 逐设备:提取特征 → 训练 → 验证 → 部署 → 清理临时文件 → 回调通知 4. 更新分类器基线 Args: on_device_trained: 可选回调 fn(device_code: str), 每个设备训练+部署成功后调用, 用于即时触发该设备的模型热重载 Returns: bool: 是否至少有一个设备成功 """ try: days_ago = (self.use_days_ago if self.use_days_ago is not None else self.config['auto_training']['incremental']['use_days_ago']) target_date = (datetime.now() - timedelta(days=days_ago)).strftime('%Y%m%d') mode_str = "冷启动训练" if self.cold_start_mode else "增量训练" logger.info("=" * 60) logger.info(f"{mode_str} - {target_date}") logger.info("=" * 60) # 1. 收集数据(仅文件路径,不加载音频,开销极低) device_files = self.collect_training_data(target_date) total = sum(len(f) for f in device_files.values()) min_samples = self.config['auto_training']['incremental']['min_samples'] if total < min_samples: logger.warning(f"数据不足 ({total} < {min_samples}),跳过") return False # 2. 备份模型 if self.config['auto_training']['model']['backup_before_train']: self.backup_model(target_date) # 3. 训练参数 epochs = (self.epochs if self.epochs is not None else self.config['auto_training']['incremental']['epochs']) lr = (self.learning_rate if self.learning_rate is not None else self.config['auto_training']['incremental']['learning_rate']) from_scratch = self.cold_start_mode model_cfg = self.config['auto_training']['model'] rollback_enabled = model_cfg.get('rollback_on_degradation', True) rollback_factor = model_cfg.get('rollback_factor', 2.0) # 4. 逐设备串行处理:提取 → 训练 → 部署 → 清理 success_count = 0 degraded_count = 0 device_count = len(device_files) for idx, (device_code, wav_files) in enumerate(device_files.items(), 1): logger.info(f"\n{'='*40}") logger.info(f"[{idx}/{device_count}] 设备: {device_code}") logger.info(f"{'='*40}") try: # ── 4a. 单设备特征提取 ── mel_dir, scale = self._extract_mel_for_device( device_code, wav_files ) if mel_dir is None: logger.warning(f"{device_code}: 特征提取无有效数据,跳过") continue # ── 4b. 训练 ── model, final_loss = self.train_single_device( device_code, mel_dir, epochs, lr, from_scratch=from_scratch ) # ── 4c. 形状验证 ── if not self._validate_model(model): logger.error(f"{device_code}: 形状验证失败,跳过部署") continue # ── 4d. 新旧模型对比(增量训练时生效) ── # 在相同验证数据上比较新旧模型的重建误差,新模型更差则跳过部署 if rollback_enabled and not self.cold_start_mode: old_model_path = self.model_root / device_code / "ae_model.pth" if old_model_path.exists(): new_avg_err = self._eval_model_error(model, mel_dir) old_model = ConvAutoencoder() old_model.load_state_dict( torch.load(old_model_path, map_location='cpu') ) old_avg_err = self._eval_model_error(old_model, mel_dir) logger.info( f" {device_code}: 新旧模型对比 | " f"旧模型误差={old_avg_err:.6f} 新模型误差={new_avg_err:.6f}" ) if new_avg_err > old_avg_err * rollback_factor: logger.warning( f"{device_code}: 新模型退化 | " f"新={new_avg_err:.6f} > 旧={old_avg_err:.6f} × {rollback_factor},跳过部署" ) degraded_count += 1 continue # ── 4e. 部署 ── if model_cfg.get('auto_deploy', True): self.deploy_device_model(device_code, model, scale, mel_dir) success_count += 1 logger.info(f"{device_code}: 训练+部署完成 | loss={final_loss:.6f}") # ── 4f. 清理已参与训练的 verified_normal 目录 ── # 核查确认的音频已被模型吸收,训练后清空释放磁盘空间 verified_dir = self.audio_root / device_code / "verified_normal" if verified_dir.exists(): v_count = len(list(verified_dir.glob("*"))) if v_count > 0: shutil.rmtree(verified_dir) verified_dir.mkdir(parents=True, exist_ok=True) logger.info(f"{device_code}: 已清理 verified_normal ({v_count} 个文件)") # ── 4g. 即时通知该设备模型重载 ── if on_device_trained: try: on_device_trained(device_code) except Exception as e: logger.warning(f"{device_code}: 模型重载回调失败 | {e}") except Exception as e: logger.error(f"{device_code}: 训练失败 | {e}", exc_info=True) finally: # ── 4h. 清理该设备的临时 Mel 文件,释放磁盘空间 ── device_mel_dir = self.temp_mel_dir / device_code if device_mel_dir.exists(): shutil.rmtree(device_mel_dir) # 5. 如果所有设备都退化,整体回滚到训练前备份 if degraded_count > 0 and success_count == 0: logger.error( f"所有设备训练后损失退化({degraded_count}个),执行整体回滚" ) self.restore_backup(target_date) return False if degraded_count > 0: logger.warning( f"{degraded_count} 个设备因损失退化跳过部署," f"{success_count} 个设备部署成功" ) logger.info("=" * 60) logger.info(f"增量训练完成: {success_count}/{device_count} 个设备成功") if degraded_count > 0: logger.info(f" 其中 {degraded_count} 个设备因损失退化跳过") logger.info("=" * 60) return success_count > 0 except Exception as e: logger.error(f"训练失败: {e}", exc_info=True) return False finally: if self.temp_mel_dir.exists(): shutil.rmtree(self.temp_mel_dir) # ======================================== # 辅助方法 # ======================================== def _validate_model(self, model: nn.Module) -> bool: # 验证模型输出形状是否合理 if not self.config['auto_training']['validation']['enabled']: return True try: device = next(model.parameters()).device test_input = torch.randn(1, 1, CFG.N_MELS, CFG.TARGET_FRAMES).to(device) with torch.no_grad(): output = model(test_input) h_diff = abs(output.shape[2] - test_input.shape[2]) w_diff = abs(output.shape[3] - test_input.shape[3]) if h_diff > 8 or w_diff > 8: logger.error(f"形状差异过大: {output.shape} vs {test_input.shape}") return False return True except Exception as e: logger.error(f"验证失败: {e}") return False def backup_model(self, date_tag: str): """ 完整备份当前所有设备的模型 备份目录结构: backups/{date_tag}/{device_code}/ ├── ae_model.pth ├── global_scale.npy └── thresholds/ """ backup_date_dir = self.backup_dir / date_tag backup_date_dir.mkdir(parents=True, exist_ok=True) backed_up = 0 # 遍历 models/ 下的所有设备子目录 for device_dir in self.model_root.iterdir(): if not device_dir.is_dir(): continue # 跳过 backups 目录本身 if device_dir.name == "backups": continue # 检查是否包含模型文件(判断是否为设备目录) if not (device_dir / "ae_model.pth").exists(): continue device_backup = backup_date_dir / device_dir.name # 递归复制整个设备目录 shutil.copytree(device_dir, device_backup, dirs_exist_ok=True) backed_up += 1 logger.info(f"备份完成: {backed_up} 个设备 -> {backup_date_dir}") # 清理旧备份 keep = self.config['auto_training']['model']['keep_backups'] backup_dirs = sorted( [d for d in self.backup_dir.iterdir() if d.is_dir() and d.name.isdigit()], reverse=True ) for old_dir in backup_dirs[keep:]: shutil.rmtree(old_dir) logger.info(f"已删除旧备份: {old_dir.name}") def restore_backup(self, date_tag: str) -> bool: """ 从备份恢复所有设备的模型 Args: date_tag: 备份日期标签 'YYYYMMDD' Returns: bool: 是否恢复成功 """ backup_date_dir = self.backup_dir / date_tag if not backup_date_dir.exists(): logger.error(f"备份目录不存在: {backup_date_dir}") return False logger.info(f"从备份恢复: {date_tag}") restored = 0 for device_backup in backup_date_dir.iterdir(): if not device_backup.is_dir(): continue target_dir = self.model_root / device_backup.name # 递归复制恢复 shutil.copytree(device_backup, target_dir, dirs_exist_ok=True) restored += 1 logger.info(f"恢复完成: {restored} 个设备") return restored > 0 def main(): # 命令行入口(增量训练) logging.basicConfig( level=logging.INFO, format='%(asctime)s | %(levelname)-8s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) config_file = Path(__file__).parent.parent / "config" / "auto_training.yaml" trainer = IncrementalTrainer(config_file) success = trainer.run_daily_training() sys.exit(0 if success else 1) if __name__ == "__main__": main()