| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- incremental_trainer.py
- ----------------------
- 模型训练模块
- 功能:
- 1. 支持全量训练(指定外部数据目录,每设备独立模型)
- 2. 支持增量训练(每日自动训练)
- 3. 滑动窗口提取Mel特征(8秒patches)
- 4. 每设备独立计算标准化参数和阈值
- 5. 产出目录结构与推理端 MultiModelPredictor 对齐
- 产出目录结构:
- models/
- ├── {device_code_1}/
- │ ├── ae_model.pth
- │ ├── global_scale.npy
- │ └── thresholds/
- │ └── threshold_{device_code_1}.npy
- └── {device_code_2}/
- ├── ae_model.pth
- ├── global_scale.npy
- └── thresholds/
- └── threshold_{device_code_2}.npy
- """
- import sys
- import random
- import shutil
- import logging
- import numpy as np
- import torch
- import torch.nn as nn
- from pathlib import Path
- from datetime import datetime, timedelta
- from typing import List, Dict, Tuple, Optional
- import yaml
- # 添加父目录到路径
- sys.path.insert(0, str(Path(__file__).parent.parent))
- from predictor import CFG
- from predictor.model_def import ConvAutoencoder
- from predictor.datasets import MelNPYDataset
- from predictor.utils import align_to_target
- logger = logging.getLogger('IncrementalTrainer')
- class IncrementalTrainer:
- """
- 模型训练器
- 支持两种训练模式:
- 1. 全量训练:指定外部数据目录,每设备从零训练独立模型
- 2. 增量训练:使用运行中采集的数据,对已有模型微调(兼容旧逻辑)
- """
- def __init__(self, config_file: Path):
- """
- 初始化训练器
- Args:
- config_file: auto_training.yaml 配置文件路径
- """
- self.config_file = config_file
- self.config = self._load_config()
- # 路径配置
- self.deploy_root = Path(__file__).parent.parent
- self.audio_root = self.deploy_root / "data" / "audio"
- # 模型根目录(所有设备子目录的父目录)
- self.model_root = self.deploy_root / "models"
- self.backup_dir = self.model_root / "backups"
- # 临时目录
- self.temp_mel_dir = self.deploy_root / "data" / "temp_mels"
- # 确保目录存在
- self.backup_dir.mkdir(parents=True, exist_ok=True)
- # 运行时可覆盖的配置(用于冷启动)
- self.use_days_ago = None
- self.sample_hours = None
- self.epochs = None
- self.learning_rate = None
- self.cold_start_mode = False
- def _load_config(self) -> Dict:
- # 加载配置文件
- with open(self.config_file, 'r', encoding='utf-8') as f:
- return yaml.safe_load(f)
- # ========================================
- # 数据收集
- # ========================================
- def collect_from_external_dir(self, data_dir: Path,
- device_filter: Optional[List[str]] = None
- ) -> Dict[str, List[Path]]:
- """
- 从外部数据目录收集训练数据
- 目录结构约定:
- data_dir/
- ├── LT-2/ <-- 子文件夹名 = device_code
- │ ├── xxx.wav
- │ └── ...
- └── LT-5/
- ├── yyy.wav
- └── ...
- 支持两种子目录结构:
- 1. 扁平结构:data_dir/{device_code}/*.wav
- 2. 日期嵌套:data_dir/{device_code}/{YYYYMMDD}/*.wav
- Args:
- data_dir: 外部数据目录路径
- device_filter: 只训练指定设备(None=全部)
- Returns:
- {device_code: [wav_files]} 按设备分组的音频文件列表
- """
- data_dir = Path(data_dir)
- if not data_dir.exists():
- raise FileNotFoundError(f"数据目录不存在: {data_dir}")
- logger.info(f"扫描外部数据目录: {data_dir}")
- device_files = {}
- for sub_dir in sorted(data_dir.iterdir()):
- # 跳过非目录
- if not sub_dir.is_dir():
- continue
- device_code = sub_dir.name
- # 应用设备过滤
- if device_filter and device_code not in device_filter:
- logger.info(f" 跳过设备(不在过滤列表中): {device_code}")
- continue
- audio_files = []
- # 扁平结构:直接查找 wav 文件
- audio_files.extend(list(sub_dir.glob("*.wav")))
- audio_files.extend(list(sub_dir.glob("*.mp4")))
- # 日期嵌套结构:查找子目录中的 wav 文件
- for nested_dir in sub_dir.iterdir():
- if nested_dir.is_dir():
- audio_files.extend(list(nested_dir.glob("*.wav")))
- audio_files.extend(list(nested_dir.glob("*.mp4")))
- # 去重
- audio_files = list(set(audio_files))
- if audio_files:
- device_files[device_code] = audio_files
- logger.info(f" {device_code}: {len(audio_files)} 个音频文件")
- else:
- logger.warning(f" {device_code}: 无音频文件,跳过")
- total = sum(len(f) for f in device_files.values())
- logger.info(f"总计: {total} 个音频文件,{len(device_files)} 个设备")
- return device_files
- def collect_training_data(self, target_date: str) -> Dict[str, List[Path]]:
- """
- 从内部数据目录收集训练数据(增量训练用)
- Args:
- target_date: 日期字符串 'YYYYMMDD'
- Returns:
- {device_code: [wav_files]} 按设备分组的音频文件
- """
- logger.info(f"收集 {target_date} 的训练数据")
- sample_hours = self.config['auto_training']['incremental'].get('sample_hours', 0)
- device_files = {}
- if not self.audio_root.exists():
- logger.warning(f"音频目录不存在: {self.audio_root}")
- return device_files
- for device_dir in self.audio_root.iterdir():
- if not device_dir.is_dir():
- continue
- device_code = device_dir.name
- audio_files = []
- # 冷启动模式:收集所有目录的数据
- if self.cold_start_mode:
- # 收集current目录
- current_dir = device_dir / "current"
- if current_dir.exists():
- audio_files.extend(list(current_dir.glob("*.wav")))
- audio_files.extend(list(current_dir.glob("*.mp4")))
- # 收集所有日期目录
- for sub_dir in device_dir.iterdir():
- if sub_dir.is_dir() and sub_dir.name.isdigit() and len(sub_dir.name) == 8:
- audio_files.extend(list(sub_dir.glob("*.wav")))
- audio_files.extend(list(sub_dir.glob("*.mp4")))
- else:
- # 正常模式:只收集指定日期的目录
- date_dir = device_dir / target_date
- if date_dir.exists():
- audio_files.extend(list(date_dir.glob("*.wav")))
- audio_files.extend(list(date_dir.glob("*.mp4")))
- # 加上 verified_normal 目录
- verified_dir = device_dir / "verified_normal"
- if verified_dir.exists():
- audio_files.extend(list(verified_dir.glob("*.wav")))
- audio_files.extend(list(verified_dir.glob("*.mp4")))
- # 去重
- audio_files = list(set(audio_files))
- # 随机采样(如果配置了采样时长)
- if sample_hours > 0 and audio_files:
- files_needed = int(sample_hours * 3600 / 60)
- if len(audio_files) > files_needed:
- audio_files = random.sample(audio_files, files_needed)
- logger.info(f" {device_code}: 随机采样 {len(audio_files)} 个音频")
- else:
- logger.info(f" {device_code}: {len(audio_files)} 个音频(全部使用)")
- else:
- logger.info(f" {device_code}: {len(audio_files)} 个音频")
- if audio_files:
- device_files[device_code] = audio_files
- total = sum(len(f) for f in device_files.values())
- logger.info(f"总计: {total} 个音频文件,{len(device_files)} 个设备")
- return device_files
- # ========================================
- # 特征提取(每设备独立标准化参数)
- # ========================================
- def _extract_mel_for_device(self, device_code: str,
- wav_files: List[Path]) -> Tuple[Optional[Path], Optional[Tuple[float, float]]]:
- """
- 为单个设备提取 Mel 特征并计算独立的 Z-score 标准化参数
- 两遍扫描:
- 1. 第一遍:收集所有 mel_db 计算 mean/std
- 2. 第二遍:Z-score 标准化后保存 npy 文件
- Args:
- device_code: 设备编码
- wav_files: 该设备的音频文件列表
- Returns:
- (mel_dir, (global_mean, global_std)),失败返回 (None, None)
- """
- import librosa
- # 滑动窗口参数
- win_samples = int(CFG.WIN_SEC * CFG.SR)
- hop_samples = int(CFG.HOP_SEC * CFG.SR)
- # 第一遍:收集所有 mel_db 值,用于计算 mean/std
- all_mel_data = []
- all_values = [] # 收集所有像素值用于全局统计
- for wav_file in wav_files:
- try:
- y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True)
- # 跳过过短的音频
- if len(y) < CFG.SR:
- continue
- for idx, start in enumerate(range(0, len(y) - win_samples + 1, hop_samples)):
- segment = y[start:start + win_samples]
- mel_spec = librosa.feature.melspectrogram(
- y=segment, sr=CFG.SR, n_fft=CFG.N_FFT,
- hop_length=CFG.HOP_LENGTH, n_mels=CFG.N_MELS, power=2.0
- )
- mel_db = librosa.power_to_db(mel_spec, ref=np.max)
- # 对齐帧数
- if mel_db.shape[1] < CFG.TARGET_FRAMES:
- pad = CFG.TARGET_FRAMES - mel_db.shape[1]
- mel_db = np.pad(mel_db, ((0, 0), (0, pad)), mode="constant")
- else:
- mel_db = mel_db[:, :CFG.TARGET_FRAMES]
- # 收集所有值用于 min/max 计算
- all_values.append(mel_db.flatten())
- all_mel_data.append((wav_file, idx, mel_db))
- except Exception as e:
- logger.warning(f"跳过文件 {wav_file.name}: {e}")
- continue
- if not all_mel_data:
- logger.warning(f" {device_code}: 无有效数据")
- return None, None
- # 计算全局 min/max(Min-Max 标准化参数)
- all_values_concat = np.concatenate(all_values)
- global_min = float(np.min(all_values_concat))
- global_max = float(np.max(all_values_concat))
- logger.info(f" {device_code}: {len(all_mel_data)} patches | "
- f"min={global_min:.4f} max={global_max:.4f}")
- # 第二遍:Min-Max 标准化并保存
- device_mel_dir = self.temp_mel_dir / device_code
- device_mel_dir.mkdir(parents=True, exist_ok=True)
- for wav_file, idx, mel_db in all_mel_data:
- # Min-Max: (x - min) / (max - min)
- mel_norm = (mel_db - global_min) / (global_max - global_min + 1e-6)
- npy_name = f"{device_code}@@{wav_file.stem}@@win{idx:05d}.npy"
- np.save(device_mel_dir / npy_name, mel_norm.astype(np.float32))
- return device_mel_dir, (global_min, global_max)
- def prepare_mel_features_per_device(self, device_files: Dict[str, List[Path]]
- ) -> Dict[str, Tuple[Path, Tuple[float, float]]]:
- """
- 为每个设备独立提取 Mel 特征
- 每设备分别计算自己的 Min-Max 标准化参数 (global_min, global_max)
- Args:
- device_files: {device_code: [wav_files]}
- Returns:
- {device_code: (mel_dir, (global_min, global_max))}
- """
- logger.info("提取 Mel 特征(每设备独立标准化)")
- # 清空临时目录
- if self.temp_mel_dir.exists():
- shutil.rmtree(self.temp_mel_dir)
- self.temp_mel_dir.mkdir(parents=True, exist_ok=True)
- device_results = {}
- for device_code, wav_files in device_files.items():
- mel_dir, scale = self._extract_mel_for_device(device_code, wav_files)
- if mel_dir is not None:
- device_results[device_code] = (mel_dir, scale)
- total_patches = sum(
- len(list(mel_dir.glob("*.npy")))
- for mel_dir, _ in device_results.values()
- )
- logger.info(f"提取完成: {total_patches} patches,{len(device_results)} 个设备")
- return device_results
- # ========================================
- # 模型训练(每设备独立)
- # ========================================
- def train_single_device(self, device_code: str, mel_dir: Path,
- epochs: int, lr: float,
- from_scratch: bool = True
- ) -> Tuple[nn.Module, float]:
- """
- 训练单个设备的独立模型
- Args:
- device_code: 设备编码
- mel_dir: 该设备的 Mel 特征目录
- epochs: 训练轮数
- lr: 学习率
- from_scratch: True=从零训练(全量),False=加载已有模型微调(增量)
- Returns:
- (model, final_loss)
- """
- logger.info(f"训练设备 {device_code}: epochs={epochs}, lr={lr}, "
- f"mode={'全量' if from_scratch else '增量'}")
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- model = ConvAutoencoder().to(device)
- # 增量模式下加载已有模型
- if not from_scratch:
- model_path = self.model_root / device_code / "ae_model.pth"
- if model_path.exists():
- model.load_state_dict(torch.load(model_path, map_location=device))
- logger.info(f" 已加载已有模型: {model_path}")
- else:
- logger.warning(f" 模型不存在,自动切换为全量训练: {model_path}")
- # 加载数据
- dataset = MelNPYDataset(mel_dir)
- if len(dataset) == 0:
- raise ValueError(f"设备 {device_code} 无训练数据")
- batch_size = self.config['auto_training']['incremental']['batch_size']
- dataloader = torch.utils.data.DataLoader(
- dataset, batch_size=batch_size, shuffle=True, num_workers=0
- )
- # 训练
- model.train()
- optimizer = torch.optim.Adam(model.parameters(), lr=lr)
- criterion = nn.MSELoss()
- avg_loss = 0.0
- for epoch in range(epochs):
- epoch_loss = 0.0
- batch_count = 0
- for batch in dataloader:
- batch = batch.to(device)
- optimizer.zero_grad()
- output = model(batch)
- output = align_to_target(output, batch)
- loss = criterion(output, batch)
- loss.backward()
- optimizer.step()
- epoch_loss += loss.item()
- batch_count += 1
- avg_loss = epoch_loss / batch_count
- # 每10轮或最后一轮打印日志,避免日志刷屏
- if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
- logger.info(f" [{device_code}] Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.6f}")
- return model, avg_loss
- # ========================================
- # 产出部署(每设备独立目录)
- # ========================================
- def deploy_device_model(self, device_code: str, model: nn.Module,
- scale: Tuple[float, float], mel_dir: Path):
- """
- 部署单个设备的模型到 models/{device_code}/ 目录
- 产出文件:
- - models/{device_code}/ae_model.pth
- - models/{device_code}/global_scale.npy → [mean, std]
- - models/{device_code}/thresholds/threshold_{device_code}.npy
- Args:
- device_code: 设备编码
- model: 训练后的模型
- scale: (global_min, global_max) Min-Max 标准化参数
- mel_dir: 该设备的 Mel 特征目录(用于计算阈值)
- """
- # 创建设备模型目录
- device_model_dir = self.model_root / device_code
- device_model_dir.mkdir(parents=True, exist_ok=True)
- # 1. 保存模型权重
- model_path = device_model_dir / "ae_model.pth"
- torch.save(model.state_dict(), model_path)
- logger.info(f" 模型已保存: {model_path}")
- # 2. 保存 Min-Max 标准化参数 [min, max]
- scale_path = device_model_dir / "global_scale.npy"
- np.save(scale_path, np.array([scale[0], scale[1]]))
- logger.info(f" 标准化参数已保存: {scale_path}")
- # 3. 计算并保存阈值
- threshold_dir = device_model_dir / "thresholds"
- threshold_dir.mkdir(parents=True, exist_ok=True)
- threshold = self._compute_threshold(model, mel_dir)
- threshold_file = threshold_dir / f"threshold_{device_code}.npy"
- np.save(threshold_file, threshold)
- logger.info(f" 阈值已保存: {threshold_file} (值={threshold:.6f})")
- def _compute_threshold(self, model: nn.Module, mel_dir: Path) -> float:
- """
- 计算单个设备的阈值
- 使用 3σ 法则:threshold = mean + 3 * std
- Args:
- model: 训练后的模型
- mel_dir: 该设备的 Mel 特征目录
- Returns:
- 阈值(标量)
- """
- device = next(model.parameters()).device
- model.eval()
- dataset = MelNPYDataset(mel_dir)
- if len(dataset) == 0:
- logger.warning("无数据计算阈值,使用默认值 0.01")
- return 0.01
- dataloader = torch.utils.data.DataLoader(
- dataset, batch_size=64, shuffle=False
- )
- all_errors = []
- with torch.no_grad():
- for batch in dataloader:
- batch = batch.to(device)
- output = model(batch)
- output = align_to_target(output, batch)
- mse = torch.mean((output - batch) ** 2, dim=[1, 2, 3])
- all_errors.append(mse.cpu().numpy())
- errors = np.concatenate(all_errors)
- # 3σ 法则
- mean_err = float(np.mean(errors))
- std_err = float(np.std(errors))
- threshold = mean_err + 3 * std_err
- logger.info(f" 阈值统计: 3σ={threshold:.6f} | "
- f"mean={mean_err:.6f} std={std_err:.6f} | "
- f"样本数={len(errors)}")
- return threshold
- # ========================================
- # 全量训练入口
- # ========================================
- def run_full_training(self, data_dir: Path,
- epochs: Optional[int] = None,
- lr: Optional[float] = None,
- device_filter: Optional[List[str]] = None) -> bool:
- """
- 全量训练入口:每设备从零训练独立模型
- 流程:
- 1. 扫描外部数据目录
- 2. 每设备独立提取 Mel 特征+标准化参数
- 3. 每设备独立训练模型
- 4. 每设备独立部署(模型+标准化+阈值)
- Args:
- data_dir: 数据目录路径
- epochs: 训练轮数(None=使用配置文件值)
- lr: 学习率(None=使用配置文件值)
- device_filter: 只训练指定设备(None=全部)
- Returns:
- bool: 是否成功
- """
- try:
- epochs = epochs or self.config['auto_training']['incremental']['epochs']
- lr = lr or self.config['auto_training']['incremental']['learning_rate']
- logger.info("=" * 60)
- logger.info(f"全量训练 | 数据目录: {data_dir}")
- logger.info(f"参数: epochs={epochs}, lr={lr}")
- if device_filter:
- logger.info(f"设备过滤: {device_filter}")
- logger.info("=" * 60)
- # 1. 收集数据
- device_files = self.collect_from_external_dir(data_dir, device_filter)
- if not device_files:
- logger.error("无可用训练数据")
- return False
- # 2. 每设备提取特征
- device_results = self.prepare_mel_features_per_device(device_files)
- if not device_results:
- logger.error("特征提取失败")
- return False
- # 3. 每设备独立训练+部署
- success_count = 0
- fail_count = 0
- for device_code, (mel_dir, scale) in device_results.items():
- try:
- logger.info(f"\n--- 训练设备: {device_code} ---")
- # 训练
- model, final_loss = self.train_single_device(
- device_code, mel_dir, epochs, lr, from_scratch=True
- )
- # 验证
- if not self._validate_model(model):
- logger.error(f" {device_code}: 模型验证失败,跳过部署")
- fail_count += 1
- continue
- # 部署到 models/{device_code}/
- self.deploy_device_model(device_code, model, scale, mel_dir)
- success_count += 1
- logger.info(f" {device_code}: 训练+部署完成 | loss={final_loss:.6f}")
- except Exception as e:
- logger.error(f" {device_code}: 训练失败 | {e}", exc_info=True)
- fail_count += 1
- logger.info("=" * 60)
- logger.info(f"全量训练完成: 成功={success_count}, 失败={fail_count}")
- logger.info("=" * 60)
- return fail_count == 0
- except Exception as e:
- logger.error(f"全量训练异常: {e}", exc_info=True)
- return False
- finally:
- # 清理临时文件
- if self.temp_mel_dir.exists():
- shutil.rmtree(self.temp_mel_dir)
- # ========================================
- # 增量训练入口(保留兼容)
- # ========================================
- def run_daily_training(self) -> bool:
- """
- 执行每日增量训练(保留原有逻辑)
- 改造点:每设备独立训练+部署,不再共享模型
- Returns:
- bool: 是否成功
- """
- try:
- days_ago = (self.use_days_ago if self.use_days_ago is not None
- else self.config['auto_training']['incremental']['use_days_ago'])
- target_date = (datetime.now() - timedelta(days=days_ago)).strftime('%Y%m%d')
- mode_str = "冷启动训练" if self.cold_start_mode else "增量训练"
- logger.info("=" * 60)
- logger.info(f"{mode_str} - {target_date}")
- logger.info("=" * 60)
- # 1. 收集数据
- device_files = self.collect_training_data(target_date)
- total = sum(len(f) for f in device_files.values())
- min_samples = self.config['auto_training']['incremental']['min_samples']
- if total < min_samples:
- logger.warning(f"数据不足 ({total} < {min_samples}),跳过")
- return False
- # 2. 备份模型
- if self.config['auto_training']['model']['backup_before_train']:
- self.backup_model(target_date)
- # 3. 每设备提取特征
- device_results = self.prepare_mel_features_per_device(device_files)
- if not device_results:
- logger.error("特征提取失败")
- return False
- # 4. 训练参数
- epochs = (self.epochs if self.epochs is not None
- else self.config['auto_training']['incremental']['epochs'])
- lr = (self.learning_rate if self.learning_rate is not None
- else self.config['auto_training']['incremental']['learning_rate'])
- # 5. 每设备独立训练+部署
- # 冷启动=全量训练(从零),增量=加载已有模型微调
- from_scratch = self.cold_start_mode
- success_count = 0
- for device_code, (mel_dir, scale) in device_results.items():
- try:
- model, _ = self.train_single_device(
- device_code, mel_dir, epochs, lr, from_scratch=from_scratch
- )
- if self._validate_model(model):
- if self.config['auto_training']['model']['auto_deploy']:
- self.deploy_device_model(device_code, model, scale, mel_dir)
- success_count += 1
- else:
- logger.error(f"{device_code}: 验证失败,跳过部署")
- except Exception as e:
- logger.error(f"{device_code}: 训练失败 | {e}")
- # 6. 更新分类器基线
- self._update_classifier_baseline(device_files)
- logger.info("=" * 60)
- logger.info(f"增量训练完成: {success_count}/{len(device_results)} 个设备成功")
- logger.info("=" * 60)
- return success_count > 0
- except Exception as e:
- logger.error(f"训练失败: {e}", exc_info=True)
- return False
- finally:
- if self.temp_mel_dir.exists():
- shutil.rmtree(self.temp_mel_dir)
- # ========================================
- # 辅助方法
- # ========================================
- def _validate_model(self, model: nn.Module) -> bool:
- # 验证模型输出形状是否合理
- if not self.config['auto_training']['validation']['enabled']:
- return True
- try:
- device = next(model.parameters()).device
- test_input = torch.randn(1, 1, CFG.N_MELS, CFG.TARGET_FRAMES).to(device)
- with torch.no_grad():
- output = model(test_input)
- h_diff = abs(output.shape[2] - test_input.shape[2])
- w_diff = abs(output.shape[3] - test_input.shape[3])
- if h_diff > 8 or w_diff > 8:
- logger.error(f"形状差异过大: {output.shape} vs {test_input.shape}")
- return False
- return True
- except Exception as e:
- logger.error(f"验证失败: {e}")
- return False
- def backup_model(self, date_tag: str):
- """
- 完整备份当前所有设备的模型
- 备份目录结构:
- backups/{date_tag}/{device_code}/
- ├── ae_model.pth
- ├── global_scale.npy
- └── thresholds/
- """
- backup_date_dir = self.backup_dir / date_tag
- backup_date_dir.mkdir(parents=True, exist_ok=True)
- backed_up = 0
- # 遍历 models/ 下的所有设备子目录
- for device_dir in self.model_root.iterdir():
- if not device_dir.is_dir():
- continue
- # 跳过 backups 目录本身
- if device_dir.name == "backups":
- continue
- # 检查是否包含模型文件(判断是否为设备目录)
- if not (device_dir / "ae_model.pth").exists():
- continue
- device_backup = backup_date_dir / device_dir.name
- # 递归复制整个设备目录
- shutil.copytree(device_dir, device_backup, dirs_exist_ok=True)
- backed_up += 1
- logger.info(f"备份完成: {backed_up} 个设备 -> {backup_date_dir}")
- # 清理旧备份
- keep = self.config['auto_training']['model']['keep_backups']
- backup_dirs = sorted(
- [d for d in self.backup_dir.iterdir() if d.is_dir() and d.name.isdigit()],
- reverse=True
- )
- for old_dir in backup_dirs[keep:]:
- shutil.rmtree(old_dir)
- logger.info(f"已删除旧备份: {old_dir.name}")
- def restore_backup(self, date_tag: str) -> bool:
- """
- 从备份恢复所有设备的模型
- Args:
- date_tag: 备份日期标签 'YYYYMMDD'
- Returns:
- bool: 是否恢复成功
- """
- backup_date_dir = self.backup_dir / date_tag
- if not backup_date_dir.exists():
- logger.error(f"备份目录不存在: {backup_date_dir}")
- return False
- logger.info(f"从备份恢复: {date_tag}")
- restored = 0
- for device_backup in backup_date_dir.iterdir():
- if not device_backup.is_dir():
- continue
- target_dir = self.model_root / device_backup.name
- # 递归复制恢复
- shutil.copytree(device_backup, target_dir, dirs_exist_ok=True)
- restored += 1
- logger.info(f"恢复完成: {restored} 个设备")
- return restored > 0
- def _update_classifier_baseline(self, device_files: Dict[str, List[Path]]):
- # 从训练数据计算并更新分类器基线
- logger.info("更新分类器基线")
- try:
- import librosa
- from core.anomaly_classifier import AnomalyClassifier
- classifier = AnomalyClassifier()
- all_files = []
- for files in device_files.values():
- all_files.extend(files)
- if not all_files:
- logger.warning("无音频文件,跳过基线更新")
- return
- sample_files = random.sample(all_files, min(50, len(all_files)))
- all_features = []
- for wav_file in sample_files:
- try:
- y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True)
- if len(y) < CFG.SR:
- continue
- features = classifier.extract_features(y, sr=CFG.SR)
- if features:
- all_features.append(features)
- except Exception:
- continue
- if not all_features:
- logger.warning("无法提取特征,跳过基线更新")
- return
- baseline = {}
- keys = all_features[0].keys()
- for key in keys:
- if key == 'has_periodic':
- values = [f[key] for f in all_features]
- baseline[key] = sum(values) > len(values) / 2
- else:
- values = [f[key] for f in all_features]
- baseline[key] = float(np.mean(values))
- classifier.save_baseline(baseline)
- logger.info(f" 基线已更新 (样本数: {len(all_features)})")
- except Exception as e:
- logger.warning(f"更新基线失败: {e}")
- def main():
- # 命令行入口(增量训练)
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s | %(levelname)-8s | %(message)s',
- datefmt='%Y-%m-%d %H:%M:%S'
- )
- config_file = Path(__file__).parent.parent / "config" / "auto_training.yaml"
- trainer = IncrementalTrainer(config_file)
- success = trainer.run_daily_training()
- sys.exit(0 if success else 1)
- if __name__ == "__main__":
- main()
|