incremental_trainer.py 51 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. incremental_trainer.py
  5. ----------------------
  6. 模型训练模块
  7. 功能:
  8. 1. 支持全量训练(指定外部数据目录,每设备独立模型)
  9. 2. 支持增量训练(每日自动训练)
  10. 3. 滑动窗口提取Mel特征(8秒patches)
  11. 4. 每设备独立计算标准化参数和阈值
  12. 5. 产出目录结构与推理端 MultiModelPredictor 对齐
  13. 产出目录结构:
  14. models/
  15. ├── {device_code_1}/
  16. │ ├── ae_model.pth
  17. │ ├── global_scale.npy
  18. │ └── thresholds/
  19. │ └── threshold_{device_code_1}.npy
  20. └── {device_code_2}/
  21. ├── ae_model.pth
  22. ├── global_scale.npy
  23. └── thresholds/
  24. └── threshold_{device_code_2}.npy
  25. """
  26. import sys
  27. import random
  28. import shutil
  29. import logging
  30. import numpy as np
  31. import torch
  32. import torch.nn as nn
  33. from pathlib import Path
  34. from datetime import datetime, timedelta
  35. from typing import List, Dict, Tuple, Optional
  36. import yaml
  37. # 添加父目录到路径
  38. sys.path.insert(0, str(Path(__file__).parent.parent))
  39. from predictor import CFG
  40. from predictor.model_def import ConvAutoencoder
  41. from predictor.datasets import MelNPYDataset
  42. from predictor.utils import align_to_target
  43. logger = logging.getLogger('IncrementalTrainer')
  44. class IncrementalTrainer:
  45. """
  46. 模型训练器
  47. 支持两种训练模式:
  48. 1. 全量训练:指定外部数据目录,每设备从零训练独立模型
  49. 2. 增量训练:使用运行中采集的数据,对已有模型微调(兼容旧逻辑)
  50. """
  51. def __init__(self, config_file: Path):
  52. """
  53. 初始化训练器
  54. Args:
  55. config_file: auto_training.yaml 配置文件路径
  56. """
  57. self.config_file = config_file
  58. self.config = self._load_config()
  59. # 路径配置
  60. self.deploy_root = Path(__file__).parent.parent
  61. self.audio_root = self.deploy_root / "data" / "audio"
  62. # 模型根目录(所有设备子目录的父目录)
  63. self.model_root = self.deploy_root / "models"
  64. self.backup_dir = self.model_root / "backups"
  65. # 临时目录
  66. self.temp_mel_dir = self.deploy_root / "data" / "temp_mels"
  67. # 确保目录存在
  68. self.backup_dir.mkdir(parents=True, exist_ok=True)
  69. # 运行时可覆盖的配置(用于冷启动)
  70. self.use_days_ago = None
  71. self.sample_hours = None
  72. self.epochs = None
  73. self.learning_rate = None
  74. self.cold_start_mode = False
  75. def _load_config(self) -> Dict:
  76. # 加载配置文件
  77. with open(self.config_file, 'r', encoding='utf-8') as f:
  78. return yaml.safe_load(f)
  79. # ========================================
  80. # 数据收集
  81. # ========================================
  82. def collect_from_external_dir(self, data_dir: Path,
  83. device_filter: Optional[List[str]] = None
  84. ) -> Dict[str, List[Path]]:
  85. """
  86. 从外部数据目录收集训练数据
  87. 目录结构约定:
  88. data_dir/
  89. ├── LT-2/ <-- 子文件夹名 = device_code
  90. │ ├── xxx.wav
  91. │ └── ...
  92. └── LT-5/
  93. ├── yyy.wav
  94. └── ...
  95. 支持两种子目录结构:
  96. 1. 扁平结构:data_dir/{device_code}/*.wav
  97. 2. 日期嵌套:data_dir/{device_code}/{YYYYMMDD}/*.wav
  98. Args:
  99. data_dir: 外部数据目录路径
  100. device_filter: 只训练指定设备(None=全部)
  101. Returns:
  102. {device_code: [wav_files]} 按设备分组的音频文件列表
  103. """
  104. data_dir = Path(data_dir)
  105. if not data_dir.exists():
  106. raise FileNotFoundError(f"数据目录不存在: {data_dir}")
  107. logger.info(f"扫描外部数据目录: {data_dir}")
  108. device_files = {}
  109. for sub_dir in sorted(data_dir.iterdir()):
  110. # 跳过非目录
  111. if not sub_dir.is_dir():
  112. continue
  113. device_code = sub_dir.name
  114. # 应用设备过滤
  115. if device_filter and device_code not in device_filter:
  116. logger.info(f" 跳过设备(不在过滤列表中): {device_code}")
  117. continue
  118. audio_files = []
  119. # 扁平结构:直接查找 wav 文件
  120. audio_files.extend(list(sub_dir.glob("*.wav")))
  121. audio_files.extend(list(sub_dir.glob("*.mp4")))
  122. # 日期嵌套结构:查找子目录中的 wav 文件
  123. for nested_dir in sub_dir.iterdir():
  124. if nested_dir.is_dir():
  125. audio_files.extend(list(nested_dir.glob("*.wav")))
  126. audio_files.extend(list(nested_dir.glob("*.mp4")))
  127. # 去重
  128. audio_files = list(set(audio_files))
  129. if audio_files:
  130. device_files[device_code] = audio_files
  131. logger.info(f" {device_code}: {len(audio_files)} 个音频文件")
  132. else:
  133. logger.warning(f" {device_code}: 无音频文件,跳过")
  134. total = sum(len(f) for f in device_files.values())
  135. logger.info(f"总计: {total} 个音频文件,{len(device_files)} 个设备")
  136. return device_files
  137. def collect_training_data(self, target_date: str) -> Dict[str, List[Path]]:
  138. """
  139. 从内部数据目录收集训练数据(增量训练用)
  140. Args:
  141. target_date: 日期字符串 'YYYYMMDD'
  142. Returns:
  143. {device_code: [wav_files]} 按设备分组的音频文件
  144. """
  145. logger.info(f"收集 {target_date} 的训练数据")
  146. sample_hours = self.config['auto_training']['incremental'].get('sample_hours', 0)
  147. device_files = {}
  148. if not self.audio_root.exists():
  149. logger.warning(f"音频目录不存在: {self.audio_root}")
  150. return device_files
  151. for device_dir in self.audio_root.iterdir():
  152. if not device_dir.is_dir():
  153. continue
  154. device_code = device_dir.name
  155. audio_files = []
  156. # 冷启动模式:收集所有已归档日期目录的数据(跳过 current/)
  157. if self.cold_start_mode:
  158. # 注意:跳过 current/ 目录,因其中可能包含 FFmpeg 正在写入的不完整文件
  159. for sub_dir in device_dir.iterdir():
  160. if sub_dir.is_dir() and sub_dir.name.isdigit() and len(sub_dir.name) == 8:
  161. audio_files.extend(list(sub_dir.glob("*.wav")))
  162. audio_files.extend(list(sub_dir.glob("*.mp4")))
  163. else:
  164. # 正常模式:只收集指定日期的目录
  165. date_dir = device_dir / target_date
  166. if date_dir.exists():
  167. audio_files.extend(list(date_dir.glob("*.wav")))
  168. audio_files.extend(list(date_dir.glob("*.mp4")))
  169. # 加上 verified_normal 目录
  170. verified_dir = device_dir / "verified_normal"
  171. if verified_dir.exists():
  172. audio_files.extend(list(verified_dir.glob("*.wav")))
  173. audio_files.extend(list(verified_dir.glob("*.mp4")))
  174. # 去重
  175. audio_files = list(set(audio_files))
  176. # 数据质量预筛:过滤能量/频谱异常的音频
  177. if audio_files and not self.cold_start_mode:
  178. before_count = len(audio_files)
  179. audio_files = self._filter_audio_quality(audio_files, device_code)
  180. filtered = before_count - len(audio_files)
  181. if filtered > 0:
  182. logger.info(f" {device_code}: 质量预筛过滤 {filtered} 个异常音频")
  183. # 随机采样(如果配置了采样时长)
  184. if sample_hours > 0 and audio_files:
  185. files_needed = int(sample_hours * 3600 / 60)
  186. if len(audio_files) > files_needed:
  187. audio_files = random.sample(audio_files, files_needed)
  188. logger.info(f" {device_code}: 随机采样 {len(audio_files)} 个音频")
  189. else:
  190. logger.info(f" {device_code}: {len(audio_files)} 个音频(全部使用)")
  191. else:
  192. logger.info(f" {device_code}: {len(audio_files)} 个音频")
  193. if audio_files:
  194. device_files[device_code] = audio_files
  195. total = sum(len(f) for f in device_files.values())
  196. logger.info(f"总计: {total} 个音频文件,{len(device_files)} 个设备")
  197. return device_files
  198. def _filter_audio_quality(self, audio_files: List[Path],
  199. device_code: str) -> List[Path]:
  200. """
  201. 音频质量预筛:基于 RMS 能量和频谱质心过滤明显异常的样本
  202. 使用 IQR (四分位距) 方法检测离群值:
  203. - 计算所有文件的 RMS 能量和频谱质心
  204. - 过滤超出 [Q1 - 2*IQR, Q3 + 2*IQR] 范围的文件
  205. 需要至少 10 个文件才执行过滤,否则样本太少无统计意义。
  206. Args:
  207. audio_files: 待过滤的音频文件列表
  208. device_code: 设备编码(用于日志)
  209. Returns:
  210. 过滤后的文件列表
  211. """
  212. if len(audio_files) < 10:
  213. return audio_files
  214. import librosa
  215. # 快速计算每个文件的 RMS 能量
  216. rms_values = []
  217. valid_files = []
  218. for wav_file in audio_files:
  219. try:
  220. y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True,
  221. duration=10) # 只读前10秒加速
  222. if len(y) < CFG.SR:
  223. continue
  224. rms = float(np.sqrt(np.mean(y ** 2)))
  225. rms_values.append(rms)
  226. valid_files.append(wav_file)
  227. except Exception:
  228. continue
  229. if len(rms_values) < 10:
  230. return audio_files
  231. # IQR 离群值检测
  232. rms_arr = np.array(rms_values)
  233. q1, q3 = np.percentile(rms_arr, [25, 75])
  234. iqr = q3 - q1
  235. lower_bound = q1 - 2 * iqr
  236. upper_bound = q3 + 2 * iqr
  237. filtered = []
  238. for f, rms in zip(valid_files, rms_values):
  239. if lower_bound <= rms <= upper_bound:
  240. filtered.append(f)
  241. return filtered
  242. # ========================================
  243. # 特征提取(每设备独立标准化参数)
  244. # ========================================
  245. def _extract_mel_for_device(self, device_code: str,
  246. wav_files: List[Path],
  247. inherit_scale: bool = False
  248. ) -> Tuple[Optional[Path], Optional[Tuple[float, float]]]:
  249. """
  250. 为单个设备提取 Mel 特征并计算独立的 Min-Max 标准化参数
  251. 流式两遍扫描(内存优化):
  252. 1. 第一遍:只计算 running min/max(O(1) 内存),不保存 mel_db
  253. 2. 第二遍:用第一遍的 min/max 标准化后直接写 npy 文件
  254. Args:
  255. device_code: 设备编码
  256. wav_files: 该设备的音频文件列表
  257. inherit_scale: 增量训练时是否继承已部署的 scale 参数
  258. Returns:
  259. (mel_dir, (global_min, global_max)),失败返回 (None, None)
  260. """
  261. import librosa
  262. # 滑动窗口参数
  263. win_samples = int(CFG.WIN_SEC * CFG.SR)
  264. hop_samples = int(CFG.HOP_SEC * CFG.SR)
  265. def _iter_mel_patches(files):
  266. """生成器:逐文件逐 patch 产出 mel_db,避免全量加载到内存"""
  267. for wav_file in files:
  268. try:
  269. y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True)
  270. if len(y) < CFG.SR:
  271. continue
  272. for idx, start in enumerate(range(0, len(y) - win_samples + 1, hop_samples)):
  273. segment = y[start:start + win_samples]
  274. mel_spec = librosa.feature.melspectrogram(
  275. y=segment, sr=CFG.SR, n_fft=CFG.N_FFT,
  276. hop_length=CFG.HOP_LENGTH, n_mels=CFG.N_MELS, power=2.0
  277. )
  278. mel_db = librosa.power_to_db(mel_spec, ref=np.max)
  279. # 对齐帧数
  280. if mel_db.shape[1] < CFG.TARGET_FRAMES:
  281. pad = CFG.TARGET_FRAMES - mel_db.shape[1]
  282. mel_db = np.pad(mel_db, ((0, 0), (0, pad)), mode="constant")
  283. else:
  284. mel_db = mel_db[:, :CFG.TARGET_FRAMES]
  285. yield wav_file, idx, mel_db
  286. except Exception as e:
  287. logger.warning(f"跳过文件 {wav_file.name}: {e}")
  288. continue
  289. # ── 第一遍:流式计算 running min/max(O(1) 内存) ──
  290. global_min = float('inf')
  291. global_max = float('-inf')
  292. patch_count = 0
  293. for _, _, mel_db in _iter_mel_patches(wav_files):
  294. local_min = float(mel_db.min())
  295. local_max = float(mel_db.max())
  296. if local_min < global_min:
  297. global_min = local_min
  298. if local_max > global_max:
  299. global_max = local_max
  300. patch_count += 1
  301. if patch_count == 0:
  302. logger.warning(f" {device_code}: 无有效数据")
  303. return None, None
  304. # 增量训练时,用已部署的 scale 做 EMA 平滑,避免剧烈偏移
  305. if inherit_scale:
  306. old_scale = self._load_deployed_scale(device_code)
  307. if old_scale is not None:
  308. ema_alpha = 0.3 # 新数据权重
  309. old_min, old_max = old_scale
  310. global_min = ema_alpha * global_min + (1 - ema_alpha) * old_min
  311. global_max = ema_alpha * global_max + (1 - ema_alpha) * old_max
  312. logger.info(f" {device_code}: scale EMA 融合 | "
  313. f"old=[{old_min:.4f}, {old_max:.4f}] → "
  314. f"new=[{global_min:.4f}, {global_max:.4f}]")
  315. logger.info(f" {device_code}: {patch_count} patches | "
  316. f"min={global_min:.4f} max={global_max:.4f}")
  317. # ── 第二遍:Min-Max 标准化并保存 ──
  318. device_mel_dir = self.temp_mel_dir / device_code
  319. device_mel_dir.mkdir(parents=True, exist_ok=True)
  320. scale_range = global_max - global_min + 1e-6
  321. for wav_file, idx, mel_db in _iter_mel_patches(wav_files):
  322. mel_norm = (mel_db - global_min) / scale_range
  323. npy_name = f"{device_code}@@{wav_file.stem}@@win{idx:05d}.npy"
  324. np.save(device_mel_dir / npy_name, mel_norm.astype(np.float32))
  325. return device_mel_dir, (global_min, global_max)
  326. def _load_deployed_scale(self, device_code: str) -> Optional[Tuple[float, float]]:
  327. """加载已部署的 global_scale.npy,用于增量训练时的 scale 继承"""
  328. scale_path = self.model_root / device_code / "global_scale.npy"
  329. if not scale_path.exists():
  330. return None
  331. try:
  332. scale = np.load(scale_path)
  333. return float(scale[0]), float(scale[1])
  334. except Exception as e:
  335. logger.warning(f"加载旧 scale 失败: {device_code} | {e}")
  336. return None
  337. def prepare_mel_features_per_device(self, device_files: Dict[str, List[Path]],
  338. inherit_scale: bool = False
  339. ) -> Dict[str, Tuple[Path, Tuple[float, float]]]:
  340. """
  341. 为每个设备独立提取 Mel 特征
  342. 每设备分别计算自己的 Min-Max 标准化参数 (global_min, global_max)
  343. Args:
  344. device_files: {device_code: [wav_files]}
  345. inherit_scale: 增量训练时传 True,将新旧 scale 做 EMA 融合
  346. Returns:
  347. {device_code: (mel_dir, (global_min, global_max))}
  348. """
  349. logger.info("提取 Mel 特征(每设备独立标准化)")
  350. # 清空临时目录
  351. if self.temp_mel_dir.exists():
  352. shutil.rmtree(self.temp_mel_dir)
  353. self.temp_mel_dir.mkdir(parents=True, exist_ok=True)
  354. device_results = {}
  355. for device_code, wav_files in device_files.items():
  356. mel_dir, scale = self._extract_mel_for_device(
  357. device_code, wav_files, inherit_scale=inherit_scale
  358. )
  359. if mel_dir is not None:
  360. device_results[device_code] = (mel_dir, scale)
  361. total_patches = sum(
  362. len(list(mel_dir.glob("*.npy")))
  363. for mel_dir, _ in device_results.values()
  364. )
  365. logger.info(f"提取完成: {total_patches} patches,{len(device_results)} 个设备")
  366. return device_results
  367. # ========================================
  368. # 模型训练(每设备独立)
  369. # ========================================
  370. def _select_training_device(self) -> torch.device:
  371. # 智能选择训练设备:GPU/NPU 显存充足则使用,否则回退 CPU
  372. # 训练配置中可通过 training_device 强制指定 (auto/cpu/cuda/npu)
  373. training_cfg = self.config['auto_training']['incremental']
  374. forced_device = training_cfg.get('training_device', 'auto')
  375. # 强制指定设备时直接返回
  376. if forced_device == 'cpu':
  377. logger.info("训练设备: CPU(配置强制指定)")
  378. return torch.device('cpu')
  379. if forced_device == 'cuda':
  380. if torch.cuda.is_available():
  381. return torch.device('cuda')
  382. logger.warning("配置指定 CUDA 但不可用,回退 CPU")
  383. return torch.device('cpu')
  384. if forced_device == 'npu':
  385. if self._npu_available():
  386. return torch.device('npu')
  387. logger.warning("配置指定 NPU 但不可用,回退 CPU")
  388. return torch.device('cpu')
  389. # auto 模式:依次检测 CUDA → NPU → CPU
  390. # 1. 检测 CUDA
  391. if torch.cuda.is_available():
  392. try:
  393. free_mem = torch.cuda.mem_get_info()[0] / (1024 * 1024)
  394. min_gpu_mem_mb = training_cfg.get('min_gpu_mem_mb', 512)
  395. if free_mem >= min_gpu_mem_mb:
  396. logger.info(f"训练设备: CUDA(空闲显存 {free_mem:.0f}MB)")
  397. return torch.device('cuda')
  398. logger.info(
  399. f"CUDA 空闲显存不足 ({free_mem:.0f}MB < {min_gpu_mem_mb}MB)"
  400. )
  401. except Exception as e:
  402. logger.warning(f"CUDA 显存检测失败: {e}")
  403. # 2. 检测 NPU (华为昇腾)
  404. if self._npu_available():
  405. try:
  406. free_mem = torch.npu.mem_get_info()[0] / (1024 * 1024)
  407. min_gpu_mem_mb = training_cfg.get('min_gpu_mem_mb', 512)
  408. if free_mem >= min_gpu_mem_mb:
  409. logger.info(f"训练设备: NPU(空闲显存 {free_mem:.0f}MB)")
  410. return torch.device('npu')
  411. logger.info(
  412. f"NPU 空闲显存不足 ({free_mem:.0f}MB < {min_gpu_mem_mb}MB)"
  413. )
  414. except Exception as e:
  415. logger.warning(f"NPU 显存检测失败: {e},回退 CPU")
  416. logger.info("训练设备: CPU")
  417. return torch.device('cpu')
  418. @staticmethod
  419. def _npu_available() -> bool:
  420. """检查华为昇腾 NPU 是否可用"""
  421. try:
  422. import torch_npu # noqa: F401
  423. return torch.npu.is_available()
  424. except ImportError:
  425. return False
  426. def _run_training_loop(self, device_code: str, model: nn.Module,
  427. train_loader, val_loader, epochs: int, lr: float,
  428. device: torch.device) -> Tuple[nn.Module, float]:
  429. # 执行实际的训练循环,与设备选择解耦
  430. # 早停基于验证集损失(如有),否则基于训练损失
  431. model = model.to(device)
  432. model.train()
  433. optimizer = torch.optim.Adam(model.parameters(), lr=lr)
  434. criterion = nn.MSELoss()
  435. # AMP 混合精度(GPU/NPU 生效,减少约 40% 显存占用)
  436. use_amp = device.type in ('cuda', 'npu')
  437. scaler = torch.amp.GradScaler(device.type) if use_amp else None
  438. # 早停配置
  439. early_stop_cfg = self.config['auto_training']['incremental']
  440. patience = early_stop_cfg.get('early_stop_patience', 5)
  441. best_loss = float('inf')
  442. no_improve_count = 0
  443. avg_loss = 0.0
  444. actual_epochs = 0
  445. for epoch in range(epochs):
  446. # ── 训练阶段 ──
  447. model.train()
  448. epoch_loss = 0.0
  449. batch_count = 0
  450. for batch in train_loader:
  451. batch = batch.to(device)
  452. optimizer.zero_grad()
  453. if use_amp:
  454. with torch.amp.autocast(device.type):
  455. output = model(batch)
  456. output = align_to_target(output, batch)
  457. loss = criterion(output, batch)
  458. scaler.scale(loss).backward()
  459. scaler.step(optimizer)
  460. scaler.update()
  461. else:
  462. output = model(batch)
  463. output = align_to_target(output, batch)
  464. loss = criterion(output, batch)
  465. loss.backward()
  466. optimizer.step()
  467. epoch_loss += loss.item()
  468. batch_count += 1
  469. avg_loss = epoch_loss / batch_count
  470. actual_epochs = epoch + 1
  471. # ── 验证阶段(如有验证集) ──
  472. if val_loader is not None:
  473. model.eval()
  474. val_loss = 0.0
  475. val_count = 0
  476. with torch.no_grad():
  477. for batch in val_loader:
  478. batch = batch.to(device)
  479. if use_amp:
  480. with torch.amp.autocast(device.type):
  481. output = model(batch)
  482. output = align_to_target(output, batch)
  483. loss = criterion(output, batch)
  484. else:
  485. output = model(batch)
  486. output = align_to_target(output, batch)
  487. loss = criterion(output, batch)
  488. val_loss += loss.item()
  489. val_count += 1
  490. avg_val_loss = val_loss / val_count
  491. monitor_loss = avg_val_loss # 早停监控验证损失
  492. else:
  493. avg_val_loss = None
  494. monitor_loss = avg_loss # 无验证集时回退训练损失
  495. if actual_epochs % 10 == 0 or epoch == epochs - 1:
  496. val_str = f" | ValLoss: {avg_val_loss:.6f}" if avg_val_loss is not None else ""
  497. logger.info(f" [{device_code}] Epoch {actual_epochs}/{epochs} | "
  498. f"Loss: {avg_loss:.6f}{val_str} | device={device.type}")
  499. # 早停检测:连续 patience 轮无改善则提前终止
  500. if monitor_loss < best_loss:
  501. best_loss = monitor_loss
  502. no_improve_count = 0
  503. else:
  504. no_improve_count += 1
  505. if no_improve_count >= patience and actual_epochs >= 10:
  506. logger.info(f" [{device_code}] 早停触发: 连续{patience}轮无改善 | "
  507. f"最终轮数={actual_epochs}/{epochs} | Loss={avg_loss:.6f}")
  508. break
  509. # 训练后清理加速器缓存
  510. if device.type == 'cuda':
  511. torch.cuda.empty_cache()
  512. elif device.type == 'npu':
  513. torch.npu.empty_cache()
  514. if actual_epochs < epochs:
  515. logger.info(f" [{device_code}] 早停节省 {epochs - actual_epochs} 轮训练")
  516. return model, avg_loss
  517. def train_single_device(self, device_code: str, mel_dir: Path,
  518. epochs: int, lr: float,
  519. from_scratch: bool = True
  520. ) -> Tuple[nn.Module, float]:
  521. # 训练单个设备的独立模型
  522. # 策略:优先 GPU 训练,显存不足自动回退 CPU;训练中 OOM 也会捕获并用 CPU 重试
  523. logger.info(f"训练设备 {device_code}: epochs={epochs}, lr={lr}, "
  524. f"mode={'全量' if from_scratch else '增量'}")
  525. # 智能选择训练设备
  526. device = self._select_training_device()
  527. # 训练前清理加速器缓存,释放推理残留的显存碎片
  528. if device.type == 'cuda':
  529. torch.cuda.empty_cache()
  530. import gc
  531. gc.collect()
  532. elif device.type == 'npu':
  533. torch.npu.empty_cache()
  534. import gc
  535. gc.collect()
  536. model = ConvAutoencoder()
  537. # 增量模式下加载已有模型
  538. if not from_scratch:
  539. model_path = self.model_root / device_code / "ae_model.pth"
  540. if model_path.exists():
  541. model.load_state_dict(torch.load(model_path, map_location='cpu'))
  542. logger.info(f" 已加载已有模型: {model_path}")
  543. else:
  544. logger.warning(f" 模型不存在,自动切换为全量训练: {model_path}")
  545. # 加载数据并按 80/20 划分训练集/验证集
  546. dataset = MelNPYDataset(mel_dir)
  547. if len(dataset) == 0:
  548. raise ValueError(f"设备 {device_code} 无训练数据")
  549. batch_size = self.config['auto_training']['incremental']['batch_size']
  550. # 验证集划分:数据量 >= 20 时才划分(否则太少无统计意义)
  551. val_loader = None
  552. if len(dataset) >= 20:
  553. val_size = max(1, int(len(dataset) * 0.2))
  554. train_size = len(dataset) - val_size
  555. train_dataset, val_dataset = torch.utils.data.random_split(
  556. dataset, [train_size, val_size]
  557. )
  558. train_loader = torch.utils.data.DataLoader(
  559. train_dataset, batch_size=batch_size, shuffle=True,
  560. num_workers=0, pin_memory=False
  561. )
  562. val_loader = torch.utils.data.DataLoader(
  563. val_dataset, batch_size=batch_size, shuffle=False,
  564. num_workers=0, pin_memory=False
  565. )
  566. logger.info(f" 数据集划分: 训练={train_size}, 验证={val_size}")
  567. else:
  568. train_loader = torch.utils.data.DataLoader(
  569. dataset, batch_size=batch_size, shuffle=True,
  570. num_workers=0, pin_memory=False
  571. )
  572. logger.info(f" 数据量不足20,跳过验证集划分(共{len(dataset)}样本)")
  573. # 尝试在选定设备上训练
  574. if device.type in ('cuda', 'npu'):
  575. try:
  576. return self._run_training_loop(
  577. device_code, model, train_loader, val_loader,
  578. epochs, lr, device
  579. )
  580. except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
  581. # GPU/NPU OOM -> 清理显存后回退 CPU 重试
  582. if 'out of memory' not in str(e).lower() and isinstance(e, RuntimeError):
  583. raise # 非 OOM 的 RuntimeError 不拦截
  584. logger.warning(
  585. f" [{device_code}] {device.type.upper()} OOM,"
  586. f"清理显存后回退 CPU 训练"
  587. )
  588. import gc
  589. gc.collect()
  590. if device.type == 'cuda':
  591. torch.cuda.empty_cache()
  592. elif device.type == 'npu':
  593. torch.npu.empty_cache()
  594. # 模型可能处于脏状态,重新初始化
  595. model = ConvAutoencoder()
  596. if not from_scratch:
  597. model_path = self.model_root / device_code / "ae_model.pth"
  598. if model_path.exists():
  599. model.load_state_dict(
  600. torch.load(model_path, map_location='cpu')
  601. )
  602. return self._run_training_loop(
  603. device_code, model, train_loader, val_loader,
  604. epochs, lr, torch.device('cpu')
  605. )
  606. else:
  607. # CPU 训练,无需 OOM 保护
  608. return self._run_training_loop(
  609. device_code, model, train_loader, val_loader,
  610. epochs, lr, device
  611. )
  612. # ========================================
  613. # 产出部署(每设备独立目录)
  614. # ========================================
  615. def deploy_device_model(self, device_code: str, model: nn.Module,
  616. scale: Tuple[float, float], mel_dir: Path):
  617. """
  618. 部署单个设备的模型到 models/{device_code}/ 目录
  619. 产出文件:
  620. - models/{device_code}/ae_model.pth
  621. - models/{device_code}/global_scale.npy → [mean, std]
  622. - models/{device_code}/thresholds/threshold_{device_code}.npy
  623. Args:
  624. device_code: 设备编码
  625. model: 训练后的模型
  626. scale: (global_min, global_max) Min-Max 标准化参数
  627. mel_dir: 该设备的 Mel 特征目录(用于计算阈值)
  628. """
  629. # 创建设备模型目录
  630. device_model_dir = self.model_root / device_code
  631. device_model_dir.mkdir(parents=True, exist_ok=True)
  632. # 1. 保存模型权重
  633. model_path = device_model_dir / "ae_model.pth"
  634. torch.save(model.state_dict(), model_path)
  635. logger.info(f" 模型已保存: {model_path}")
  636. # 2. 保存 Min-Max 标准化参数 [min, max]
  637. scale_path = device_model_dir / "global_scale.npy"
  638. np.save(scale_path, np.array([scale[0], scale[1]]))
  639. logger.info(f" 标准化参数已保存: {scale_path}")
  640. # 3. 计算并保存阈值
  641. threshold_dir = device_model_dir / "thresholds"
  642. threshold_dir.mkdir(parents=True, exist_ok=True)
  643. threshold = self._compute_threshold(model, mel_dir)
  644. threshold_file = threshold_dir / f"threshold_{device_code}.npy"
  645. np.save(threshold_file, threshold)
  646. logger.info(f" 阈值已保存: {threshold_file} (值={threshold:.6f})")
  647. def _compute_threshold(self, model: nn.Module, mel_dir: Path) -> float:
  648. """
  649. 计算单个设备的阈值
  650. 使用 3σ 法则:threshold = mean + 3 * std
  651. Args:
  652. model: 训练后的模型
  653. mel_dir: 该设备的 Mel 特征目录
  654. Returns:
  655. 阈值(标量)
  656. """
  657. device = next(model.parameters()).device
  658. model.eval()
  659. dataset = MelNPYDataset(mel_dir)
  660. if len(dataset) == 0:
  661. logger.warning("无数据计算阈值,使用默认值 0.01")
  662. return 0.01
  663. dataloader = torch.utils.data.DataLoader(
  664. dataset, batch_size=64, shuffle=False
  665. )
  666. all_errors = []
  667. with torch.no_grad():
  668. for batch in dataloader:
  669. batch = batch.to(device)
  670. output = model(batch)
  671. output = align_to_target(output, batch)
  672. mse = torch.mean((output - batch) ** 2, dim=[1, 2, 3])
  673. all_errors.append(mse.cpu().numpy())
  674. errors = np.concatenate(all_errors)
  675. # 3σ 法则
  676. mean_err = float(np.mean(errors))
  677. std_err = float(np.std(errors))
  678. threshold = mean_err + 3 * std_err
  679. logger.info(f" 阈值统计: 3σ={threshold:.6f} | "
  680. f"mean={mean_err:.6f} std={std_err:.6f} | "
  681. f"样本数={len(errors)}")
  682. return threshold
  683. # ========================================
  684. # 全量训练入口
  685. # ========================================
  686. def run_full_training(self, data_dir: Path,
  687. epochs: Optional[int] = None,
  688. lr: Optional[float] = None,
  689. device_filter: Optional[List[str]] = None) -> bool:
  690. """
  691. 全量训练入口:每设备从零训练独立模型
  692. 流程:
  693. 1. 扫描外部数据目录
  694. 2. 每设备独立提取 Mel 特征+标准化参数
  695. 3. 每设备独立训练模型
  696. 4. 每设备独立部署(模型+标准化+阈值)
  697. Args:
  698. data_dir: 数据目录路径
  699. epochs: 训练轮数(None=使用配置文件值)
  700. lr: 学习率(None=使用配置文件值)
  701. device_filter: 只训练指定设备(None=全部)
  702. Returns:
  703. bool: 是否成功
  704. """
  705. try:
  706. epochs = epochs or self.config['auto_training']['incremental']['epochs']
  707. lr = lr or self.config['auto_training']['incremental']['learning_rate']
  708. logger.info("=" * 60)
  709. logger.info(f"全量训练 | 数据目录: {data_dir}")
  710. logger.info(f"参数: epochs={epochs}, lr={lr}")
  711. if device_filter:
  712. logger.info(f"设备过滤: {device_filter}")
  713. logger.info("=" * 60)
  714. # 1. 收集数据
  715. device_files = self.collect_from_external_dir(data_dir, device_filter)
  716. if not device_files:
  717. logger.error("无可用训练数据")
  718. return False
  719. # 2. 每设备提取特征
  720. device_results = self.prepare_mel_features_per_device(device_files)
  721. if not device_results:
  722. logger.error("特征提取失败")
  723. return False
  724. # 3. 每设备独立训练+部署
  725. success_count = 0
  726. fail_count = 0
  727. for device_code, (mel_dir, scale) in device_results.items():
  728. try:
  729. logger.info(f"\n--- 训练设备: {device_code} ---")
  730. # 训练
  731. model, final_loss = self.train_single_device(
  732. device_code, mel_dir, epochs, lr, from_scratch=True
  733. )
  734. # 验证
  735. if not self._validate_model(model):
  736. logger.error(f" {device_code}: 模型验证失败,跳过部署")
  737. fail_count += 1
  738. continue
  739. # 部署到 models/{device_code}/
  740. self.deploy_device_model(device_code, model, scale, mel_dir)
  741. success_count += 1
  742. logger.info(f" {device_code}: 训练+部署完成 | loss={final_loss:.6f}")
  743. except Exception as e:
  744. logger.error(f" {device_code}: 训练失败 | {e}", exc_info=True)
  745. fail_count += 1
  746. logger.info("=" * 60)
  747. logger.info(f"全量训练完成: 成功={success_count}, 失败={fail_count}")
  748. logger.info("=" * 60)
  749. return fail_count == 0
  750. except Exception as e:
  751. logger.error(f"全量训练异常: {e}", exc_info=True)
  752. return False
  753. finally:
  754. # 清理临时文件
  755. if self.temp_mel_dir.exists():
  756. shutil.rmtree(self.temp_mel_dir)
  757. # ========================================
  758. # 增量训练入口(保留兼容)
  759. # ========================================
  760. def run_daily_training(self, on_device_trained=None) -> bool:
  761. """
  762. 执行每日增量训练 — 逐设备串行处理
  763. 流程(每个设备完整走完再处理下一个,降低内存/CPU 峰值):
  764. 1. 收集所有设备的文件列表(仅路径,开销极低)
  765. 2. 备份模型
  766. 3. 逐设备:提取特征 → 训练 → 验证 → 部署 → 清理临时文件 → 回调通知
  767. 4. 更新分类器基线
  768. Args:
  769. on_device_trained: 可选回调 fn(device_code: str),
  770. 每个设备训练+部署成功后调用,
  771. 用于即时触发该设备的模型热重载
  772. Returns:
  773. bool: 是否至少有一个设备成功
  774. """
  775. try:
  776. days_ago = (self.use_days_ago if self.use_days_ago is not None
  777. else self.config['auto_training']['incremental']['use_days_ago'])
  778. target_date = (datetime.now() - timedelta(days=days_ago)).strftime('%Y%m%d')
  779. mode_str = "冷启动训练" if self.cold_start_mode else "增量训练"
  780. logger.info("=" * 60)
  781. logger.info(f"{mode_str} - {target_date}")
  782. logger.info("=" * 60)
  783. # 1. 收集数据(仅文件路径,不加载音频,开销极低)
  784. device_files = self.collect_training_data(target_date)
  785. total = sum(len(f) for f in device_files.values())
  786. min_samples = self.config['auto_training']['incremental']['min_samples']
  787. if total < min_samples:
  788. logger.warning(f"数据不足 ({total} < {min_samples}),跳过")
  789. return False
  790. # 2. 备份模型
  791. if self.config['auto_training']['model']['backup_before_train']:
  792. self.backup_model(target_date)
  793. # 3. 训练参数
  794. epochs = (self.epochs if self.epochs is not None
  795. else self.config['auto_training']['incremental']['epochs'])
  796. lr = (self.learning_rate if self.learning_rate is not None
  797. else self.config['auto_training']['incremental']['learning_rate'])
  798. from_scratch = self.cold_start_mode
  799. inherit_scale = not self.cold_start_mode
  800. model_cfg = self.config['auto_training']['model']
  801. rollback_enabled = model_cfg.get('rollback_on_degradation', True)
  802. rollback_factor = model_cfg.get('rollback_factor', 2.0)
  803. # 4. 逐设备串行处理:提取 → 训练 → 部署 → 清理
  804. success_count = 0
  805. degraded_count = 0
  806. device_count = len(device_files)
  807. for idx, (device_code, wav_files) in enumerate(device_files.items(), 1):
  808. logger.info(f"\n{'='*40}")
  809. logger.info(f"[{idx}/{device_count}] 设备: {device_code}")
  810. logger.info(f"{'='*40}")
  811. try:
  812. # ── 4a. 单设备特征提取 ──
  813. mel_dir, scale = self._extract_mel_for_device(
  814. device_code, wav_files, inherit_scale=inherit_scale
  815. )
  816. if mel_dir is None:
  817. logger.warning(f"{device_code}: 特征提取无有效数据,跳过")
  818. continue
  819. # ── 4b. 训练 ──
  820. model, final_loss = self.train_single_device(
  821. device_code, mel_dir, epochs, lr, from_scratch=from_scratch
  822. )
  823. # ── 4c. 形状验证 ──
  824. if not self._validate_model(model):
  825. logger.error(f"{device_code}: 形状验证失败,跳过部署")
  826. continue
  827. # ── 4d. 损失退化检测(增量训练时生效) ──
  828. if rollback_enabled and not self.cold_start_mode:
  829. old_threshold = self._get_old_threshold(device_code)
  830. if old_threshold and old_threshold > 0:
  831. if final_loss > old_threshold * rollback_factor:
  832. logger.warning(
  833. f"{device_code}: 损失退化检测触发 | "
  834. f"训练损失={final_loss:.6f} > "
  835. f"旧阈值={old_threshold:.6f} × {rollback_factor} = "
  836. f"{old_threshold * rollback_factor:.6f}"
  837. )
  838. degraded_count += 1
  839. continue
  840. # ── 4e. 阈值偏移检测 + 部署 ──
  841. if model_cfg.get('auto_deploy', True):
  842. if rollback_enabled and not self.cold_start_mode:
  843. old_threshold = self._get_old_threshold(device_code)
  844. if old_threshold and old_threshold > 0:
  845. new_threshold = self._compute_threshold(model, mel_dir)
  846. drift_ratio = abs(new_threshold - old_threshold) / old_threshold
  847. # 记录阈值变化趋势(用于长期漂移监控)
  848. self._log_threshold_history(
  849. device_code, target_date,
  850. old_threshold, new_threshold, final_loss
  851. )
  852. if drift_ratio > 0.3:
  853. logger.warning(
  854. f"{device_code}: 阈值偏移告警 | "
  855. f"旧={old_threshold:.6f} → "
  856. f"新={new_threshold:.6f} | "
  857. f"偏移={drift_ratio:.1%}"
  858. )
  859. if drift_ratio > 1.0:
  860. logger.warning(
  861. f"{device_code}: 阈值偏移过大"
  862. f"(>{drift_ratio:.0%}),跳过部署"
  863. )
  864. degraded_count += 1
  865. continue
  866. self.deploy_device_model(device_code, model, scale, mel_dir)
  867. success_count += 1
  868. logger.info(f"{device_code}: 训练+部署完成 | loss={final_loss:.6f}")
  869. # ── 4f. 即时通知该设备模型重载 ──
  870. if on_device_trained:
  871. try:
  872. on_device_trained(device_code)
  873. except Exception as e:
  874. logger.warning(f"{device_code}: 模型重载回调失败 | {e}")
  875. except Exception as e:
  876. logger.error(f"{device_code}: 训练失败 | {e}", exc_info=True)
  877. finally:
  878. # ── 4g. 清理该设备的临时 Mel 文件,释放磁盘空间 ──
  879. device_mel_dir = self.temp_mel_dir / device_code
  880. if device_mel_dir.exists():
  881. shutil.rmtree(device_mel_dir)
  882. # 5. 如果所有设备都退化,整体回滚到训练前备份
  883. if degraded_count > 0 and success_count == 0:
  884. logger.error(
  885. f"所有设备训练后损失退化({degraded_count}个),执行整体回滚"
  886. )
  887. self.restore_backup(target_date)
  888. return False
  889. if degraded_count > 0:
  890. logger.warning(
  891. f"{degraded_count} 个设备因损失退化跳过部署,"
  892. f"{success_count} 个设备部署成功"
  893. )
  894. # 6. 更新分类器基线
  895. self._update_classifier_baseline(device_files)
  896. logger.info("=" * 60)
  897. logger.info(f"增量训练完成: {success_count}/{device_count} 个设备成功")
  898. if degraded_count > 0:
  899. logger.info(f" 其中 {degraded_count} 个设备因损失退化跳过")
  900. logger.info("=" * 60)
  901. return success_count > 0
  902. except Exception as e:
  903. logger.error(f"训练失败: {e}", exc_info=True)
  904. return False
  905. finally:
  906. if self.temp_mel_dir.exists():
  907. shutil.rmtree(self.temp_mel_dir)
  908. # ========================================
  909. # 辅助方法
  910. # ========================================
  911. def _get_old_threshold(self, device_code: str) -> float:
  912. """
  913. 读取设备当前已部署的阈值(训练前的旧阈值)
  914. 用于损失退化校验:新模型的训练损失不应远超旧阈值。
  915. 阈值文件路径: models/{device_code}/thresholds/threshold_{device_code}.npy
  916. 返回:
  917. 阈值浮点数,文件不存在时返回 0.0
  918. """
  919. threshold_file = self.model_root / device_code / "thresholds" / f"threshold_{device_code}.npy"
  920. if not threshold_file.exists():
  921. return 0.0
  922. try:
  923. data = np.load(threshold_file)
  924. return float(data.flat[0])
  925. except Exception as e:
  926. logger.warning(f"读取旧阈值失败: {device_code} | {e}")
  927. return 0.0
  928. def _log_threshold_history(self, device_code: str, date_str: str,
  929. old_threshold: float, new_threshold: float,
  930. train_loss: float):
  931. """
  932. 记录阈值变化历史到 CSV,用于监控模型长期漂移趋势
  933. 文件路径: logs/threshold_history.csv
  934. 格式: date,device_code,old_threshold,new_threshold,drift_ratio,train_loss
  935. """
  936. import csv
  937. log_dir = self.deploy_root / "logs"
  938. log_dir.mkdir(parents=True, exist_ok=True)
  939. csv_path = log_dir / "threshold_history.csv"
  940. drift_ratio = (new_threshold - old_threshold) / old_threshold if old_threshold > 0 else 0.0
  941. write_header = not csv_path.exists()
  942. try:
  943. with open(csv_path, 'a', newline='', encoding='utf-8') as f:
  944. writer = csv.writer(f)
  945. if write_header:
  946. writer.writerow([
  947. 'date', 'device_code', 'old_threshold', 'new_threshold',
  948. 'drift_ratio', 'train_loss'
  949. ])
  950. writer.writerow([
  951. date_str, device_code,
  952. f"{old_threshold:.8f}", f"{new_threshold:.8f}",
  953. f"{drift_ratio:.4f}", f"{train_loss:.8f}"
  954. ])
  955. except Exception as e:
  956. logger.warning(f"写入阈值历史失败: {e}")
  957. def _validate_model(self, model: nn.Module) -> bool:
  958. # 验证模型输出形状是否合理
  959. if not self.config['auto_training']['validation']['enabled']:
  960. return True
  961. try:
  962. device = next(model.parameters()).device
  963. test_input = torch.randn(1, 1, CFG.N_MELS, CFG.TARGET_FRAMES).to(device)
  964. with torch.no_grad():
  965. output = model(test_input)
  966. h_diff = abs(output.shape[2] - test_input.shape[2])
  967. w_diff = abs(output.shape[3] - test_input.shape[3])
  968. if h_diff > 8 or w_diff > 8:
  969. logger.error(f"形状差异过大: {output.shape} vs {test_input.shape}")
  970. return False
  971. return True
  972. except Exception as e:
  973. logger.error(f"验证失败: {e}")
  974. return False
  975. def backup_model(self, date_tag: str):
  976. """
  977. 完整备份当前所有设备的模型
  978. 备份目录结构:
  979. backups/{date_tag}/{device_code}/
  980. ├── ae_model.pth
  981. ├── global_scale.npy
  982. └── thresholds/
  983. """
  984. backup_date_dir = self.backup_dir / date_tag
  985. backup_date_dir.mkdir(parents=True, exist_ok=True)
  986. backed_up = 0
  987. # 遍历 models/ 下的所有设备子目录
  988. for device_dir in self.model_root.iterdir():
  989. if not device_dir.is_dir():
  990. continue
  991. # 跳过 backups 目录本身
  992. if device_dir.name == "backups":
  993. continue
  994. # 检查是否包含模型文件(判断是否为设备目录)
  995. if not (device_dir / "ae_model.pth").exists():
  996. continue
  997. device_backup = backup_date_dir / device_dir.name
  998. # 递归复制整个设备目录
  999. shutil.copytree(device_dir, device_backup, dirs_exist_ok=True)
  1000. backed_up += 1
  1001. logger.info(f"备份完成: {backed_up} 个设备 -> {backup_date_dir}")
  1002. # 清理旧备份
  1003. keep = self.config['auto_training']['model']['keep_backups']
  1004. backup_dirs = sorted(
  1005. [d for d in self.backup_dir.iterdir() if d.is_dir() and d.name.isdigit()],
  1006. reverse=True
  1007. )
  1008. for old_dir in backup_dirs[keep:]:
  1009. shutil.rmtree(old_dir)
  1010. logger.info(f"已删除旧备份: {old_dir.name}")
  1011. def restore_backup(self, date_tag: str) -> bool:
  1012. """
  1013. 从备份恢复所有设备的模型
  1014. Args:
  1015. date_tag: 备份日期标签 'YYYYMMDD'
  1016. Returns:
  1017. bool: 是否恢复成功
  1018. """
  1019. backup_date_dir = self.backup_dir / date_tag
  1020. if not backup_date_dir.exists():
  1021. logger.error(f"备份目录不存在: {backup_date_dir}")
  1022. return False
  1023. logger.info(f"从备份恢复: {date_tag}")
  1024. restored = 0
  1025. for device_backup in backup_date_dir.iterdir():
  1026. if not device_backup.is_dir():
  1027. continue
  1028. target_dir = self.model_root / device_backup.name
  1029. # 递归复制恢复
  1030. shutil.copytree(device_backup, target_dir, dirs_exist_ok=True)
  1031. restored += 1
  1032. logger.info(f"恢复完成: {restored} 个设备")
  1033. return restored > 0
  1034. def _update_classifier_baseline(self, device_files: Dict[str, List[Path]]):
  1035. # 从训练数据计算并更新分类器基线
  1036. logger.info("更新分类器基线")
  1037. try:
  1038. import librosa
  1039. from core.anomaly_classifier import AnomalyClassifier
  1040. classifier = AnomalyClassifier()
  1041. all_files = []
  1042. for files in device_files.values():
  1043. all_files.extend(files)
  1044. if not all_files:
  1045. logger.warning("无音频文件,跳过基线更新")
  1046. return
  1047. sample_files = random.sample(all_files, min(50, len(all_files)))
  1048. all_features = []
  1049. for wav_file in sample_files:
  1050. try:
  1051. y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True)
  1052. if len(y) < CFG.SR:
  1053. continue
  1054. features = classifier.extract_features(y, sr=CFG.SR)
  1055. if features:
  1056. all_features.append(features)
  1057. except Exception:
  1058. continue
  1059. if not all_features:
  1060. logger.warning("无法提取特征,跳过基线更新")
  1061. return
  1062. baseline = {}
  1063. keys = all_features[0].keys()
  1064. for key in keys:
  1065. if key == 'has_periodic':
  1066. values = [f[key] for f in all_features]
  1067. baseline[key] = sum(values) > len(values) / 2
  1068. else:
  1069. values = [f[key] for f in all_features]
  1070. baseline[key] = float(np.mean(values))
  1071. classifier.save_baseline(baseline)
  1072. logger.info(f" 基线已更新 (样本数: {len(all_features)})")
  1073. except Exception as e:
  1074. logger.warning(f"更新基线失败: {e}")
  1075. def main():
  1076. # 命令行入口(增量训练)
  1077. logging.basicConfig(
  1078. level=logging.INFO,
  1079. format='%(asctime)s | %(levelname)-8s | %(message)s',
  1080. datefmt='%Y-%m-%d %H:%M:%S'
  1081. )
  1082. config_file = Path(__file__).parent.parent / "config" / "auto_training.yaml"
  1083. trainer = IncrementalTrainer(config_file)
  1084. success = trainer.run_daily_training()
  1085. sys.exit(0 if success else 1)
  1086. if __name__ == "__main__":
  1087. main()