| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- incremental_trainer.py
- ----------------------
- 模型训练模块
- 功能:
- 1. 支持全量训练(指定外部数据目录,每设备独立模型)
- 2. 支持增量训练(每日自动训练)
- 3. 滑动窗口提取Mel特征(8秒patches)
- 4. 每设备独立计算标准化参数和阈值
- 5. 产出目录结构与推理端 MultiModelPredictor 对齐
- 产出目录结构:
- models/
- ├── {device_code_1}/
- │ ├── ae_model.pth
- │ ├── global_scale.npy
- │ └── thresholds/
- │ └── threshold_{device_code_1}.npy
- └── {device_code_2}/
- ├── ae_model.pth
- ├── global_scale.npy
- └── thresholds/
- └── threshold_{device_code_2}.npy
- """
- import sys
- import random
- import shutil
- import logging
- import numpy as np
- import torch
- import torch.nn as nn
- from pathlib import Path
- from datetime import datetime, timedelta
- from typing import List, Dict, Tuple, Optional
- import yaml
- # 添加父目录到路径
- sys.path.insert(0, str(Path(__file__).parent.parent))
- from predictor import CFG
- from predictor.model_def import ConvAutoencoder
- from predictor.datasets import MelNPYDataset
- from predictor.utils import align_to_target
- logger = logging.getLogger('IncrementalTrainer')
- class IncrementalTrainer:
- """
- 模型训练器
- 支持两种训练模式:
- 1. 全量训练:指定外部数据目录,每设备从零训练独立模型
- 2. 增量训练:使用运行中采集的数据,对已有模型微调(兼容旧逻辑)
- """
- def __init__(self, config_file: Path = None, config: dict = None):
- # 支持两种初始化方式:
- # 1. 传 config dict(从数据库读取后直接传入,主程序使用)
- # 2. 传 config_file YAML 路径(standalone_train.py 等独立工具使用)
- if config is not None:
- self.config = config
- self.config_file = None
- elif config_file is not None:
- self.config_file = config_file
- self.config = self._load_config()
- else:
- raise ValueError("必须提供 config_file 或 config 之一")
- # 路径配置
- self.deploy_root = Path(__file__).parent.parent
- self.audio_root = self.deploy_root / "data" / "audio"
- # 模型根目录(所有设备子目录的父目录)
- self.model_root = self.deploy_root / "models"
- self.backup_dir = self.model_root / "backups"
- # 临时目录
- self.temp_mel_dir = self.deploy_root / "data" / "temp_mels"
- # 确保目录存在
- self.backup_dir.mkdir(parents=True, exist_ok=True)
- # 运行时可覆盖的配置(用于冷启动)
- self.use_days_ago = None
- self.sample_hours = None
- self.epochs = None
- self.learning_rate = None
- self.cold_start_mode = False
- def _load_config(self) -> Dict:
- # 从 YAML 文件加载配置(仅 config_file 模式使用)
- with open(self.config_file, 'r', encoding='utf-8') as f:
- return yaml.safe_load(f)
- # ========================================
- # 数据收集
- # ========================================
- def collect_from_external_dir(self, data_dir: Path,
- device_filter: Optional[List[str]] = None
- ) -> Dict[str, List[Path]]:
- """
- 从外部数据目录收集训练数据
- 目录结构约定:
- data_dir/
- ├── LT-2/ <-- 子文件夹名 = device_code
- │ ├── xxx.wav
- │ └── ...
- └── LT-5/
- ├── yyy.wav
- └── ...
- 支持两种子目录结构:
- 1. 扁平结构:data_dir/{device_code}/*.wav
- 2. 日期嵌套:data_dir/{device_code}/{YYYYMMDD}/*.wav
- Args:
- data_dir: 外部数据目录路径
- device_filter: 只训练指定设备(None=全部)
- Returns:
- {device_code: [wav_files]} 按设备分组的音频文件列表
- """
- data_dir = Path(data_dir)
- if not data_dir.exists():
- raise FileNotFoundError(f"数据目录不存在: {data_dir}")
- logger.info(f"扫描外部数据目录: {data_dir}")
- device_files = {}
- for sub_dir in sorted(data_dir.iterdir()):
- # 跳过非目录
- if not sub_dir.is_dir():
- continue
- device_code = sub_dir.name
- # 应用设备过滤
- if device_filter and device_code not in device_filter:
- logger.info(f" 跳过设备(不在过滤列表中): {device_code}")
- continue
- audio_files = []
- # 扁平结构:直接查找 wav 文件
- audio_files.extend(list(sub_dir.glob("*.wav")))
- audio_files.extend(list(sub_dir.glob("*.mp4")))
- # 日期嵌套结构:查找子目录中的 wav 文件
- for nested_dir in sub_dir.iterdir():
- if nested_dir.is_dir():
- audio_files.extend(list(nested_dir.glob("*.wav")))
- audio_files.extend(list(nested_dir.glob("*.mp4")))
- # 去重
- audio_files = list(set(audio_files))
- if audio_files:
- device_files[device_code] = audio_files
- logger.info(f" {device_code}: {len(audio_files)} 个音频文件")
- else:
- logger.warning(f" {device_code}: 无音频文件,跳过")
- total = sum(len(f) for f in device_files.values())
- logger.info(f"总计: {total} 个音频文件,{len(device_files)} 个设备")
- return device_files
- def collect_training_data(self, target_date: str) -> Dict[str, List[Path]]:
- """
- 从内部数据目录收集训练数据(增量训练用)
- Args:
- target_date: 日期字符串 'YYYYMMDD'
- Returns:
- {device_code: [wav_files]} 按设备分组的音频文件
- """
- logger.info(f"收集 {target_date} 的训练数据")
- sample_hours = self.config['auto_training']['incremental'].get('sample_hours', 0)
- device_files = {}
- if not self.audio_root.exists():
- logger.warning(f"音频目录不存在: {self.audio_root}")
- return device_files
- for device_dir in self.audio_root.iterdir():
- if not device_dir.is_dir():
- continue
- device_code = device_dir.name
- audio_files = []
- # 冷启动模式:收集所有已归档日期目录的正常音频(跳过 current/)
- if self.cold_start_mode:
- # 注意:跳过 current/ 目录,因其中可能包含 FFmpeg 正在写入的不完整文件
- for sub_dir in device_dir.iterdir():
- if sub_dir.is_dir() and sub_dir.name.isdigit() and len(sub_dir.name) == 8:
- # 新结构:从 {date}/normal/ 子目录读取
- normal_dir = sub_dir / "normal"
- if normal_dir.exists():
- audio_files.extend(list(normal_dir.glob("*.wav")))
- audio_files.extend(list(normal_dir.glob("*.mp4")))
- # 兼容旧结构:日期目录下直接存放的音频文件
- audio_files.extend(list(sub_dir.glob("*.wav")))
- audio_files.extend(list(sub_dir.glob("*.mp4")))
- else:
- # 正常模式:只收集指定日期的正常音频
- date_dir = device_dir / target_date
- if date_dir.exists():
- # 新结构:从 {date}/normal/ 子目录读取
- normal_dir = date_dir / "normal"
- if normal_dir.exists():
- audio_files.extend(list(normal_dir.glob("*.wav")))
- audio_files.extend(list(normal_dir.glob("*.mp4")))
- # 兼容旧结构:日期目录下直接存放的音频文件
- audio_files.extend(list(date_dir.glob("*.wav")))
- audio_files.extend(list(date_dir.glob("*.mp4")))
- # 加上 verified_normal 目录(单独收集,不参与采样和质量预筛)
- verified_dir = device_dir / "verified_normal"
- verified_files = []
- if verified_dir.exists():
- verified_files.extend(list(verified_dir.glob("*.wav")))
- verified_files.extend(list(verified_dir.glob("*.mp4")))
- # 去重(仅日期目录音频)
- audio_files = list(set(audio_files))
- # 数据质量预筛:仅对日期目录音频过滤,verified_normal 已经人工确认,跳过
- if audio_files and not self.cold_start_mode:
- before_count = len(audio_files)
- audio_files = self._filter_audio_quality(audio_files, device_code)
- filtered = before_count - len(audio_files)
- if filtered > 0:
- logger.info(f" {device_code}: 质量预筛过滤 {filtered} 个异常音频")
- # 随机采样(仅对日期目录音频采样,verified_normal 不参与)
- if sample_hours > 0 and audio_files:
- files_needed = int(sample_hours * 3600 / 60)
- if len(audio_files) > files_needed:
- audio_files = random.sample(audio_files, files_needed)
- logger.info(f" {device_code}: 随机采样 {len(audio_files)} 个音频")
- else:
- logger.info(f" {device_code}: {len(audio_files)} 个音频(全部使用)")
- else:
- logger.info(f" {device_code}: {len(audio_files)} 个音频")
- # 合并 verified_normal(采样后追加,保证全量参与训练)
- if verified_files:
- audio_files.extend(verified_files)
- logger.info(f" {device_code}: +{len(verified_files)} 个核查确认音频(verified_normal)")
- if audio_files:
- device_files[device_code] = audio_files
- total = sum(len(f) for f in device_files.values())
- logger.info(f"总计: {total} 个音频文件,{len(device_files)} 个设备")
- return device_files
- def _filter_audio_quality(self, audio_files: List[Path],
- device_code: str) -> List[Path]:
- """
- 音频质量预筛:基于 RMS 能量和频谱质心过滤明显异常的样本
- 使用 IQR (四分位距) 方法检测离群值:
- - 计算所有文件的 RMS 能量和频谱质心
- - 过滤超出 [Q1 - 2*IQR, Q3 + 2*IQR] 范围的文件
- 需要至少 10 个文件才执行过滤,否则样本太少无统计意义。
- Args:
- audio_files: 待过滤的音频文件列表
- device_code: 设备编码(用于日志)
- Returns:
- 过滤后的文件列表
- """
- if len(audio_files) < 10:
- return audio_files
- import librosa
- # 快速计算每个文件的 RMS 能量
- rms_values = []
- valid_files = []
- for wav_file in audio_files:
- try:
- y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True,
- duration=10) # 只读前10秒加速
- if len(y) < CFG.SR:
- continue
- rms = float(np.sqrt(np.mean(y ** 2)))
- rms_values.append(rms)
- valid_files.append(wav_file)
- except Exception:
- continue
- if len(rms_values) < 10:
- return audio_files
- # IQR 离群值检测
- rms_arr = np.array(rms_values)
- q1, q3 = np.percentile(rms_arr, [25, 75])
- iqr = q3 - q1
- lower_bound = q1 - 2 * iqr
- upper_bound = q3 + 2 * iqr
- filtered = []
- for f, rms in zip(valid_files, rms_values):
- if lower_bound <= rms <= upper_bound:
- filtered.append(f)
- return filtered
- # ========================================
- # 特征提取(每设备独立标准化参数)
- # ========================================
- def _extract_mel_for_device(self, device_code: str,
- wav_files: List[Path]
- ) -> Tuple[Optional[Path], Optional[Tuple[float, float]]]:
- """
- 为单个设备提取 Mel 特征并计算独立的 Min-Max 标准化参数
- 流式两遍扫描(内存优化):
- 1. 第一遍:只计算 running min/max(O(1) 内存),不保存 mel_db
- 2. 第二遍:用第一遍的 min/max 标准化后直接写 npy 文件
- Args:
- device_code: 设备编码
- wav_files: 该设备的音频文件列表
- Returns:
- (mel_dir, (global_min, global_max)),失败返回 (None, None)
- """
- import librosa
- # 滑动窗口参数
- win_samples = int(CFG.WIN_SEC * CFG.SR)
- hop_samples = int(CFG.HOP_SEC * CFG.SR)
- def _iter_mel_patches(files):
- """生成器:逐文件逐 patch 产出 mel_db,避免全量加载到内存"""
- for wav_file in files:
- try:
- y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True)
- if len(y) < CFG.SR:
- continue
- for idx, start in enumerate(range(0, len(y) - win_samples + 1, hop_samples)):
- segment = y[start:start + win_samples]
- mel_spec = librosa.feature.melspectrogram(
- y=segment, sr=CFG.SR, n_fft=CFG.N_FFT,
- hop_length=CFG.HOP_LENGTH, n_mels=CFG.N_MELS, power=2.0
- )
- mel_db = librosa.power_to_db(mel_spec, ref=np.max)
- # 对齐帧数
- if mel_db.shape[1] < CFG.TARGET_FRAMES:
- pad = CFG.TARGET_FRAMES - mel_db.shape[1]
- mel_db = np.pad(mel_db, ((0, 0), (0, pad)), mode="constant")
- else:
- mel_db = mel_db[:, :CFG.TARGET_FRAMES]
- yield wav_file, idx, mel_db
- except Exception as e:
- logger.warning(f"跳过文件 {wav_file.name}: {e}")
- continue
- # ── 第一遍:流式计算 running min/max(O(1) 内存) ──
- global_min = float('inf')
- global_max = float('-inf')
- patch_count = 0
- for _, _, mel_db in _iter_mel_patches(wav_files):
- local_min = float(mel_db.min())
- local_max = float(mel_db.max())
- if local_min < global_min:
- global_min = local_min
- if local_max > global_max:
- global_max = local_max
- patch_count += 1
- if patch_count == 0:
- logger.warning(f" {device_code}: 无有效数据")
- return None, None
- logger.info(f" {device_code}: {patch_count} patches | "
- f"min={global_min:.4f} max={global_max:.4f}")
- # ── 第二遍:Min-Max 标准化并保存 ──
- device_mel_dir = self.temp_mel_dir / device_code
- device_mel_dir.mkdir(parents=True, exist_ok=True)
- scale_range = global_max - global_min + 1e-6
- for wav_file, idx, mel_db in _iter_mel_patches(wav_files):
- mel_norm = (mel_db - global_min) / scale_range
- npy_name = f"{device_code}@@{wav_file.stem}@@win{idx:05d}.npy"
- np.save(device_mel_dir / npy_name, mel_norm.astype(np.float32))
- return device_mel_dir, (global_min, global_max)
- def prepare_mel_features_per_device(self, device_files: Dict[str, List[Path]]
- ) -> Dict[str, Tuple[Path, Tuple[float, float]]]:
- """
- 为每个设备独立提取 Mel 特征
- 每设备分别计算自己的 Min-Max 标准化参数 (global_min, global_max)
- Args:
- device_files: {device_code: [wav_files]}
- Returns:
- {device_code: (mel_dir, (global_min, global_max))}
- """
- logger.info("提取 Mel 特征(每设备独立标准化)")
- # 清空临时目录
- if self.temp_mel_dir.exists():
- shutil.rmtree(self.temp_mel_dir)
- self.temp_mel_dir.mkdir(parents=True, exist_ok=True)
- device_results = {}
- for device_code, wav_files in device_files.items():
- mel_dir, scale = self._extract_mel_for_device(device_code, wav_files)
- if mel_dir is not None:
- device_results[device_code] = (mel_dir, scale)
- total_patches = sum(
- len(list(mel_dir.glob("*.npy")))
- for mel_dir, _ in device_results.values()
- )
- logger.info(f"提取完成: {total_patches} patches,{len(device_results)} 个设备")
- return device_results
- # ========================================
- # 模型训练(每设备独立)
- # ========================================
- def _select_training_device(self) -> torch.device:
- # 智能选择训练设备:GPU 显存充足则使用,否则回退 CPU
- # 训练配置中可通过 training_device 强制指定 (auto/cpu/cuda)
- training_cfg = self.config['auto_training']['incremental']
- forced_device = training_cfg.get('training_device', 'auto')
- # 强制指定设备时直接返回
- if forced_device == 'cpu':
- logger.info("训练设备: CPU(配置强制指定)")
- return torch.device('cpu')
- if forced_device == 'cuda':
- if torch.cuda.is_available():
- return torch.device('cuda')
- logger.warning("配置指定 CUDA 但不可用,回退 CPU")
- return torch.device('cpu')
- # auto 模式:检测 CUDA → CPU
- if torch.cuda.is_available():
- try:
- free_mem = torch.cuda.mem_get_info()[0] / (1024 * 1024)
- min_gpu_mem_mb = training_cfg.get('min_gpu_mem_mb', 512)
- if free_mem >= min_gpu_mem_mb:
- logger.info(f"训练设备: CUDA(空闲显存 {free_mem:.0f}MB)")
- return torch.device('cuda')
- logger.info(
- f"CUDA 空闲显存不足 ({free_mem:.0f}MB < {min_gpu_mem_mb}MB)"
- )
- except Exception as e:
- logger.warning(f"CUDA 显存检测失败: {e}")
- logger.info("训练设备: CPU")
- return torch.device('cpu')
- def _run_training_loop(self, device_code: str, model: nn.Module,
- train_loader, val_loader, epochs: int, lr: float,
- device: torch.device) -> Tuple[nn.Module, float]:
- # 执行实际的训练循环,与设备选择解耦
- # 早停基于验证集损失(如有),否则基于训练损失
- model = model.to(device)
- model.train()
- optimizer = torch.optim.Adam(model.parameters(), lr=lr)
- criterion = nn.MSELoss()
- # AMP 混合精度(GPU 生效,减少约 40% 显存占用)
- use_amp = device.type == 'cuda'
- scaler = torch.amp.GradScaler(device.type) if use_amp else None
- # 早停配置
- early_stop_cfg = self.config['auto_training']['incremental']
- patience = early_stop_cfg.get('early_stop_patience', 5)
- best_loss = float('inf')
- no_improve_count = 0
- avg_loss = 0.0
- actual_epochs = 0
- for epoch in range(epochs):
- # ── 训练阶段 ──
- model.train()
- epoch_loss = 0.0
- batch_count = 0
- for batch in train_loader:
- batch = batch.to(device)
- optimizer.zero_grad()
- if use_amp:
- with torch.amp.autocast(device.type):
- output = model(batch)
- output = align_to_target(output, batch)
- loss = criterion(output, batch)
- scaler.scale(loss).backward()
- scaler.step(optimizer)
- scaler.update()
- else:
- output = model(batch)
- output = align_to_target(output, batch)
- loss = criterion(output, batch)
- loss.backward()
- optimizer.step()
- epoch_loss += loss.item()
- batch_count += 1
- avg_loss = epoch_loss / batch_count
- actual_epochs = epoch + 1
- # ── 验证阶段(如有验证集) ──
- if val_loader is not None:
- model.eval()
- val_loss = 0.0
- val_count = 0
- with torch.no_grad():
- for batch in val_loader:
- batch = batch.to(device)
- if use_amp:
- with torch.amp.autocast(device.type):
- output = model(batch)
- output = align_to_target(output, batch)
- loss = criterion(output, batch)
- else:
- output = model(batch)
- output = align_to_target(output, batch)
- loss = criterion(output, batch)
- val_loss += loss.item()
- val_count += 1
- avg_val_loss = val_loss / val_count
- monitor_loss = avg_val_loss # 早停监控验证损失
- else:
- avg_val_loss = None
- monitor_loss = avg_loss # 无验证集时回退训练损失
- if actual_epochs % 10 == 0 or epoch == epochs - 1:
- val_str = f" | ValLoss: {avg_val_loss:.6f}" if avg_val_loss is not None else ""
- logger.info(f" [{device_code}] Epoch {actual_epochs}/{epochs} | "
- f"Loss: {avg_loss:.6f}{val_str} | device={device.type}")
- # 早停检测:连续 patience 轮无改善则提前终止
- if monitor_loss < best_loss:
- best_loss = monitor_loss
- no_improve_count = 0
- else:
- no_improve_count += 1
- if no_improve_count >= patience and actual_epochs >= 10:
- logger.info(f" [{device_code}] 早停触发: 连续{patience}轮无改善 | "
- f"最终轮数={actual_epochs}/{epochs} | Loss={avg_loss:.6f}")
- break
- # 训练后清理 GPU 缓存
- if device.type == 'cuda':
- torch.cuda.empty_cache()
- if actual_epochs < epochs:
- logger.info(f" [{device_code}] 早停节省 {epochs - actual_epochs} 轮训练")
- return model, avg_loss
- def train_single_device(self, device_code: str, mel_dir: Path,
- epochs: int, lr: float,
- from_scratch: bool = True
- ) -> Tuple[nn.Module, float]:
- # 训练单个设备的独立模型
- # 策略:优先 GPU 训练,显存不足自动回退 CPU;训练中 OOM 也会捕获并用 CPU 重试
- logger.info(f"训练设备 {device_code}: epochs={epochs}, lr={lr}, "
- f"mode={'全量' if from_scratch else '增量'}")
- # 智能选择训练设备
- device = self._select_training_device()
- # 训练前清理 GPU 缓存,释放推理残留的显存碎片
- if device.type == 'cuda':
- torch.cuda.empty_cache()
- import gc
- gc.collect()
- model = ConvAutoencoder()
- # 增量模式下加载已有模型
- if not from_scratch:
- model_path = self.model_root / device_code / "ae_model.pth"
- if model_path.exists():
- model.load_state_dict(torch.load(model_path, map_location='cpu'))
- logger.info(f" 已加载已有模型: {model_path}")
- else:
- logger.warning(f" 模型不存在,自动切换为全量训练: {model_path}")
- # 加载数据并按 80/20 划分训练集/验证集
- dataset = MelNPYDataset(mel_dir)
- if len(dataset) == 0:
- raise ValueError(f"设备 {device_code} 无训练数据")
- batch_size = self.config['auto_training']['incremental']['batch_size']
- # 验证集划分:数据量 >= 20 时才划分(否则太少无统计意义)
- val_loader = None
- if len(dataset) >= 20:
- val_size = max(1, int(len(dataset) * 0.2))
- train_size = len(dataset) - val_size
- train_dataset, val_dataset = torch.utils.data.random_split(
- dataset, [train_size, val_size]
- )
- train_loader = torch.utils.data.DataLoader(
- train_dataset, batch_size=batch_size, shuffle=True,
- num_workers=0, pin_memory=False
- )
- val_loader = torch.utils.data.DataLoader(
- val_dataset, batch_size=batch_size, shuffle=False,
- num_workers=0, pin_memory=False
- )
- logger.info(f" 数据集划分: 训练={train_size}, 验证={val_size}")
- else:
- train_loader = torch.utils.data.DataLoader(
- dataset, batch_size=batch_size, shuffle=True,
- num_workers=0, pin_memory=False
- )
- logger.info(f" 数据量不足20,跳过验证集划分(共{len(dataset)}样本)")
- # 尝试在选定设备上训练
- if device.type == 'cuda':
- try:
- return self._run_training_loop(
- device_code, model, train_loader, val_loader,
- epochs, lr, device
- )
- except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
- # GPU OOM -> 清理显存后回退 CPU 重试
- if 'out of memory' not in str(e).lower() and isinstance(e, RuntimeError):
- raise # 非 OOM 的 RuntimeError 不拦截
- logger.warning(
- f" [{device_code}] CUDA OOM,清理显存后回退 CPU 训练"
- )
- import gc
- gc.collect()
- torch.cuda.empty_cache()
- # 模型可能处于脏状态,重新初始化
- model = ConvAutoencoder()
- if not from_scratch:
- model_path = self.model_root / device_code / "ae_model.pth"
- if model_path.exists():
- model.load_state_dict(
- torch.load(model_path, map_location='cpu')
- )
- return self._run_training_loop(
- device_code, model, train_loader, val_loader,
- epochs, lr, torch.device('cpu')
- )
- else:
- # CPU 训练,无需 OOM 保护
- return self._run_training_loop(
- device_code, model, train_loader, val_loader,
- epochs, lr, device
- )
- # ========================================
- # 产出部署(每设备独立目录)
- # ========================================
- def deploy_device_model(self, device_code: str, model: nn.Module,
- scale: Tuple[float, float], mel_dir: Path):
- """
- 部署单个设备的模型到 models/{device_code}/ 目录
- 产出文件:
- - models/{device_code}/ae_model.pth
- - models/{device_code}/global_scale.npy → [mean, std]
- - models/{device_code}/thresholds/threshold_{device_code}.npy
- Args:
- device_code: 设备编码
- model: 训练后的模型
- scale: (global_min, global_max) Min-Max 标准化参数
- mel_dir: 该设备的 Mel 特征目录(用于计算阈值)
- """
- # 创建设备模型目录
- device_model_dir = self.model_root / device_code
- device_model_dir.mkdir(parents=True, exist_ok=True)
- # 1. 保存模型权重
- model_path = device_model_dir / "ae_model.pth"
- torch.save(model.state_dict(), model_path)
- logger.info(f" 模型已保存: {model_path}")
- # 2. 保存 Min-Max 标准化参数 [min, max]
- scale_path = device_model_dir / "global_scale.npy"
- np.save(scale_path, np.array([scale[0], scale[1]]))
- logger.info(f" 标准化参数已保存: {scale_path}")
- # 3. 计算并保存阈值
- threshold_dir = device_model_dir / "thresholds"
- threshold_dir.mkdir(parents=True, exist_ok=True)
- threshold = self._compute_threshold(model, mel_dir)
- threshold_file = threshold_dir / f"threshold_{device_code}.npy"
- np.save(threshold_file, threshold)
- logger.info(f" 阈值已保存: {threshold_file} (值={threshold:.6f})")
- def _compute_threshold(self, model: nn.Module, mel_dir: Path) -> float:
- """
- 计算单个设备的阈值
- 使用 3σ 法则:threshold = mean + 3 * std
- Args:
- model: 训练后的模型
- mel_dir: 该设备的 Mel 特征目录
- Returns:
- 阈值(标量)
- """
- device = next(model.parameters()).device
- model.eval()
- dataset = MelNPYDataset(mel_dir)
- if len(dataset) == 0:
- logger.warning("无数据计算阈值,使用默认值 0.01")
- return 0.01
- dataloader = torch.utils.data.DataLoader(
- dataset, batch_size=64, shuffle=False
- )
- all_errors = []
- with torch.no_grad():
- for batch in dataloader:
- batch = batch.to(device)
- output = model(batch)
- output = align_to_target(output, batch)
- mse = torch.mean((output - batch) ** 2, dim=[1, 2, 3])
- all_errors.append(mse.cpu().numpy())
- errors = np.concatenate(all_errors)
- # 3σ 法则
- mean_err = float(np.mean(errors))
- std_err = float(np.std(errors))
- threshold = mean_err + 3 * std_err
- logger.info(f" 阈值统计: 3σ={threshold:.6f} | "
- f"mean={mean_err:.6f} std={std_err:.6f} | "
- f"样本数={len(errors)}")
- return threshold
- def _eval_model_error(self, model: nn.Module, mel_dir: Path) -> float:
- """在验证数据上计算模型的平均重建误差,用于新旧模型对比"""
- device = next(model.parameters()).device
- model.eval()
- dataset = MelNPYDataset(mel_dir)
- if len(dataset) == 0:
- return float('inf')
- dataloader = torch.utils.data.DataLoader(
- dataset, batch_size=64, shuffle=False
- )
- all_errors = []
- with torch.no_grad():
- for batch in dataloader:
- batch = batch.to(device)
- output = model(batch)
- output = align_to_target(output, batch)
- mse = torch.mean((output - batch) ** 2, dim=[1, 2, 3])
- all_errors.append(mse.cpu().numpy())
- errors = np.concatenate(all_errors)
- return float(np.mean(errors))
- # ========================================
- # 全量训练入口
- # ========================================
- def run_full_training(self, data_dir: Path,
- epochs: Optional[int] = None,
- lr: Optional[float] = None,
- device_filter: Optional[List[str]] = None) -> bool:
- """
- 全量训练入口:每设备从零训练独立模型
- 流程:
- 1. 扫描外部数据目录
- 2. 每设备独立提取 Mel 特征+标准化参数
- 3. 每设备独立训练模型
- 4. 每设备独立部署(模型+标准化+阈值)
- Args:
- data_dir: 数据目录路径
- epochs: 训练轮数(None=使用配置文件值)
- lr: 学习率(None=使用配置文件值)
- device_filter: 只训练指定设备(None=全部)
- Returns:
- bool: 是否成功
- """
- try:
- epochs = epochs or self.config['auto_training']['incremental']['epochs']
- lr = lr or self.config['auto_training']['incremental']['learning_rate']
- logger.info("=" * 60)
- logger.info(f"全量训练 | 数据目录: {data_dir}")
- logger.info(f"参数: epochs={epochs}, lr={lr}")
- if device_filter:
- logger.info(f"设备过滤: {device_filter}")
- logger.info("=" * 60)
- # 1. 收集数据
- device_files = self.collect_from_external_dir(data_dir, device_filter)
- if not device_files:
- logger.error("无可用训练数据")
- return False
- # 2. 每设备提取特征
- device_results = self.prepare_mel_features_per_device(device_files)
- if not device_results:
- logger.error("特征提取失败")
- return False
- # 3. 每设备独立训练+部署
- success_count = 0
- fail_count = 0
- for device_code, (mel_dir, scale) in device_results.items():
- try:
- logger.info(f"\n--- 训练设备: {device_code} ---")
- # 训练
- model, final_loss = self.train_single_device(
- device_code, mel_dir, epochs, lr, from_scratch=True
- )
- # 验证
- if not self._validate_model(model):
- logger.error(f" {device_code}: 模型验证失败,跳过部署")
- fail_count += 1
- continue
- # 部署到 models/{device_code}/
- self.deploy_device_model(device_code, model, scale, mel_dir)
- success_count += 1
- logger.info(f" {device_code}: 训练+部署完成 | loss={final_loss:.6f}")
- except Exception as e:
- logger.error(f" {device_code}: 训练失败 | {e}", exc_info=True)
- fail_count += 1
- logger.info("=" * 60)
- logger.info(f"全量训练完成: 成功={success_count}, 失败={fail_count}")
- logger.info("=" * 60)
- return fail_count == 0
- except Exception as e:
- logger.error(f"全量训练异常: {e}", exc_info=True)
- return False
- finally:
- # 清理临时文件
- if self.temp_mel_dir.exists():
- shutil.rmtree(self.temp_mel_dir)
- # ========================================
- # 增量训练入口(保留兼容)
- # ========================================
- def run_daily_training(self, on_device_trained=None) -> bool:
- """
- 执行每日增量训练 — 逐设备串行处理
- 流程(每个设备完整走完再处理下一个,降低内存/CPU 峰值):
- 1. 收集所有设备的文件列表(仅路径,开销极低)
- 2. 备份模型
- 3. 逐设备:提取特征 → 训练 → 验证 → 部署 → 清理临时文件 → 回调通知
- 4. 更新分类器基线
- Args:
- on_device_trained: 可选回调 fn(device_code: str),
- 每个设备训练+部署成功后调用,
- 用于即时触发该设备的模型热重载
- Returns:
- bool: 是否至少有一个设备成功
- """
- try:
- days_ago = (self.use_days_ago if self.use_days_ago is not None
- else self.config['auto_training']['incremental']['use_days_ago'])
- target_date = (datetime.now() - timedelta(days=days_ago)).strftime('%Y%m%d')
- mode_str = "冷启动训练" if self.cold_start_mode else "增量训练"
- logger.info("=" * 60)
- logger.info(f"{mode_str} - {target_date}")
- logger.info("=" * 60)
- # 1. 收集数据(仅文件路径,不加载音频,开销极低)
- device_files = self.collect_training_data(target_date)
- total = sum(len(f) for f in device_files.values())
- min_samples = self.config['auto_training']['incremental']['min_samples']
- if total < min_samples:
- logger.warning(f"数据不足 ({total} < {min_samples}),跳过")
- return False
- # 2. 备份模型
- if self.config['auto_training']['model']['backup_before_train']:
- self.backup_model(target_date)
- # 3. 训练参数
- epochs = (self.epochs if self.epochs is not None
- else self.config['auto_training']['incremental']['epochs'])
- lr = (self.learning_rate if self.learning_rate is not None
- else self.config['auto_training']['incremental']['learning_rate'])
- from_scratch = self.cold_start_mode
- model_cfg = self.config['auto_training']['model']
- rollback_enabled = model_cfg.get('rollback_on_degradation', True)
- rollback_factor = model_cfg.get('rollback_factor', 2.0)
- # 4. 逐设备串行处理:提取 → 训练 → 部署 → 清理
- success_count = 0
- degraded_count = 0
- device_count = len(device_files)
- for idx, (device_code, wav_files) in enumerate(device_files.items(), 1):
- logger.info(f"\n{'='*40}")
- logger.info(f"[{idx}/{device_count}] 设备: {device_code}")
- logger.info(f"{'='*40}")
- try:
- # ── 4a. 单设备特征提取 ──
- mel_dir, scale = self._extract_mel_for_device(
- device_code, wav_files
- )
- if mel_dir is None:
- logger.warning(f"{device_code}: 特征提取无有效数据,跳过")
- continue
- # ── 4b. 训练 ──
- model, final_loss = self.train_single_device(
- device_code, mel_dir, epochs, lr, from_scratch=from_scratch
- )
- # ── 4c. 形状验证 ──
- if not self._validate_model(model):
- logger.error(f"{device_code}: 形状验证失败,跳过部署")
- continue
- # ── 4d. 新旧模型对比(增量训练时生效) ──
- # 在相同验证数据上比较新旧模型的重建误差,新模型更差则跳过部署
- if rollback_enabled and not self.cold_start_mode:
- old_model_path = self.model_root / device_code / "ae_model.pth"
- if old_model_path.exists():
- new_avg_err = self._eval_model_error(model, mel_dir)
- old_model = ConvAutoencoder()
- old_model.load_state_dict(
- torch.load(old_model_path, map_location='cpu')
- )
- old_avg_err = self._eval_model_error(old_model, mel_dir)
- logger.info(
- f" {device_code}: 新旧模型对比 | "
- f"旧模型误差={old_avg_err:.6f} 新模型误差={new_avg_err:.6f}"
- )
- if new_avg_err > old_avg_err * rollback_factor:
- logger.warning(
- f"{device_code}: 新模型退化 | "
- f"新={new_avg_err:.6f} > 旧={old_avg_err:.6f} × {rollback_factor},跳过部署"
- )
- degraded_count += 1
- continue
- # ── 4e. 部署 ──
- if model_cfg.get('auto_deploy', True):
- self.deploy_device_model(device_code, model, scale, mel_dir)
- success_count += 1
- logger.info(f"{device_code}: 训练+部署完成 | loss={final_loss:.6f}")
- # ── 4f. 清理已参与训练的 verified_normal 目录 ──
- # 核查确认的音频已被模型吸收,训练后清空释放磁盘空间
- verified_dir = self.audio_root / device_code / "verified_normal"
- if verified_dir.exists():
- v_count = len(list(verified_dir.glob("*")))
- if v_count > 0:
- shutil.rmtree(verified_dir)
- verified_dir.mkdir(parents=True, exist_ok=True)
- logger.info(f"{device_code}: 已清理 verified_normal ({v_count} 个文件)")
- # ── 4g. 即时通知该设备模型重载 ──
- if on_device_trained:
- try:
- on_device_trained(device_code)
- except Exception as e:
- logger.warning(f"{device_code}: 模型重载回调失败 | {e}")
- except Exception as e:
- logger.error(f"{device_code}: 训练失败 | {e}", exc_info=True)
- finally:
- # ── 4h. 清理该设备的临时 Mel 文件,释放磁盘空间 ──
- device_mel_dir = self.temp_mel_dir / device_code
- if device_mel_dir.exists():
- shutil.rmtree(device_mel_dir)
- # 5. 如果所有设备都退化,整体回滚到训练前备份
- if degraded_count > 0 and success_count == 0:
- logger.error(
- f"所有设备训练后损失退化({degraded_count}个),执行整体回滚"
- )
- self.restore_backup(target_date)
- return False
- if degraded_count > 0:
- logger.warning(
- f"{degraded_count} 个设备因损失退化跳过部署,"
- f"{success_count} 个设备部署成功"
- )
- logger.info("=" * 60)
- logger.info(f"增量训练完成: {success_count}/{device_count} 个设备成功")
- if degraded_count > 0:
- logger.info(f" 其中 {degraded_count} 个设备因损失退化跳过")
- logger.info("=" * 60)
- return success_count > 0
- except Exception as e:
- logger.error(f"训练失败: {e}", exc_info=True)
- return False
- finally:
- if self.temp_mel_dir.exists():
- shutil.rmtree(self.temp_mel_dir)
- # ========================================
- # 辅助方法
- # ========================================
- def _validate_model(self, model: nn.Module) -> bool:
- # 验证模型输出形状是否合理
- if not self.config['auto_training']['validation']['enabled']:
- return True
- try:
- device = next(model.parameters()).device
- test_input = torch.randn(1, 1, CFG.N_MELS, CFG.TARGET_FRAMES).to(device)
- with torch.no_grad():
- output = model(test_input)
- h_diff = abs(output.shape[2] - test_input.shape[2])
- w_diff = abs(output.shape[3] - test_input.shape[3])
- if h_diff > 8 or w_diff > 8:
- logger.error(f"形状差异过大: {output.shape} vs {test_input.shape}")
- return False
- return True
- except Exception as e:
- logger.error(f"验证失败: {e}")
- return False
- def backup_model(self, date_tag: str):
- """
- 完整备份当前所有设备的模型
- 备份目录结构:
- backups/{date_tag}/{device_code}/
- ├── ae_model.pth
- ├── global_scale.npy
- └── thresholds/
- """
- backup_date_dir = self.backup_dir / date_tag
- backup_date_dir.mkdir(parents=True, exist_ok=True)
- backed_up = 0
- # 遍历 models/ 下的所有设备子目录
- for device_dir in self.model_root.iterdir():
- if not device_dir.is_dir():
- continue
- # 跳过 backups 目录本身
- if device_dir.name == "backups":
- continue
- # 检查是否包含模型文件(判断是否为设备目录)
- if not (device_dir / "ae_model.pth").exists():
- continue
- device_backup = backup_date_dir / device_dir.name
- # 递归复制整个设备目录
- shutil.copytree(device_dir, device_backup, dirs_exist_ok=True)
- backed_up += 1
- logger.info(f"备份完成: {backed_up} 个设备 -> {backup_date_dir}")
- # 清理旧备份
- keep = self.config['auto_training']['model']['keep_backups']
- backup_dirs = sorted(
- [d for d in self.backup_dir.iterdir() if d.is_dir() and d.name.isdigit()],
- reverse=True
- )
- for old_dir in backup_dirs[keep:]:
- shutil.rmtree(old_dir)
- logger.info(f"已删除旧备份: {old_dir.name}")
- def restore_backup(self, date_tag: str) -> bool:
- """
- 从备份恢复所有设备的模型
- Args:
- date_tag: 备份日期标签 'YYYYMMDD'
- Returns:
- bool: 是否恢复成功
- """
- backup_date_dir = self.backup_dir / date_tag
- if not backup_date_dir.exists():
- logger.error(f"备份目录不存在: {backup_date_dir}")
- return False
- logger.info(f"从备份恢复: {date_tag}")
- restored = 0
- for device_backup in backup_date_dir.iterdir():
- if not device_backup.is_dir():
- continue
- target_dir = self.model_root / device_backup.name
- # 递归复制恢复
- shutil.copytree(device_backup, target_dir, dirs_exist_ok=True)
- restored += 1
- logger.info(f"恢复完成: {restored} 个设备")
- return restored > 0
- def main():
- # 命令行入口(增量训练)
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s | %(levelname)-8s | %(message)s',
- datefmt='%Y-%m-%d %H:%M:%S'
- )
- config_file = Path(__file__).parent.parent / "config" / "auto_training.yaml"
- trainer = IncrementalTrainer(config_file)
- success = trainer.run_daily_training()
- sys.exit(0 if success else 1)
- if __name__ == "__main__":
- main()
|