incremental_trainer.py 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. incremental_trainer.py
  5. ----------------------
  6. 模型训练模块
  7. 功能:
  8. 1. 支持全量训练(指定外部数据目录,每设备独立模型)
  9. 2. 支持增量训练(每日自动训练)
  10. 3. 滑动窗口提取Mel特征(8秒patches)
  11. 4. 每设备独立计算标准化参数和阈值
  12. 5. 产出目录结构与推理端 MultiModelPredictor 对齐
  13. 产出目录结构:
  14. models/
  15. ├── {device_code_1}/
  16. │ ├── ae_model.pth
  17. │ ├── global_scale.npy
  18. │ └── thresholds/
  19. │ └── threshold_{device_code_1}.npy
  20. └── {device_code_2}/
  21. ├── ae_model.pth
  22. ├── global_scale.npy
  23. └── thresholds/
  24. └── threshold_{device_code_2}.npy
  25. """
  26. import sys
  27. import random
  28. import shutil
  29. import logging
  30. import numpy as np
  31. import torch
  32. import torch.nn as nn
  33. from pathlib import Path
  34. from datetime import datetime, timedelta
  35. from typing import List, Dict, Tuple, Optional
  36. import yaml
  37. # 添加父目录到路径
  38. sys.path.insert(0, str(Path(__file__).parent.parent))
  39. from predictor import CFG
  40. from predictor.model_def import ConvAutoencoder
  41. from predictor.datasets import MelNPYDataset
  42. from predictor.utils import align_to_target
  43. logger = logging.getLogger('IncrementalTrainer')
  44. class IncrementalTrainer:
  45. """
  46. 模型训练器
  47. 支持两种训练模式:
  48. 1. 全量训练:指定外部数据目录,每设备从零训练独立模型
  49. 2. 增量训练:使用运行中采集的数据,对已有模型微调(兼容旧逻辑)
  50. """
  51. def __init__(self, config_file: Path = None, config: dict = None):
  52. # 支持两种初始化方式:
  53. # 1. 传 config dict(从数据库读取后直接传入,主程序使用)
  54. # 2. 传 config_file YAML 路径(standalone_train.py 等独立工具使用)
  55. if config is not None:
  56. self.config = config
  57. self.config_file = None
  58. elif config_file is not None:
  59. self.config_file = config_file
  60. self.config = self._load_config()
  61. else:
  62. raise ValueError("必须提供 config_file 或 config 之一")
  63. # 路径配置
  64. self.deploy_root = Path(__file__).parent.parent
  65. self.audio_root = self.deploy_root / "data" / "audio"
  66. # 模型根目录(所有设备子目录的父目录)
  67. self.model_root = self.deploy_root / "models"
  68. self.backup_dir = self.model_root / "backups"
  69. # 临时目录
  70. self.temp_mel_dir = self.deploy_root / "data" / "temp_mels"
  71. # 确保目录存在
  72. self.backup_dir.mkdir(parents=True, exist_ok=True)
  73. # 运行时可覆盖的配置(用于冷启动)
  74. self.use_days_ago = None
  75. self.sample_hours = None
  76. self.epochs = None
  77. self.learning_rate = None
  78. self.cold_start_mode = False
  79. def _load_config(self) -> Dict:
  80. # 从 YAML 文件加载配置(仅 config_file 模式使用)
  81. with open(self.config_file, 'r', encoding='utf-8') as f:
  82. return yaml.safe_load(f)
  83. # ========================================
  84. # 数据收集
  85. # ========================================
  86. def collect_from_external_dir(self, data_dir: Path,
  87. device_filter: Optional[List[str]] = None
  88. ) -> Dict[str, List[Path]]:
  89. """
  90. 从外部数据目录收集训练数据
  91. 目录结构约定:
  92. data_dir/
  93. ├── LT-2/ <-- 子文件夹名 = device_code
  94. │ ├── xxx.wav
  95. │ └── ...
  96. └── LT-5/
  97. ├── yyy.wav
  98. └── ...
  99. 支持两种子目录结构:
  100. 1. 扁平结构:data_dir/{device_code}/*.wav
  101. 2. 日期嵌套:data_dir/{device_code}/{YYYYMMDD}/*.wav
  102. Args:
  103. data_dir: 外部数据目录路径
  104. device_filter: 只训练指定设备(None=全部)
  105. Returns:
  106. {device_code: [wav_files]} 按设备分组的音频文件列表
  107. """
  108. data_dir = Path(data_dir)
  109. if not data_dir.exists():
  110. raise FileNotFoundError(f"数据目录不存在: {data_dir}")
  111. logger.info(f"扫描外部数据目录: {data_dir}")
  112. device_files = {}
  113. for sub_dir in sorted(data_dir.iterdir()):
  114. # 跳过非目录
  115. if not sub_dir.is_dir():
  116. continue
  117. device_code = sub_dir.name
  118. # 应用设备过滤
  119. if device_filter and device_code not in device_filter:
  120. logger.info(f" 跳过设备(不在过滤列表中): {device_code}")
  121. continue
  122. audio_files = []
  123. # 扁平结构:直接查找 wav 文件
  124. audio_files.extend(list(sub_dir.glob("*.wav")))
  125. audio_files.extend(list(sub_dir.glob("*.mp4")))
  126. # 日期嵌套结构:查找子目录中的 wav 文件
  127. for nested_dir in sub_dir.iterdir():
  128. if nested_dir.is_dir():
  129. audio_files.extend(list(nested_dir.glob("*.wav")))
  130. audio_files.extend(list(nested_dir.glob("*.mp4")))
  131. # 去重
  132. audio_files = list(set(audio_files))
  133. if audio_files:
  134. device_files[device_code] = audio_files
  135. logger.info(f" {device_code}: {len(audio_files)} 个音频文件")
  136. else:
  137. logger.warning(f" {device_code}: 无音频文件,跳过")
  138. total = sum(len(f) for f in device_files.values())
  139. logger.info(f"总计: {total} 个音频文件,{len(device_files)} 个设备")
  140. return device_files
  141. def collect_training_data(self, target_date: str) -> Dict[str, List[Path]]:
  142. """
  143. 从内部数据目录收集训练数据(增量训练用)
  144. Args:
  145. target_date: 日期字符串 'YYYYMMDD'
  146. Returns:
  147. {device_code: [wav_files]} 按设备分组的音频文件
  148. """
  149. logger.info(f"收集 {target_date} 的训练数据")
  150. sample_hours = self.config['auto_training']['incremental'].get('sample_hours', 0)
  151. device_files = {}
  152. if not self.audio_root.exists():
  153. logger.warning(f"音频目录不存在: {self.audio_root}")
  154. return device_files
  155. for device_dir in self.audio_root.iterdir():
  156. if not device_dir.is_dir():
  157. continue
  158. device_code = device_dir.name
  159. audio_files = []
  160. # 冷启动模式:收集所有已归档日期目录的正常音频(跳过 current/)
  161. if self.cold_start_mode:
  162. # 注意:跳过 current/ 目录,因其中可能包含 FFmpeg 正在写入的不完整文件
  163. for sub_dir in device_dir.iterdir():
  164. if sub_dir.is_dir() and sub_dir.name.isdigit() and len(sub_dir.name) == 8:
  165. # 新结构:从 {date}/normal/ 子目录读取
  166. normal_dir = sub_dir / "normal"
  167. if normal_dir.exists():
  168. audio_files.extend(list(normal_dir.glob("*.wav")))
  169. audio_files.extend(list(normal_dir.glob("*.mp4")))
  170. # 兼容旧结构:日期目录下直接存放的音频文件
  171. audio_files.extend(list(sub_dir.glob("*.wav")))
  172. audio_files.extend(list(sub_dir.glob("*.mp4")))
  173. else:
  174. # 正常模式:只收集指定日期的正常音频
  175. date_dir = device_dir / target_date
  176. if date_dir.exists():
  177. # 新结构:从 {date}/normal/ 子目录读取
  178. normal_dir = date_dir / "normal"
  179. if normal_dir.exists():
  180. audio_files.extend(list(normal_dir.glob("*.wav")))
  181. audio_files.extend(list(normal_dir.glob("*.mp4")))
  182. # 兼容旧结构:日期目录下直接存放的音频文件
  183. audio_files.extend(list(date_dir.glob("*.wav")))
  184. audio_files.extend(list(date_dir.glob("*.mp4")))
  185. # 加上 verified_normal 目录(单独收集,不参与采样和质量预筛)
  186. verified_dir = device_dir / "verified_normal"
  187. verified_files = []
  188. if verified_dir.exists():
  189. verified_files.extend(list(verified_dir.glob("*.wav")))
  190. verified_files.extend(list(verified_dir.glob("*.mp4")))
  191. # 去重(仅日期目录音频)
  192. audio_files = list(set(audio_files))
  193. # 数据质量预筛:仅对日期目录音频过滤,verified_normal 已经人工确认,跳过
  194. if audio_files and not self.cold_start_mode:
  195. before_count = len(audio_files)
  196. audio_files = self._filter_audio_quality(audio_files, device_code)
  197. filtered = before_count - len(audio_files)
  198. if filtered > 0:
  199. logger.info(f" {device_code}: 质量预筛过滤 {filtered} 个异常音频")
  200. # 随机采样(仅对日期目录音频采样,verified_normal 不参与)
  201. if sample_hours > 0 and audio_files:
  202. files_needed = int(sample_hours * 3600 / 60)
  203. if len(audio_files) > files_needed:
  204. audio_files = random.sample(audio_files, files_needed)
  205. logger.info(f" {device_code}: 随机采样 {len(audio_files)} 个音频")
  206. else:
  207. logger.info(f" {device_code}: {len(audio_files)} 个音频(全部使用)")
  208. else:
  209. logger.info(f" {device_code}: {len(audio_files)} 个音频")
  210. # 合并 verified_normal(采样后追加,保证全量参与训练)
  211. if verified_files:
  212. audio_files.extend(verified_files)
  213. logger.info(f" {device_code}: +{len(verified_files)} 个核查确认音频(verified_normal)")
  214. if audio_files:
  215. device_files[device_code] = audio_files
  216. total = sum(len(f) for f in device_files.values())
  217. logger.info(f"总计: {total} 个音频文件,{len(device_files)} 个设备")
  218. return device_files
  219. def _filter_audio_quality(self, audio_files: List[Path],
  220. device_code: str) -> List[Path]:
  221. """
  222. 音频质量预筛:基于 RMS 能量和频谱质心过滤明显异常的样本
  223. 使用 IQR (四分位距) 方法检测离群值:
  224. - 计算所有文件的 RMS 能量和频谱质心
  225. - 过滤超出 [Q1 - 2*IQR, Q3 + 2*IQR] 范围的文件
  226. 需要至少 10 个文件才执行过滤,否则样本太少无统计意义。
  227. Args:
  228. audio_files: 待过滤的音频文件列表
  229. device_code: 设备编码(用于日志)
  230. Returns:
  231. 过滤后的文件列表
  232. """
  233. if len(audio_files) < 10:
  234. return audio_files
  235. import librosa
  236. # 快速计算每个文件的 RMS 能量
  237. rms_values = []
  238. valid_files = []
  239. for wav_file in audio_files:
  240. try:
  241. y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True,
  242. duration=10) # 只读前10秒加速
  243. if len(y) < CFG.SR:
  244. continue
  245. rms = float(np.sqrt(np.mean(y ** 2)))
  246. rms_values.append(rms)
  247. valid_files.append(wav_file)
  248. except Exception:
  249. continue
  250. if len(rms_values) < 10:
  251. return audio_files
  252. # IQR 离群值检测
  253. rms_arr = np.array(rms_values)
  254. q1, q3 = np.percentile(rms_arr, [25, 75])
  255. iqr = q3 - q1
  256. lower_bound = q1 - 2 * iqr
  257. upper_bound = q3 + 2 * iqr
  258. filtered = []
  259. for f, rms in zip(valid_files, rms_values):
  260. if lower_bound <= rms <= upper_bound:
  261. filtered.append(f)
  262. return filtered
  263. # ========================================
  264. # 特征提取(每设备独立标准化参数)
  265. # ========================================
  266. def _extract_mel_for_device(self, device_code: str,
  267. wav_files: List[Path]
  268. ) -> Tuple[Optional[Path], Optional[Tuple[float, float]]]:
  269. """
  270. 为单个设备提取 Mel 特征并计算独立的 Min-Max 标准化参数
  271. 流式两遍扫描(内存优化):
  272. 1. 第一遍:只计算 running min/max(O(1) 内存),不保存 mel_db
  273. 2. 第二遍:用第一遍的 min/max 标准化后直接写 npy 文件
  274. Args:
  275. device_code: 设备编码
  276. wav_files: 该设备的音频文件列表
  277. Returns:
  278. (mel_dir, (global_min, global_max)),失败返回 (None, None)
  279. """
  280. import librosa
  281. # 滑动窗口参数
  282. win_samples = int(CFG.WIN_SEC * CFG.SR)
  283. hop_samples = int(CFG.HOP_SEC * CFG.SR)
  284. def _iter_mel_patches(files):
  285. """生成器:逐文件逐 patch 产出 mel_db,避免全量加载到内存"""
  286. for wav_file in files:
  287. try:
  288. y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True)
  289. if len(y) < CFG.SR:
  290. continue
  291. for idx, start in enumerate(range(0, len(y) - win_samples + 1, hop_samples)):
  292. segment = y[start:start + win_samples]
  293. mel_spec = librosa.feature.melspectrogram(
  294. y=segment, sr=CFG.SR, n_fft=CFG.N_FFT,
  295. hop_length=CFG.HOP_LENGTH, n_mels=CFG.N_MELS, power=2.0
  296. )
  297. mel_db = librosa.power_to_db(mel_spec, ref=np.max)
  298. # 对齐帧数
  299. if mel_db.shape[1] < CFG.TARGET_FRAMES:
  300. pad = CFG.TARGET_FRAMES - mel_db.shape[1]
  301. mel_db = np.pad(mel_db, ((0, 0), (0, pad)), mode="constant")
  302. else:
  303. mel_db = mel_db[:, :CFG.TARGET_FRAMES]
  304. yield wav_file, idx, mel_db
  305. except Exception as e:
  306. logger.warning(f"跳过文件 {wav_file.name}: {e}")
  307. continue
  308. # ── 第一遍:流式计算 running min/max(O(1) 内存) ──
  309. global_min = float('inf')
  310. global_max = float('-inf')
  311. patch_count = 0
  312. for _, _, mel_db in _iter_mel_patches(wav_files):
  313. local_min = float(mel_db.min())
  314. local_max = float(mel_db.max())
  315. if local_min < global_min:
  316. global_min = local_min
  317. if local_max > global_max:
  318. global_max = local_max
  319. patch_count += 1
  320. if patch_count == 0:
  321. logger.warning(f" {device_code}: 无有效数据")
  322. return None, None
  323. logger.info(f" {device_code}: {patch_count} patches | "
  324. f"min={global_min:.4f} max={global_max:.4f}")
  325. # ── 第二遍:Min-Max 标准化并保存 ──
  326. device_mel_dir = self.temp_mel_dir / device_code
  327. device_mel_dir.mkdir(parents=True, exist_ok=True)
  328. scale_range = global_max - global_min + 1e-6
  329. for wav_file, idx, mel_db in _iter_mel_patches(wav_files):
  330. mel_norm = (mel_db - global_min) / scale_range
  331. npy_name = f"{device_code}@@{wav_file.stem}@@win{idx:05d}.npy"
  332. np.save(device_mel_dir / npy_name, mel_norm.astype(np.float32))
  333. return device_mel_dir, (global_min, global_max)
  334. def prepare_mel_features_per_device(self, device_files: Dict[str, List[Path]]
  335. ) -> Dict[str, Tuple[Path, Tuple[float, float]]]:
  336. """
  337. 为每个设备独立提取 Mel 特征
  338. 每设备分别计算自己的 Min-Max 标准化参数 (global_min, global_max)
  339. Args:
  340. device_files: {device_code: [wav_files]}
  341. Returns:
  342. {device_code: (mel_dir, (global_min, global_max))}
  343. """
  344. logger.info("提取 Mel 特征(每设备独立标准化)")
  345. # 清空临时目录
  346. if self.temp_mel_dir.exists():
  347. shutil.rmtree(self.temp_mel_dir)
  348. self.temp_mel_dir.mkdir(parents=True, exist_ok=True)
  349. device_results = {}
  350. for device_code, wav_files in device_files.items():
  351. mel_dir, scale = self._extract_mel_for_device(device_code, wav_files)
  352. if mel_dir is not None:
  353. device_results[device_code] = (mel_dir, scale)
  354. total_patches = sum(
  355. len(list(mel_dir.glob("*.npy")))
  356. for mel_dir, _ in device_results.values()
  357. )
  358. logger.info(f"提取完成: {total_patches} patches,{len(device_results)} 个设备")
  359. return device_results
  360. # ========================================
  361. # 模型训练(每设备独立)
  362. # ========================================
  363. def _select_training_device(self) -> torch.device:
  364. # 智能选择训练设备:GPU 显存充足则使用,否则回退 CPU
  365. # 训练配置中可通过 training_device 强制指定 (auto/cpu/cuda)
  366. training_cfg = self.config['auto_training']['incremental']
  367. forced_device = training_cfg.get('training_device', 'auto')
  368. # 强制指定设备时直接返回
  369. if forced_device == 'cpu':
  370. logger.info("训练设备: CPU(配置强制指定)")
  371. return torch.device('cpu')
  372. if forced_device == 'cuda':
  373. if torch.cuda.is_available():
  374. return torch.device('cuda')
  375. logger.warning("配置指定 CUDA 但不可用,回退 CPU")
  376. return torch.device('cpu')
  377. # auto 模式:检测 CUDA → CPU
  378. if torch.cuda.is_available():
  379. try:
  380. free_mem = torch.cuda.mem_get_info()[0] / (1024 * 1024)
  381. min_gpu_mem_mb = training_cfg.get('min_gpu_mem_mb', 512)
  382. if free_mem >= min_gpu_mem_mb:
  383. logger.info(f"训练设备: CUDA(空闲显存 {free_mem:.0f}MB)")
  384. return torch.device('cuda')
  385. logger.info(
  386. f"CUDA 空闲显存不足 ({free_mem:.0f}MB < {min_gpu_mem_mb}MB)"
  387. )
  388. except Exception as e:
  389. logger.warning(f"CUDA 显存检测失败: {e}")
  390. logger.info("训练设备: CPU")
  391. return torch.device('cpu')
  392. def _run_training_loop(self, device_code: str, model: nn.Module,
  393. train_loader, val_loader, epochs: int, lr: float,
  394. device: torch.device) -> Tuple[nn.Module, float]:
  395. # 执行实际的训练循环,与设备选择解耦
  396. # 早停基于验证集损失(如有),否则基于训练损失
  397. model = model.to(device)
  398. model.train()
  399. optimizer = torch.optim.Adam(model.parameters(), lr=lr)
  400. criterion = nn.MSELoss()
  401. # AMP 混合精度(GPU 生效,减少约 40% 显存占用)
  402. use_amp = device.type == 'cuda'
  403. scaler = torch.amp.GradScaler(device.type) if use_amp else None
  404. # 早停配置
  405. early_stop_cfg = self.config['auto_training']['incremental']
  406. patience = early_stop_cfg.get('early_stop_patience', 5)
  407. best_loss = float('inf')
  408. no_improve_count = 0
  409. avg_loss = 0.0
  410. actual_epochs = 0
  411. for epoch in range(epochs):
  412. # ── 训练阶段 ──
  413. model.train()
  414. epoch_loss = 0.0
  415. batch_count = 0
  416. for batch in train_loader:
  417. batch = batch.to(device)
  418. optimizer.zero_grad()
  419. if use_amp:
  420. with torch.amp.autocast(device.type):
  421. output = model(batch)
  422. output = align_to_target(output, batch)
  423. loss = criterion(output, batch)
  424. scaler.scale(loss).backward()
  425. scaler.step(optimizer)
  426. scaler.update()
  427. else:
  428. output = model(batch)
  429. output = align_to_target(output, batch)
  430. loss = criterion(output, batch)
  431. loss.backward()
  432. optimizer.step()
  433. epoch_loss += loss.item()
  434. batch_count += 1
  435. avg_loss = epoch_loss / batch_count
  436. actual_epochs = epoch + 1
  437. # ── 验证阶段(如有验证集) ──
  438. if val_loader is not None:
  439. model.eval()
  440. val_loss = 0.0
  441. val_count = 0
  442. with torch.no_grad():
  443. for batch in val_loader:
  444. batch = batch.to(device)
  445. if use_amp:
  446. with torch.amp.autocast(device.type):
  447. output = model(batch)
  448. output = align_to_target(output, batch)
  449. loss = criterion(output, batch)
  450. else:
  451. output = model(batch)
  452. output = align_to_target(output, batch)
  453. loss = criterion(output, batch)
  454. val_loss += loss.item()
  455. val_count += 1
  456. avg_val_loss = val_loss / val_count
  457. monitor_loss = avg_val_loss # 早停监控验证损失
  458. else:
  459. avg_val_loss = None
  460. monitor_loss = avg_loss # 无验证集时回退训练损失
  461. if actual_epochs % 10 == 0 or epoch == epochs - 1:
  462. val_str = f" | ValLoss: {avg_val_loss:.6f}" if avg_val_loss is not None else ""
  463. logger.info(f" [{device_code}] Epoch {actual_epochs}/{epochs} | "
  464. f"Loss: {avg_loss:.6f}{val_str} | device={device.type}")
  465. # 早停检测:连续 patience 轮无改善则提前终止
  466. if monitor_loss < best_loss:
  467. best_loss = monitor_loss
  468. no_improve_count = 0
  469. else:
  470. no_improve_count += 1
  471. if no_improve_count >= patience and actual_epochs >= 10:
  472. logger.info(f" [{device_code}] 早停触发: 连续{patience}轮无改善 | "
  473. f"最终轮数={actual_epochs}/{epochs} | Loss={avg_loss:.6f}")
  474. break
  475. # 训练后清理 GPU 缓存
  476. if device.type == 'cuda':
  477. torch.cuda.empty_cache()
  478. if actual_epochs < epochs:
  479. logger.info(f" [{device_code}] 早停节省 {epochs - actual_epochs} 轮训练")
  480. return model, avg_loss
  481. def train_single_device(self, device_code: str, mel_dir: Path,
  482. epochs: int, lr: float,
  483. from_scratch: bool = True
  484. ) -> Tuple[nn.Module, float]:
  485. # 训练单个设备的独立模型
  486. # 策略:优先 GPU 训练,显存不足自动回退 CPU;训练中 OOM 也会捕获并用 CPU 重试
  487. logger.info(f"训练设备 {device_code}: epochs={epochs}, lr={lr}, "
  488. f"mode={'全量' if from_scratch else '增量'}")
  489. # 智能选择训练设备
  490. device = self._select_training_device()
  491. # 训练前清理 GPU 缓存,释放推理残留的显存碎片
  492. if device.type == 'cuda':
  493. torch.cuda.empty_cache()
  494. import gc
  495. gc.collect()
  496. model = ConvAutoencoder()
  497. # 增量模式下加载已有模型
  498. if not from_scratch:
  499. model_path = self.model_root / device_code / "ae_model.pth"
  500. if model_path.exists():
  501. model.load_state_dict(torch.load(model_path, map_location='cpu'))
  502. logger.info(f" 已加载已有模型: {model_path}")
  503. else:
  504. logger.warning(f" 模型不存在,自动切换为全量训练: {model_path}")
  505. # 加载数据并按 80/20 划分训练集/验证集
  506. dataset = MelNPYDataset(mel_dir)
  507. if len(dataset) == 0:
  508. raise ValueError(f"设备 {device_code} 无训练数据")
  509. batch_size = self.config['auto_training']['incremental']['batch_size']
  510. # 验证集划分:数据量 >= 20 时才划分(否则太少无统计意义)
  511. val_loader = None
  512. if len(dataset) >= 20:
  513. val_size = max(1, int(len(dataset) * 0.2))
  514. train_size = len(dataset) - val_size
  515. train_dataset, val_dataset = torch.utils.data.random_split(
  516. dataset, [train_size, val_size]
  517. )
  518. train_loader = torch.utils.data.DataLoader(
  519. train_dataset, batch_size=batch_size, shuffle=True,
  520. num_workers=0, pin_memory=False
  521. )
  522. val_loader = torch.utils.data.DataLoader(
  523. val_dataset, batch_size=batch_size, shuffle=False,
  524. num_workers=0, pin_memory=False
  525. )
  526. logger.info(f" 数据集划分: 训练={train_size}, 验证={val_size}")
  527. else:
  528. train_loader = torch.utils.data.DataLoader(
  529. dataset, batch_size=batch_size, shuffle=True,
  530. num_workers=0, pin_memory=False
  531. )
  532. logger.info(f" 数据量不足20,跳过验证集划分(共{len(dataset)}样本)")
  533. # 尝试在选定设备上训练
  534. if device.type == 'cuda':
  535. try:
  536. return self._run_training_loop(
  537. device_code, model, train_loader, val_loader,
  538. epochs, lr, device
  539. )
  540. except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
  541. # GPU OOM -> 清理显存后回退 CPU 重试
  542. if 'out of memory' not in str(e).lower() and isinstance(e, RuntimeError):
  543. raise # 非 OOM 的 RuntimeError 不拦截
  544. logger.warning(
  545. f" [{device_code}] CUDA OOM,清理显存后回退 CPU 训练"
  546. )
  547. import gc
  548. gc.collect()
  549. torch.cuda.empty_cache()
  550. # 模型可能处于脏状态,重新初始化
  551. model = ConvAutoencoder()
  552. if not from_scratch:
  553. model_path = self.model_root / device_code / "ae_model.pth"
  554. if model_path.exists():
  555. model.load_state_dict(
  556. torch.load(model_path, map_location='cpu')
  557. )
  558. return self._run_training_loop(
  559. device_code, model, train_loader, val_loader,
  560. epochs, lr, torch.device('cpu')
  561. )
  562. else:
  563. # CPU 训练,无需 OOM 保护
  564. return self._run_training_loop(
  565. device_code, model, train_loader, val_loader,
  566. epochs, lr, device
  567. )
  568. # ========================================
  569. # 产出部署(每设备独立目录)
  570. # ========================================
  571. def deploy_device_model(self, device_code: str, model: nn.Module,
  572. scale: Tuple[float, float], mel_dir: Path):
  573. """
  574. 部署单个设备的模型到 models/{device_code}/ 目录
  575. 产出文件:
  576. - models/{device_code}/ae_model.pth
  577. - models/{device_code}/global_scale.npy → [mean, std]
  578. - models/{device_code}/thresholds/threshold_{device_code}.npy
  579. Args:
  580. device_code: 设备编码
  581. model: 训练后的模型
  582. scale: (global_min, global_max) Min-Max 标准化参数
  583. mel_dir: 该设备的 Mel 特征目录(用于计算阈值)
  584. """
  585. # 创建设备模型目录
  586. device_model_dir = self.model_root / device_code
  587. device_model_dir.mkdir(parents=True, exist_ok=True)
  588. # 1. 保存模型权重
  589. model_path = device_model_dir / "ae_model.pth"
  590. torch.save(model.state_dict(), model_path)
  591. logger.info(f" 模型已保存: {model_path}")
  592. # 2. 保存 Min-Max 标准化参数 [min, max]
  593. scale_path = device_model_dir / "global_scale.npy"
  594. np.save(scale_path, np.array([scale[0], scale[1]]))
  595. logger.info(f" 标准化参数已保存: {scale_path}")
  596. # 3. 计算并保存阈值
  597. threshold_dir = device_model_dir / "thresholds"
  598. threshold_dir.mkdir(parents=True, exist_ok=True)
  599. threshold = self._compute_threshold(model, mel_dir)
  600. threshold_file = threshold_dir / f"threshold_{device_code}.npy"
  601. np.save(threshold_file, threshold)
  602. logger.info(f" 阈值已保存: {threshold_file} (值={threshold:.6f})")
  603. def _compute_threshold(self, model: nn.Module, mel_dir: Path) -> float:
  604. """
  605. 计算单个设备的阈值
  606. 使用 3σ 法则:threshold = mean + 3 * std
  607. Args:
  608. model: 训练后的模型
  609. mel_dir: 该设备的 Mel 特征目录
  610. Returns:
  611. 阈值(标量)
  612. """
  613. device = next(model.parameters()).device
  614. model.eval()
  615. dataset = MelNPYDataset(mel_dir)
  616. if len(dataset) == 0:
  617. logger.warning("无数据计算阈值,使用默认值 0.01")
  618. return 0.01
  619. dataloader = torch.utils.data.DataLoader(
  620. dataset, batch_size=64, shuffle=False
  621. )
  622. all_errors = []
  623. with torch.no_grad():
  624. for batch in dataloader:
  625. batch = batch.to(device)
  626. output = model(batch)
  627. output = align_to_target(output, batch)
  628. mse = torch.mean((output - batch) ** 2, dim=[1, 2, 3])
  629. all_errors.append(mse.cpu().numpy())
  630. errors = np.concatenate(all_errors)
  631. # 3σ 法则
  632. mean_err = float(np.mean(errors))
  633. std_err = float(np.std(errors))
  634. threshold = mean_err + 3 * std_err
  635. logger.info(f" 阈值统计: 3σ={threshold:.6f} | "
  636. f"mean={mean_err:.6f} std={std_err:.6f} | "
  637. f"样本数={len(errors)}")
  638. return threshold
  639. def _eval_model_error(self, model: nn.Module, mel_dir: Path) -> float:
  640. """在验证数据上计算模型的平均重建误差,用于新旧模型对比"""
  641. device = next(model.parameters()).device
  642. model.eval()
  643. dataset = MelNPYDataset(mel_dir)
  644. if len(dataset) == 0:
  645. return float('inf')
  646. dataloader = torch.utils.data.DataLoader(
  647. dataset, batch_size=64, shuffle=False
  648. )
  649. all_errors = []
  650. with torch.no_grad():
  651. for batch in dataloader:
  652. batch = batch.to(device)
  653. output = model(batch)
  654. output = align_to_target(output, batch)
  655. mse = torch.mean((output - batch) ** 2, dim=[1, 2, 3])
  656. all_errors.append(mse.cpu().numpy())
  657. errors = np.concatenate(all_errors)
  658. return float(np.mean(errors))
  659. # ========================================
  660. # 全量训练入口
  661. # ========================================
  662. def run_full_training(self, data_dir: Path,
  663. epochs: Optional[int] = None,
  664. lr: Optional[float] = None,
  665. device_filter: Optional[List[str]] = None) -> bool:
  666. """
  667. 全量训练入口:每设备从零训练独立模型
  668. 流程:
  669. 1. 扫描外部数据目录
  670. 2. 每设备独立提取 Mel 特征+标准化参数
  671. 3. 每设备独立训练模型
  672. 4. 每设备独立部署(模型+标准化+阈值)
  673. Args:
  674. data_dir: 数据目录路径
  675. epochs: 训练轮数(None=使用配置文件值)
  676. lr: 学习率(None=使用配置文件值)
  677. device_filter: 只训练指定设备(None=全部)
  678. Returns:
  679. bool: 是否成功
  680. """
  681. try:
  682. epochs = epochs or self.config['auto_training']['incremental']['epochs']
  683. lr = lr or self.config['auto_training']['incremental']['learning_rate']
  684. logger.info("=" * 60)
  685. logger.info(f"全量训练 | 数据目录: {data_dir}")
  686. logger.info(f"参数: epochs={epochs}, lr={lr}")
  687. if device_filter:
  688. logger.info(f"设备过滤: {device_filter}")
  689. logger.info("=" * 60)
  690. # 1. 收集数据
  691. device_files = self.collect_from_external_dir(data_dir, device_filter)
  692. if not device_files:
  693. logger.error("无可用训练数据")
  694. return False
  695. # 2. 每设备提取特征
  696. device_results = self.prepare_mel_features_per_device(device_files)
  697. if not device_results:
  698. logger.error("特征提取失败")
  699. return False
  700. # 3. 每设备独立训练+部署
  701. success_count = 0
  702. fail_count = 0
  703. for device_code, (mel_dir, scale) in device_results.items():
  704. try:
  705. logger.info(f"\n--- 训练设备: {device_code} ---")
  706. # 训练
  707. model, final_loss = self.train_single_device(
  708. device_code, mel_dir, epochs, lr, from_scratch=True
  709. )
  710. # 验证
  711. if not self._validate_model(model):
  712. logger.error(f" {device_code}: 模型验证失败,跳过部署")
  713. fail_count += 1
  714. continue
  715. # 部署到 models/{device_code}/
  716. self.deploy_device_model(device_code, model, scale, mel_dir)
  717. success_count += 1
  718. logger.info(f" {device_code}: 训练+部署完成 | loss={final_loss:.6f}")
  719. except Exception as e:
  720. logger.error(f" {device_code}: 训练失败 | {e}", exc_info=True)
  721. fail_count += 1
  722. logger.info("=" * 60)
  723. logger.info(f"全量训练完成: 成功={success_count}, 失败={fail_count}")
  724. logger.info("=" * 60)
  725. return fail_count == 0
  726. except Exception as e:
  727. logger.error(f"全量训练异常: {e}", exc_info=True)
  728. return False
  729. finally:
  730. # 清理临时文件
  731. if self.temp_mel_dir.exists():
  732. shutil.rmtree(self.temp_mel_dir)
  733. # ========================================
  734. # 增量训练入口(保留兼容)
  735. # ========================================
  736. def run_daily_training(self, on_device_trained=None) -> bool:
  737. """
  738. 执行每日增量训练 — 逐设备串行处理
  739. 流程(每个设备完整走完再处理下一个,降低内存/CPU 峰值):
  740. 1. 收集所有设备的文件列表(仅路径,开销极低)
  741. 2. 备份模型
  742. 3. 逐设备:提取特征 → 训练 → 验证 → 部署 → 清理临时文件 → 回调通知
  743. 4. 更新分类器基线
  744. Args:
  745. on_device_trained: 可选回调 fn(device_code: str),
  746. 每个设备训练+部署成功后调用,
  747. 用于即时触发该设备的模型热重载
  748. Returns:
  749. bool: 是否至少有一个设备成功
  750. """
  751. try:
  752. days_ago = (self.use_days_ago if self.use_days_ago is not None
  753. else self.config['auto_training']['incremental']['use_days_ago'])
  754. target_date = (datetime.now() - timedelta(days=days_ago)).strftime('%Y%m%d')
  755. mode_str = "冷启动训练" if self.cold_start_mode else "增量训练"
  756. logger.info("=" * 60)
  757. logger.info(f"{mode_str} - {target_date}")
  758. logger.info("=" * 60)
  759. # 1. 收集数据(仅文件路径,不加载音频,开销极低)
  760. device_files = self.collect_training_data(target_date)
  761. total = sum(len(f) for f in device_files.values())
  762. min_samples = self.config['auto_training']['incremental']['min_samples']
  763. if total < min_samples:
  764. logger.warning(f"数据不足 ({total} < {min_samples}),跳过")
  765. return False
  766. # 2. 备份模型
  767. if self.config['auto_training']['model']['backup_before_train']:
  768. self.backup_model(target_date)
  769. # 3. 训练参数
  770. epochs = (self.epochs if self.epochs is not None
  771. else self.config['auto_training']['incremental']['epochs'])
  772. lr = (self.learning_rate if self.learning_rate is not None
  773. else self.config['auto_training']['incremental']['learning_rate'])
  774. from_scratch = self.cold_start_mode
  775. model_cfg = self.config['auto_training']['model']
  776. rollback_enabled = model_cfg.get('rollback_on_degradation', True)
  777. rollback_factor = model_cfg.get('rollback_factor', 2.0)
  778. # 4. 逐设备串行处理:提取 → 训练 → 部署 → 清理
  779. success_count = 0
  780. degraded_count = 0
  781. device_count = len(device_files)
  782. for idx, (device_code, wav_files) in enumerate(device_files.items(), 1):
  783. logger.info(f"\n{'='*40}")
  784. logger.info(f"[{idx}/{device_count}] 设备: {device_code}")
  785. logger.info(f"{'='*40}")
  786. try:
  787. # ── 4a. 单设备特征提取 ──
  788. mel_dir, scale = self._extract_mel_for_device(
  789. device_code, wav_files
  790. )
  791. if mel_dir is None:
  792. logger.warning(f"{device_code}: 特征提取无有效数据,跳过")
  793. continue
  794. # ── 4b. 训练 ──
  795. model, final_loss = self.train_single_device(
  796. device_code, mel_dir, epochs, lr, from_scratch=from_scratch
  797. )
  798. # ── 4c. 形状验证 ──
  799. if not self._validate_model(model):
  800. logger.error(f"{device_code}: 形状验证失败,跳过部署")
  801. continue
  802. # ── 4d. 新旧模型对比(增量训练时生效) ──
  803. # 在相同验证数据上比较新旧模型的重建误差,新模型更差则跳过部署
  804. if rollback_enabled and not self.cold_start_mode:
  805. old_model_path = self.model_root / device_code / "ae_model.pth"
  806. if old_model_path.exists():
  807. new_avg_err = self._eval_model_error(model, mel_dir)
  808. old_model = ConvAutoencoder()
  809. old_model.load_state_dict(
  810. torch.load(old_model_path, map_location='cpu')
  811. )
  812. old_avg_err = self._eval_model_error(old_model, mel_dir)
  813. logger.info(
  814. f" {device_code}: 新旧模型对比 | "
  815. f"旧模型误差={old_avg_err:.6f} 新模型误差={new_avg_err:.6f}"
  816. )
  817. if new_avg_err > old_avg_err * rollback_factor:
  818. logger.warning(
  819. f"{device_code}: 新模型退化 | "
  820. f"新={new_avg_err:.6f} > 旧={old_avg_err:.6f} × {rollback_factor},跳过部署"
  821. )
  822. degraded_count += 1
  823. continue
  824. # ── 4e. 部署 ──
  825. if model_cfg.get('auto_deploy', True):
  826. self.deploy_device_model(device_code, model, scale, mel_dir)
  827. success_count += 1
  828. logger.info(f"{device_code}: 训练+部署完成 | loss={final_loss:.6f}")
  829. # ── 4f. 清理已参与训练的 verified_normal 目录 ──
  830. # 核查确认的音频已被模型吸收,训练后清空释放磁盘空间
  831. verified_dir = self.audio_root / device_code / "verified_normal"
  832. if verified_dir.exists():
  833. v_count = len(list(verified_dir.glob("*")))
  834. if v_count > 0:
  835. shutil.rmtree(verified_dir)
  836. verified_dir.mkdir(parents=True, exist_ok=True)
  837. logger.info(f"{device_code}: 已清理 verified_normal ({v_count} 个文件)")
  838. # ── 4g. 即时通知该设备模型重载 ──
  839. if on_device_trained:
  840. try:
  841. on_device_trained(device_code)
  842. except Exception as e:
  843. logger.warning(f"{device_code}: 模型重载回调失败 | {e}")
  844. except Exception as e:
  845. logger.error(f"{device_code}: 训练失败 | {e}", exc_info=True)
  846. finally:
  847. # ── 4h. 清理该设备的临时 Mel 文件,释放磁盘空间 ──
  848. device_mel_dir = self.temp_mel_dir / device_code
  849. if device_mel_dir.exists():
  850. shutil.rmtree(device_mel_dir)
  851. # 5. 如果所有设备都退化,整体回滚到训练前备份
  852. if degraded_count > 0 and success_count == 0:
  853. logger.error(
  854. f"所有设备训练后损失退化({degraded_count}个),执行整体回滚"
  855. )
  856. self.restore_backup(target_date)
  857. return False
  858. if degraded_count > 0:
  859. logger.warning(
  860. f"{degraded_count} 个设备因损失退化跳过部署,"
  861. f"{success_count} 个设备部署成功"
  862. )
  863. logger.info("=" * 60)
  864. logger.info(f"增量训练完成: {success_count}/{device_count} 个设备成功")
  865. if degraded_count > 0:
  866. logger.info(f" 其中 {degraded_count} 个设备因损失退化跳过")
  867. logger.info("=" * 60)
  868. return success_count > 0
  869. except Exception as e:
  870. logger.error(f"训练失败: {e}", exc_info=True)
  871. return False
  872. finally:
  873. if self.temp_mel_dir.exists():
  874. shutil.rmtree(self.temp_mel_dir)
  875. # ========================================
  876. # 辅助方法
  877. # ========================================
  878. def _validate_model(self, model: nn.Module) -> bool:
  879. # 验证模型输出形状是否合理
  880. if not self.config['auto_training']['validation']['enabled']:
  881. return True
  882. try:
  883. device = next(model.parameters()).device
  884. test_input = torch.randn(1, 1, CFG.N_MELS, CFG.TARGET_FRAMES).to(device)
  885. with torch.no_grad():
  886. output = model(test_input)
  887. h_diff = abs(output.shape[2] - test_input.shape[2])
  888. w_diff = abs(output.shape[3] - test_input.shape[3])
  889. if h_diff > 8 or w_diff > 8:
  890. logger.error(f"形状差异过大: {output.shape} vs {test_input.shape}")
  891. return False
  892. return True
  893. except Exception as e:
  894. logger.error(f"验证失败: {e}")
  895. return False
  896. def backup_model(self, date_tag: str):
  897. """
  898. 完整备份当前所有设备的模型
  899. 备份目录结构:
  900. backups/{date_tag}/{device_code}/
  901. ├── ae_model.pth
  902. ├── global_scale.npy
  903. └── thresholds/
  904. """
  905. backup_date_dir = self.backup_dir / date_tag
  906. backup_date_dir.mkdir(parents=True, exist_ok=True)
  907. backed_up = 0
  908. # 遍历 models/ 下的所有设备子目录
  909. for device_dir in self.model_root.iterdir():
  910. if not device_dir.is_dir():
  911. continue
  912. # 跳过 backups 目录本身
  913. if device_dir.name == "backups":
  914. continue
  915. # 检查是否包含模型文件(判断是否为设备目录)
  916. if not (device_dir / "ae_model.pth").exists():
  917. continue
  918. device_backup = backup_date_dir / device_dir.name
  919. # 递归复制整个设备目录
  920. shutil.copytree(device_dir, device_backup, dirs_exist_ok=True)
  921. backed_up += 1
  922. logger.info(f"备份完成: {backed_up} 个设备 -> {backup_date_dir}")
  923. # 清理旧备份
  924. keep = self.config['auto_training']['model']['keep_backups']
  925. backup_dirs = sorted(
  926. [d for d in self.backup_dir.iterdir() if d.is_dir() and d.name.isdigit()],
  927. reverse=True
  928. )
  929. for old_dir in backup_dirs[keep:]:
  930. shutil.rmtree(old_dir)
  931. logger.info(f"已删除旧备份: {old_dir.name}")
  932. def restore_backup(self, date_tag: str) -> bool:
  933. """
  934. 从备份恢复所有设备的模型
  935. Args:
  936. date_tag: 备份日期标签 'YYYYMMDD'
  937. Returns:
  938. bool: 是否恢复成功
  939. """
  940. backup_date_dir = self.backup_dir / date_tag
  941. if not backup_date_dir.exists():
  942. logger.error(f"备份目录不存在: {backup_date_dir}")
  943. return False
  944. logger.info(f"从备份恢复: {date_tag}")
  945. restored = 0
  946. for device_backup in backup_date_dir.iterdir():
  947. if not device_backup.is_dir():
  948. continue
  949. target_dir = self.model_root / device_backup.name
  950. # 递归复制恢复
  951. shutil.copytree(device_backup, target_dir, dirs_exist_ok=True)
  952. restored += 1
  953. logger.info(f"恢复完成: {restored} 个设备")
  954. return restored > 0
  955. def main():
  956. # 命令行入口(增量训练)
  957. logging.basicConfig(
  958. level=logging.INFO,
  959. format='%(asctime)s | %(levelname)-8s | %(message)s',
  960. datefmt='%Y-%m-%d %H:%M:%S'
  961. )
  962. config_file = Path(__file__).parent.parent / "config" / "auto_training.yaml"
  963. trainer = IncrementalTrainer(config_file)
  964. success = trainer.run_daily_training()
  965. sys.exit(0 if success else 1)
  966. if __name__ == "__main__":
  967. main()