incremental_trainer.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. incremental_trainer.py
  5. ----------------------
  6. 模型训练模块
  7. 功能:
  8. 1. 支持全量训练(指定外部数据目录,每设备独立模型)
  9. 2. 支持增量训练(每日自动训练)
  10. 3. 滑动窗口提取Mel特征(8秒patches)
  11. 4. 每设备独立计算标准化参数和阈值
  12. 5. 产出目录结构与推理端 MultiModelPredictor 对齐
  13. 产出目录结构:
  14. models/
  15. ├── {device_code_1}/
  16. │ ├── ae_model.pth
  17. │ ├── global_scale.npy
  18. │ └── thresholds/
  19. │ └── threshold_{device_code_1}.npy
  20. └── {device_code_2}/
  21. ├── ae_model.pth
  22. ├── global_scale.npy
  23. └── thresholds/
  24. └── threshold_{device_code_2}.npy
  25. """
  26. import sys
  27. import random
  28. import shutil
  29. import logging
  30. import numpy as np
  31. import torch
  32. import torch.nn as nn
  33. from pathlib import Path
  34. from datetime import datetime, timedelta
  35. from typing import List, Dict, Tuple, Optional
  36. import yaml
  37. # 添加父目录到路径
  38. sys.path.insert(0, str(Path(__file__).parent.parent))
  39. from predictor import CFG
  40. from predictor.model_def import ConvAutoencoder
  41. from predictor.datasets import MelNPYDataset
  42. from predictor.utils import align_to_target
  43. logger = logging.getLogger('IncrementalTrainer')
  44. class IncrementalTrainer:
  45. """
  46. 模型训练器
  47. 支持两种训练模式:
  48. 1. 全量训练:指定外部数据目录,每设备从零训练独立模型
  49. 2. 增量训练:使用运行中采集的数据,对已有模型微调(兼容旧逻辑)
  50. """
  51. def __init__(self, config_file: Path):
  52. """
  53. 初始化训练器
  54. Args:
  55. config_file: auto_training.yaml 配置文件路径
  56. """
  57. self.config_file = config_file
  58. self.config = self._load_config()
  59. # 路径配置
  60. self.deploy_root = Path(__file__).parent.parent
  61. self.audio_root = self.deploy_root / "data" / "audio"
  62. # 模型根目录(所有设备子目录的父目录)
  63. self.model_root = self.deploy_root / "models"
  64. self.backup_dir = self.model_root / "backups"
  65. # 临时目录
  66. self.temp_mel_dir = self.deploy_root / "data" / "temp_mels"
  67. # 确保目录存在
  68. self.backup_dir.mkdir(parents=True, exist_ok=True)
  69. # 运行时可覆盖的配置(用于冷启动)
  70. self.use_days_ago = None
  71. self.sample_hours = None
  72. self.epochs = None
  73. self.learning_rate = None
  74. self.cold_start_mode = False
  75. def _load_config(self) -> Dict:
  76. # 加载配置文件
  77. with open(self.config_file, 'r', encoding='utf-8') as f:
  78. return yaml.safe_load(f)
  79. # ========================================
  80. # 数据收集
  81. # ========================================
  82. def collect_from_external_dir(self, data_dir: Path,
  83. device_filter: Optional[List[str]] = None
  84. ) -> Dict[str, List[Path]]:
  85. """
  86. 从外部数据目录收集训练数据
  87. 目录结构约定:
  88. data_dir/
  89. ├── LT-2/ <-- 子文件夹名 = device_code
  90. │ ├── xxx.wav
  91. │ └── ...
  92. └── LT-5/
  93. ├── yyy.wav
  94. └── ...
  95. 支持两种子目录结构:
  96. 1. 扁平结构:data_dir/{device_code}/*.wav
  97. 2. 日期嵌套:data_dir/{device_code}/{YYYYMMDD}/*.wav
  98. Args:
  99. data_dir: 外部数据目录路径
  100. device_filter: 只训练指定设备(None=全部)
  101. Returns:
  102. {device_code: [wav_files]} 按设备分组的音频文件列表
  103. """
  104. data_dir = Path(data_dir)
  105. if not data_dir.exists():
  106. raise FileNotFoundError(f"数据目录不存在: {data_dir}")
  107. logger.info(f"扫描外部数据目录: {data_dir}")
  108. device_files = {}
  109. for sub_dir in sorted(data_dir.iterdir()):
  110. # 跳过非目录
  111. if not sub_dir.is_dir():
  112. continue
  113. device_code = sub_dir.name
  114. # 应用设备过滤
  115. if device_filter and device_code not in device_filter:
  116. logger.info(f" 跳过设备(不在过滤列表中): {device_code}")
  117. continue
  118. audio_files = []
  119. # 扁平结构:直接查找 wav 文件
  120. audio_files.extend(list(sub_dir.glob("*.wav")))
  121. audio_files.extend(list(sub_dir.glob("*.mp4")))
  122. # 日期嵌套结构:查找子目录中的 wav 文件
  123. for nested_dir in sub_dir.iterdir():
  124. if nested_dir.is_dir():
  125. audio_files.extend(list(nested_dir.glob("*.wav")))
  126. audio_files.extend(list(nested_dir.glob("*.mp4")))
  127. # 去重
  128. audio_files = list(set(audio_files))
  129. if audio_files:
  130. device_files[device_code] = audio_files
  131. logger.info(f" {device_code}: {len(audio_files)} 个音频文件")
  132. else:
  133. logger.warning(f" {device_code}: 无音频文件,跳过")
  134. total = sum(len(f) for f in device_files.values())
  135. logger.info(f"总计: {total} 个音频文件,{len(device_files)} 个设备")
  136. return device_files
  137. def collect_training_data(self, target_date: str) -> Dict[str, List[Path]]:
  138. """
  139. 从内部数据目录收集训练数据(增量训练用)
  140. Args:
  141. target_date: 日期字符串 'YYYYMMDD'
  142. Returns:
  143. {device_code: [wav_files]} 按设备分组的音频文件
  144. """
  145. logger.info(f"收集 {target_date} 的训练数据")
  146. sample_hours = self.config['auto_training']['incremental'].get('sample_hours', 0)
  147. device_files = {}
  148. if not self.audio_root.exists():
  149. logger.warning(f"音频目录不存在: {self.audio_root}")
  150. return device_files
  151. for device_dir in self.audio_root.iterdir():
  152. if not device_dir.is_dir():
  153. continue
  154. device_code = device_dir.name
  155. audio_files = []
  156. # 冷启动模式:收集所有目录的数据
  157. if self.cold_start_mode:
  158. # 收集current目录
  159. current_dir = device_dir / "current"
  160. if current_dir.exists():
  161. audio_files.extend(list(current_dir.glob("*.wav")))
  162. audio_files.extend(list(current_dir.glob("*.mp4")))
  163. # 收集所有日期目录
  164. for sub_dir in device_dir.iterdir():
  165. if sub_dir.is_dir() and sub_dir.name.isdigit() and len(sub_dir.name) == 8:
  166. audio_files.extend(list(sub_dir.glob("*.wav")))
  167. audio_files.extend(list(sub_dir.glob("*.mp4")))
  168. else:
  169. # 正常模式:只收集指定日期的目录
  170. date_dir = device_dir / target_date
  171. if date_dir.exists():
  172. audio_files.extend(list(date_dir.glob("*.wav")))
  173. audio_files.extend(list(date_dir.glob("*.mp4")))
  174. # 加上 verified_normal 目录
  175. verified_dir = device_dir / "verified_normal"
  176. if verified_dir.exists():
  177. audio_files.extend(list(verified_dir.glob("*.wav")))
  178. audio_files.extend(list(verified_dir.glob("*.mp4")))
  179. # 去重
  180. audio_files = list(set(audio_files))
  181. # 随机采样(如果配置了采样时长)
  182. if sample_hours > 0 and audio_files:
  183. files_needed = int(sample_hours * 3600 / 60)
  184. if len(audio_files) > files_needed:
  185. audio_files = random.sample(audio_files, files_needed)
  186. logger.info(f" {device_code}: 随机采样 {len(audio_files)} 个音频")
  187. else:
  188. logger.info(f" {device_code}: {len(audio_files)} 个音频(全部使用)")
  189. else:
  190. logger.info(f" {device_code}: {len(audio_files)} 个音频")
  191. if audio_files:
  192. device_files[device_code] = audio_files
  193. total = sum(len(f) for f in device_files.values())
  194. logger.info(f"总计: {total} 个音频文件,{len(device_files)} 个设备")
  195. return device_files
  196. # ========================================
  197. # 特征提取(每设备独立标准化参数)
  198. # ========================================
  199. def _extract_mel_for_device(self, device_code: str,
  200. wav_files: List[Path]) -> Tuple[Optional[Path], Optional[Tuple[float, float]]]:
  201. """
  202. 为单个设备提取 Mel 特征并计算独立的 Z-score 标准化参数
  203. 两遍扫描:
  204. 1. 第一遍:收集所有 mel_db 计算 mean/std
  205. 2. 第二遍:Z-score 标准化后保存 npy 文件
  206. Args:
  207. device_code: 设备编码
  208. wav_files: 该设备的音频文件列表
  209. Returns:
  210. (mel_dir, (global_mean, global_std)),失败返回 (None, None)
  211. """
  212. import librosa
  213. # 滑动窗口参数
  214. win_samples = int(CFG.WIN_SEC * CFG.SR)
  215. hop_samples = int(CFG.HOP_SEC * CFG.SR)
  216. # 第一遍:收集所有 mel_db 值,用于计算 mean/std
  217. all_mel_data = []
  218. all_values = [] # 收集所有像素值用于全局统计
  219. for wav_file in wav_files:
  220. try:
  221. y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True)
  222. # 跳过过短的音频
  223. if len(y) < CFG.SR:
  224. continue
  225. for idx, start in enumerate(range(0, len(y) - win_samples + 1, hop_samples)):
  226. segment = y[start:start + win_samples]
  227. mel_spec = librosa.feature.melspectrogram(
  228. y=segment, sr=CFG.SR, n_fft=CFG.N_FFT,
  229. hop_length=CFG.HOP_LENGTH, n_mels=CFG.N_MELS, power=2.0
  230. )
  231. mel_db = librosa.power_to_db(mel_spec, ref=np.max)
  232. # 对齐帧数
  233. if mel_db.shape[1] < CFG.TARGET_FRAMES:
  234. pad = CFG.TARGET_FRAMES - mel_db.shape[1]
  235. mel_db = np.pad(mel_db, ((0, 0), (0, pad)), mode="constant")
  236. else:
  237. mel_db = mel_db[:, :CFG.TARGET_FRAMES]
  238. # 收集所有值用于 min/max 计算
  239. all_values.append(mel_db.flatten())
  240. all_mel_data.append((wav_file, idx, mel_db))
  241. except Exception as e:
  242. logger.warning(f"跳过文件 {wav_file.name}: {e}")
  243. continue
  244. if not all_mel_data:
  245. logger.warning(f" {device_code}: 无有效数据")
  246. return None, None
  247. # 计算全局 min/max(Min-Max 标准化参数)
  248. all_values_concat = np.concatenate(all_values)
  249. global_min = float(np.min(all_values_concat))
  250. global_max = float(np.max(all_values_concat))
  251. logger.info(f" {device_code}: {len(all_mel_data)} patches | "
  252. f"min={global_min:.4f} max={global_max:.4f}")
  253. # 第二遍:Min-Max 标准化并保存
  254. device_mel_dir = self.temp_mel_dir / device_code
  255. device_mel_dir.mkdir(parents=True, exist_ok=True)
  256. for wav_file, idx, mel_db in all_mel_data:
  257. # Min-Max: (x - min) / (max - min)
  258. mel_norm = (mel_db - global_min) / (global_max - global_min + 1e-6)
  259. npy_name = f"{device_code}@@{wav_file.stem}@@win{idx:05d}.npy"
  260. np.save(device_mel_dir / npy_name, mel_norm.astype(np.float32))
  261. return device_mel_dir, (global_min, global_max)
  262. def prepare_mel_features_per_device(self, device_files: Dict[str, List[Path]]
  263. ) -> Dict[str, Tuple[Path, Tuple[float, float]]]:
  264. """
  265. 为每个设备独立提取 Mel 特征
  266. 每设备分别计算自己的 Min-Max 标准化参数 (global_min, global_max)
  267. Args:
  268. device_files: {device_code: [wav_files]}
  269. Returns:
  270. {device_code: (mel_dir, (global_min, global_max))}
  271. """
  272. logger.info("提取 Mel 特征(每设备独立标准化)")
  273. # 清空临时目录
  274. if self.temp_mel_dir.exists():
  275. shutil.rmtree(self.temp_mel_dir)
  276. self.temp_mel_dir.mkdir(parents=True, exist_ok=True)
  277. device_results = {}
  278. for device_code, wav_files in device_files.items():
  279. mel_dir, scale = self._extract_mel_for_device(device_code, wav_files)
  280. if mel_dir is not None:
  281. device_results[device_code] = (mel_dir, scale)
  282. total_patches = sum(
  283. len(list(mel_dir.glob("*.npy")))
  284. for mel_dir, _ in device_results.values()
  285. )
  286. logger.info(f"提取完成: {total_patches} patches,{len(device_results)} 个设备")
  287. return device_results
  288. # ========================================
  289. # 模型训练(每设备独立)
  290. # ========================================
  291. def train_single_device(self, device_code: str, mel_dir: Path,
  292. epochs: int, lr: float,
  293. from_scratch: bool = True
  294. ) -> Tuple[nn.Module, float]:
  295. """
  296. 训练单个设备的独立模型
  297. Args:
  298. device_code: 设备编码
  299. mel_dir: 该设备的 Mel 特征目录
  300. epochs: 训练轮数
  301. lr: 学习率
  302. from_scratch: True=从零训练(全量),False=加载已有模型微调(增量)
  303. Returns:
  304. (model, final_loss)
  305. """
  306. logger.info(f"训练设备 {device_code}: epochs={epochs}, lr={lr}, "
  307. f"mode={'全量' if from_scratch else '增量'}")
  308. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  309. model = ConvAutoencoder().to(device)
  310. # 增量模式下加载已有模型
  311. if not from_scratch:
  312. model_path = self.model_root / device_code / "ae_model.pth"
  313. if model_path.exists():
  314. model.load_state_dict(torch.load(model_path, map_location=device))
  315. logger.info(f" 已加载已有模型: {model_path}")
  316. else:
  317. logger.warning(f" 模型不存在,自动切换为全量训练: {model_path}")
  318. # 加载数据
  319. dataset = MelNPYDataset(mel_dir)
  320. if len(dataset) == 0:
  321. raise ValueError(f"设备 {device_code} 无训练数据")
  322. batch_size = self.config['auto_training']['incremental']['batch_size']
  323. dataloader = torch.utils.data.DataLoader(
  324. dataset, batch_size=batch_size, shuffle=True, num_workers=0
  325. )
  326. # 训练
  327. model.train()
  328. optimizer = torch.optim.Adam(model.parameters(), lr=lr)
  329. criterion = nn.MSELoss()
  330. avg_loss = 0.0
  331. for epoch in range(epochs):
  332. epoch_loss = 0.0
  333. batch_count = 0
  334. for batch in dataloader:
  335. batch = batch.to(device)
  336. optimizer.zero_grad()
  337. output = model(batch)
  338. output = align_to_target(output, batch)
  339. loss = criterion(output, batch)
  340. loss.backward()
  341. optimizer.step()
  342. epoch_loss += loss.item()
  343. batch_count += 1
  344. avg_loss = epoch_loss / batch_count
  345. # 每10轮或最后一轮打印日志,避免日志刷屏
  346. if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
  347. logger.info(f" [{device_code}] Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.6f}")
  348. return model, avg_loss
  349. # ========================================
  350. # 产出部署(每设备独立目录)
  351. # ========================================
  352. def deploy_device_model(self, device_code: str, model: nn.Module,
  353. scale: Tuple[float, float], mel_dir: Path):
  354. """
  355. 部署单个设备的模型到 models/{device_code}/ 目录
  356. 产出文件:
  357. - models/{device_code}/ae_model.pth
  358. - models/{device_code}/global_scale.npy → [mean, std]
  359. - models/{device_code}/thresholds/threshold_{device_code}.npy
  360. Args:
  361. device_code: 设备编码
  362. model: 训练后的模型
  363. scale: (global_min, global_max) Min-Max 标准化参数
  364. mel_dir: 该设备的 Mel 特征目录(用于计算阈值)
  365. """
  366. # 创建设备模型目录
  367. device_model_dir = self.model_root / device_code
  368. device_model_dir.mkdir(parents=True, exist_ok=True)
  369. # 1. 保存模型权重
  370. model_path = device_model_dir / "ae_model.pth"
  371. torch.save(model.state_dict(), model_path)
  372. logger.info(f" 模型已保存: {model_path}")
  373. # 2. 保存 Min-Max 标准化参数 [min, max]
  374. scale_path = device_model_dir / "global_scale.npy"
  375. np.save(scale_path, np.array([scale[0], scale[1]]))
  376. logger.info(f" 标准化参数已保存: {scale_path}")
  377. # 3. 计算并保存阈值
  378. threshold_dir = device_model_dir / "thresholds"
  379. threshold_dir.mkdir(parents=True, exist_ok=True)
  380. threshold = self._compute_threshold(model, mel_dir)
  381. threshold_file = threshold_dir / f"threshold_{device_code}.npy"
  382. np.save(threshold_file, threshold)
  383. logger.info(f" 阈值已保存: {threshold_file} (值={threshold:.6f})")
  384. def _compute_threshold(self, model: nn.Module, mel_dir: Path) -> float:
  385. """
  386. 计算单个设备的阈值
  387. 使用 3σ 法则:threshold = mean + 3 * std
  388. Args:
  389. model: 训练后的模型
  390. mel_dir: 该设备的 Mel 特征目录
  391. Returns:
  392. 阈值(标量)
  393. """
  394. device = next(model.parameters()).device
  395. model.eval()
  396. dataset = MelNPYDataset(mel_dir)
  397. if len(dataset) == 0:
  398. logger.warning("无数据计算阈值,使用默认值 0.01")
  399. return 0.01
  400. dataloader = torch.utils.data.DataLoader(
  401. dataset, batch_size=64, shuffle=False
  402. )
  403. all_errors = []
  404. with torch.no_grad():
  405. for batch in dataloader:
  406. batch = batch.to(device)
  407. output = model(batch)
  408. output = align_to_target(output, batch)
  409. mse = torch.mean((output - batch) ** 2, dim=[1, 2, 3])
  410. all_errors.append(mse.cpu().numpy())
  411. errors = np.concatenate(all_errors)
  412. # 3σ 法则
  413. mean_err = float(np.mean(errors))
  414. std_err = float(np.std(errors))
  415. threshold = mean_err + 3 * std_err
  416. logger.info(f" 阈值统计: 3σ={threshold:.6f} | "
  417. f"mean={mean_err:.6f} std={std_err:.6f} | "
  418. f"样本数={len(errors)}")
  419. return threshold
  420. # ========================================
  421. # 全量训练入口
  422. # ========================================
  423. def run_full_training(self, data_dir: Path,
  424. epochs: Optional[int] = None,
  425. lr: Optional[float] = None,
  426. device_filter: Optional[List[str]] = None) -> bool:
  427. """
  428. 全量训练入口:每设备从零训练独立模型
  429. 流程:
  430. 1. 扫描外部数据目录
  431. 2. 每设备独立提取 Mel 特征+标准化参数
  432. 3. 每设备独立训练模型
  433. 4. 每设备独立部署(模型+标准化+阈值)
  434. Args:
  435. data_dir: 数据目录路径
  436. epochs: 训练轮数(None=使用配置文件值)
  437. lr: 学习率(None=使用配置文件值)
  438. device_filter: 只训练指定设备(None=全部)
  439. Returns:
  440. bool: 是否成功
  441. """
  442. try:
  443. epochs = epochs or self.config['auto_training']['incremental']['epochs']
  444. lr = lr or self.config['auto_training']['incremental']['learning_rate']
  445. logger.info("=" * 60)
  446. logger.info(f"全量训练 | 数据目录: {data_dir}")
  447. logger.info(f"参数: epochs={epochs}, lr={lr}")
  448. if device_filter:
  449. logger.info(f"设备过滤: {device_filter}")
  450. logger.info("=" * 60)
  451. # 1. 收集数据
  452. device_files = self.collect_from_external_dir(data_dir, device_filter)
  453. if not device_files:
  454. logger.error("无可用训练数据")
  455. return False
  456. # 2. 每设备提取特征
  457. device_results = self.prepare_mel_features_per_device(device_files)
  458. if not device_results:
  459. logger.error("特征提取失败")
  460. return False
  461. # 3. 每设备独立训练+部署
  462. success_count = 0
  463. fail_count = 0
  464. for device_code, (mel_dir, scale) in device_results.items():
  465. try:
  466. logger.info(f"\n--- 训练设备: {device_code} ---")
  467. # 训练
  468. model, final_loss = self.train_single_device(
  469. device_code, mel_dir, epochs, lr, from_scratch=True
  470. )
  471. # 验证
  472. if not self._validate_model(model):
  473. logger.error(f" {device_code}: 模型验证失败,跳过部署")
  474. fail_count += 1
  475. continue
  476. # 部署到 models/{device_code}/
  477. self.deploy_device_model(device_code, model, scale, mel_dir)
  478. success_count += 1
  479. logger.info(f" {device_code}: 训练+部署完成 | loss={final_loss:.6f}")
  480. except Exception as e:
  481. logger.error(f" {device_code}: 训练失败 | {e}", exc_info=True)
  482. fail_count += 1
  483. logger.info("=" * 60)
  484. logger.info(f"全量训练完成: 成功={success_count}, 失败={fail_count}")
  485. logger.info("=" * 60)
  486. return fail_count == 0
  487. except Exception as e:
  488. logger.error(f"全量训练异常: {e}", exc_info=True)
  489. return False
  490. finally:
  491. # 清理临时文件
  492. if self.temp_mel_dir.exists():
  493. shutil.rmtree(self.temp_mel_dir)
  494. # ========================================
  495. # 增量训练入口(保留兼容)
  496. # ========================================
  497. def run_daily_training(self) -> bool:
  498. """
  499. 执行每日增量训练(保留原有逻辑)
  500. 改造点:每设备独立训练+部署,不再共享模型
  501. Returns:
  502. bool: 是否成功
  503. """
  504. try:
  505. days_ago = (self.use_days_ago if self.use_days_ago is not None
  506. else self.config['auto_training']['incremental']['use_days_ago'])
  507. target_date = (datetime.now() - timedelta(days=days_ago)).strftime('%Y%m%d')
  508. mode_str = "冷启动训练" if self.cold_start_mode else "增量训练"
  509. logger.info("=" * 60)
  510. logger.info(f"{mode_str} - {target_date}")
  511. logger.info("=" * 60)
  512. # 1. 收集数据
  513. device_files = self.collect_training_data(target_date)
  514. total = sum(len(f) for f in device_files.values())
  515. min_samples = self.config['auto_training']['incremental']['min_samples']
  516. if total < min_samples:
  517. logger.warning(f"数据不足 ({total} < {min_samples}),跳过")
  518. return False
  519. # 2. 备份模型
  520. if self.config['auto_training']['model']['backup_before_train']:
  521. self.backup_model(target_date)
  522. # 3. 每设备提取特征
  523. device_results = self.prepare_mel_features_per_device(device_files)
  524. if not device_results:
  525. logger.error("特征提取失败")
  526. return False
  527. # 4. 训练参数
  528. epochs = (self.epochs if self.epochs is not None
  529. else self.config['auto_training']['incremental']['epochs'])
  530. lr = (self.learning_rate if self.learning_rate is not None
  531. else self.config['auto_training']['incremental']['learning_rate'])
  532. # 5. 每设备独立训练+部署
  533. # 冷启动=全量训练(从零),增量=加载已有模型微调
  534. from_scratch = self.cold_start_mode
  535. success_count = 0
  536. for device_code, (mel_dir, scale) in device_results.items():
  537. try:
  538. model, _ = self.train_single_device(
  539. device_code, mel_dir, epochs, lr, from_scratch=from_scratch
  540. )
  541. if self._validate_model(model):
  542. if self.config['auto_training']['model']['auto_deploy']:
  543. self.deploy_device_model(device_code, model, scale, mel_dir)
  544. success_count += 1
  545. else:
  546. logger.error(f"{device_code}: 验证失败,跳过部署")
  547. except Exception as e:
  548. logger.error(f"{device_code}: 训练失败 | {e}")
  549. # 6. 更新分类器基线
  550. self._update_classifier_baseline(device_files)
  551. logger.info("=" * 60)
  552. logger.info(f"增量训练完成: {success_count}/{len(device_results)} 个设备成功")
  553. logger.info("=" * 60)
  554. return success_count > 0
  555. except Exception as e:
  556. logger.error(f"训练失败: {e}", exc_info=True)
  557. return False
  558. finally:
  559. if self.temp_mel_dir.exists():
  560. shutil.rmtree(self.temp_mel_dir)
  561. # ========================================
  562. # 辅助方法
  563. # ========================================
  564. def _validate_model(self, model: nn.Module) -> bool:
  565. # 验证模型输出形状是否合理
  566. if not self.config['auto_training']['validation']['enabled']:
  567. return True
  568. try:
  569. device = next(model.parameters()).device
  570. test_input = torch.randn(1, 1, CFG.N_MELS, CFG.TARGET_FRAMES).to(device)
  571. with torch.no_grad():
  572. output = model(test_input)
  573. h_diff = abs(output.shape[2] - test_input.shape[2])
  574. w_diff = abs(output.shape[3] - test_input.shape[3])
  575. if h_diff > 8 or w_diff > 8:
  576. logger.error(f"形状差异过大: {output.shape} vs {test_input.shape}")
  577. return False
  578. return True
  579. except Exception as e:
  580. logger.error(f"验证失败: {e}")
  581. return False
  582. def backup_model(self, date_tag: str):
  583. """
  584. 完整备份当前所有设备的模型
  585. 备份目录结构:
  586. backups/{date_tag}/{device_code}/
  587. ├── ae_model.pth
  588. ├── global_scale.npy
  589. └── thresholds/
  590. """
  591. backup_date_dir = self.backup_dir / date_tag
  592. backup_date_dir.mkdir(parents=True, exist_ok=True)
  593. backed_up = 0
  594. # 遍历 models/ 下的所有设备子目录
  595. for device_dir in self.model_root.iterdir():
  596. if not device_dir.is_dir():
  597. continue
  598. # 跳过 backups 目录本身
  599. if device_dir.name == "backups":
  600. continue
  601. # 检查是否包含模型文件(判断是否为设备目录)
  602. if not (device_dir / "ae_model.pth").exists():
  603. continue
  604. device_backup = backup_date_dir / device_dir.name
  605. # 递归复制整个设备目录
  606. shutil.copytree(device_dir, device_backup, dirs_exist_ok=True)
  607. backed_up += 1
  608. logger.info(f"备份完成: {backed_up} 个设备 -> {backup_date_dir}")
  609. # 清理旧备份
  610. keep = self.config['auto_training']['model']['keep_backups']
  611. backup_dirs = sorted(
  612. [d for d in self.backup_dir.iterdir() if d.is_dir() and d.name.isdigit()],
  613. reverse=True
  614. )
  615. for old_dir in backup_dirs[keep:]:
  616. shutil.rmtree(old_dir)
  617. logger.info(f"已删除旧备份: {old_dir.name}")
  618. def restore_backup(self, date_tag: str) -> bool:
  619. """
  620. 从备份恢复所有设备的模型
  621. Args:
  622. date_tag: 备份日期标签 'YYYYMMDD'
  623. Returns:
  624. bool: 是否恢复成功
  625. """
  626. backup_date_dir = self.backup_dir / date_tag
  627. if not backup_date_dir.exists():
  628. logger.error(f"备份目录不存在: {backup_date_dir}")
  629. return False
  630. logger.info(f"从备份恢复: {date_tag}")
  631. restored = 0
  632. for device_backup in backup_date_dir.iterdir():
  633. if not device_backup.is_dir():
  634. continue
  635. target_dir = self.model_root / device_backup.name
  636. # 递归复制恢复
  637. shutil.copytree(device_backup, target_dir, dirs_exist_ok=True)
  638. restored += 1
  639. logger.info(f"恢复完成: {restored} 个设备")
  640. return restored > 0
  641. def _update_classifier_baseline(self, device_files: Dict[str, List[Path]]):
  642. # 从训练数据计算并更新分类器基线
  643. logger.info("更新分类器基线")
  644. try:
  645. import librosa
  646. from core.anomaly_classifier import AnomalyClassifier
  647. classifier = AnomalyClassifier()
  648. all_files = []
  649. for files in device_files.values():
  650. all_files.extend(files)
  651. if not all_files:
  652. logger.warning("无音频文件,跳过基线更新")
  653. return
  654. sample_files = random.sample(all_files, min(50, len(all_files)))
  655. all_features = []
  656. for wav_file in sample_files:
  657. try:
  658. y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True)
  659. if len(y) < CFG.SR:
  660. continue
  661. features = classifier.extract_features(y, sr=CFG.SR)
  662. if features:
  663. all_features.append(features)
  664. except Exception:
  665. continue
  666. if not all_features:
  667. logger.warning("无法提取特征,跳过基线更新")
  668. return
  669. baseline = {}
  670. keys = all_features[0].keys()
  671. for key in keys:
  672. if key == 'has_periodic':
  673. values = [f[key] for f in all_features]
  674. baseline[key] = sum(values) > len(values) / 2
  675. else:
  676. values = [f[key] for f in all_features]
  677. baseline[key] = float(np.mean(values))
  678. classifier.save_baseline(baseline)
  679. logger.info(f" 基线已更新 (样本数: {len(all_features)})")
  680. except Exception as e:
  681. logger.warning(f"更新基线失败: {e}")
  682. def main():
  683. # 命令行入口(增量训练)
  684. logging.basicConfig(
  685. level=logging.INFO,
  686. format='%(asctime)s | %(levelname)-8s | %(message)s',
  687. datefmt='%Y-%m-%d %H:%M:%S'
  688. )
  689. config_file = Path(__file__).parent.parent / "config" / "auto_training.yaml"
  690. trainer = IncrementalTrainer(config_file)
  691. success = trainer.run_daily_training()
  692. sys.exit(0 if success else 1)
  693. if __name__ == "__main__":
  694. main()