|
|
@@ -192,15 +192,9 @@ class IncrementalTrainer:
|
|
|
device_code = device_dir.name
|
|
|
audio_files = []
|
|
|
|
|
|
- # 冷启动模式:收集所有目录的数据
|
|
|
+ # 冷启动模式:收集所有已归档日期目录的数据(跳过 current/)
|
|
|
if self.cold_start_mode:
|
|
|
- # 收集current目录
|
|
|
- current_dir = device_dir / "current"
|
|
|
- if current_dir.exists():
|
|
|
- audio_files.extend(list(current_dir.glob("*.wav")))
|
|
|
- audio_files.extend(list(current_dir.glob("*.mp4")))
|
|
|
-
|
|
|
- # 收集所有日期目录
|
|
|
+ # 注意:跳过 current/ 目录,因其中可能包含 FFmpeg 正在写入的不完整文件
|
|
|
for sub_dir in device_dir.iterdir():
|
|
|
if sub_dir.is_dir() and sub_dir.name.isdigit() and len(sub_dir.name) == 8:
|
|
|
audio_files.extend(list(sub_dir.glob("*.wav")))
|
|
|
@@ -221,6 +215,14 @@ class IncrementalTrainer:
|
|
|
# 去重
|
|
|
audio_files = list(set(audio_files))
|
|
|
|
|
|
+ # 数据质量预筛:过滤能量/频谱异常的音频
|
|
|
+ if audio_files and not self.cold_start_mode:
|
|
|
+ before_count = len(audio_files)
|
|
|
+ audio_files = self._filter_audio_quality(audio_files, device_code)
|
|
|
+ filtered = before_count - len(audio_files)
|
|
|
+ if filtered > 0:
|
|
|
+ logger.info(f" {device_code}: 质量预筛过滤 {filtered} 个异常音频")
|
|
|
+
|
|
|
# 随机采样(如果配置了采样时长)
|
|
|
if sample_hours > 0 and audio_files:
|
|
|
files_needed = int(sample_hours * 3600 / 60)
|
|
|
@@ -239,25 +241,84 @@ class IncrementalTrainer:
|
|
|
logger.info(f"总计: {total} 个音频文件,{len(device_files)} 个设备")
|
|
|
return device_files
|
|
|
|
|
|
+ def _filter_audio_quality(self, audio_files: List[Path],
|
|
|
+ device_code: str) -> List[Path]:
|
|
|
+ """
|
|
|
+ 音频质量预筛:基于 RMS 能量和频谱质心过滤明显异常的样本
|
|
|
+
|
|
|
+ 使用 IQR (四分位距) 方法检测离群值:
|
|
|
+ - 计算所有文件的 RMS 能量和频谱质心
|
|
|
+ - 过滤超出 [Q1 - 2*IQR, Q3 + 2*IQR] 范围的文件
|
|
|
+
|
|
|
+ 需要至少 10 个文件才执行过滤,否则样本太少无统计意义。
|
|
|
+
|
|
|
+ Args:
|
|
|
+ audio_files: 待过滤的音频文件列表
|
|
|
+ device_code: 设备编码(用于日志)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 过滤后的文件列表
|
|
|
+ """
|
|
|
+ if len(audio_files) < 10:
|
|
|
+ return audio_files
|
|
|
+
|
|
|
+ import librosa
|
|
|
+
|
|
|
+ # 快速计算每个文件的 RMS 能量
|
|
|
+ rms_values = []
|
|
|
+ valid_files = []
|
|
|
+
|
|
|
+ for wav_file in audio_files:
|
|
|
+ try:
|
|
|
+ y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True,
|
|
|
+ duration=10) # 只读前10秒加速
|
|
|
+ if len(y) < CFG.SR:
|
|
|
+ continue
|
|
|
+ rms = float(np.sqrt(np.mean(y ** 2)))
|
|
|
+ rms_values.append(rms)
|
|
|
+ valid_files.append(wav_file)
|
|
|
+ except Exception:
|
|
|
+ continue
|
|
|
+
|
|
|
+ if len(rms_values) < 10:
|
|
|
+ return audio_files
|
|
|
+
|
|
|
+ # IQR 离群值检测
|
|
|
+ rms_arr = np.array(rms_values)
|
|
|
+ q1, q3 = np.percentile(rms_arr, [25, 75])
|
|
|
+ iqr = q3 - q1
|
|
|
+ lower_bound = q1 - 2 * iqr
|
|
|
+ upper_bound = q3 + 2 * iqr
|
|
|
+
|
|
|
+ filtered = []
|
|
|
+ for f, rms in zip(valid_files, rms_values):
|
|
|
+ if lower_bound <= rms <= upper_bound:
|
|
|
+ filtered.append(f)
|
|
|
+
|
|
|
+ return filtered
|
|
|
+
|
|
|
# ========================================
|
|
|
# 特征提取(每设备独立标准化参数)
|
|
|
# ========================================
|
|
|
|
|
|
def _extract_mel_for_device(self, device_code: str,
|
|
|
- wav_files: List[Path]) -> Tuple[Optional[Path], Optional[Tuple[float, float]]]:
|
|
|
+ wav_files: List[Path],
|
|
|
+ inherit_scale: bool = False
|
|
|
+ ) -> Tuple[Optional[Path], Optional[Tuple[float, float]]]:
|
|
|
"""
|
|
|
- 为单个设备提取 Mel 特征并计算独立的 Z-score 标准化参数
|
|
|
+ 为单个设备提取 Mel 特征并计算独立的 Min-Max 标准化参数
|
|
|
|
|
|
- 两遍扫描:
|
|
|
- 1. 第一遍:收集所有 mel_db 计算 mean/std
|
|
|
- 2. 第二遍:Z-score 标准化后保存 npy 文件
|
|
|
+ 流式两遍扫描(内存优化):
|
|
|
+ 1. 第一遍:只计算 running min/max(O(1) 内存),不保存 mel_db
|
|
|
+ 2. 第二遍:用第一遍的 min/max 标准化后直接写 npy 文件
|
|
|
|
|
|
Args:
|
|
|
device_code: 设备编码
|
|
|
wav_files: 该设备的音频文件列表
|
|
|
+ inherit_scale: 增量训练时是否继承已部署的 scale 参数
|
|
|
|
|
|
Returns:
|
|
|
- (mel_dir, (global_mean, global_std)),失败返回 (None, None)
|
|
|
+ (mel_dir, (global_min, global_max)),失败返回 (None, None)
|
|
|
"""
|
|
|
import librosa
|
|
|
|
|
|
@@ -265,67 +326,90 @@ class IncrementalTrainer:
|
|
|
win_samples = int(CFG.WIN_SEC * CFG.SR)
|
|
|
hop_samples = int(CFG.HOP_SEC * CFG.SR)
|
|
|
|
|
|
- # 第一遍:收集所有 mel_db 值,用于计算 mean/std
|
|
|
- all_mel_data = []
|
|
|
- all_values = [] # 收集所有像素值用于全局统计
|
|
|
-
|
|
|
- for wav_file in wav_files:
|
|
|
- try:
|
|
|
- y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True)
|
|
|
-
|
|
|
- # 跳过过短的音频
|
|
|
- if len(y) < CFG.SR:
|
|
|
+ def _iter_mel_patches(files):
|
|
|
+ """生成器:逐文件逐 patch 产出 mel_db,避免全量加载到内存"""
|
|
|
+ for wav_file in files:
|
|
|
+ try:
|
|
|
+ y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True)
|
|
|
+ if len(y) < CFG.SR:
|
|
|
+ continue
|
|
|
+ for idx, start in enumerate(range(0, len(y) - win_samples + 1, hop_samples)):
|
|
|
+ segment = y[start:start + win_samples]
|
|
|
+ mel_spec = librosa.feature.melspectrogram(
|
|
|
+ y=segment, sr=CFG.SR, n_fft=CFG.N_FFT,
|
|
|
+ hop_length=CFG.HOP_LENGTH, n_mels=CFG.N_MELS, power=2.0
|
|
|
+ )
|
|
|
+ mel_db = librosa.power_to_db(mel_spec, ref=np.max)
|
|
|
+ # 对齐帧数
|
|
|
+ if mel_db.shape[1] < CFG.TARGET_FRAMES:
|
|
|
+ pad = CFG.TARGET_FRAMES - mel_db.shape[1]
|
|
|
+ mel_db = np.pad(mel_db, ((0, 0), (0, pad)), mode="constant")
|
|
|
+ else:
|
|
|
+ mel_db = mel_db[:, :CFG.TARGET_FRAMES]
|
|
|
+ yield wav_file, idx, mel_db
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"跳过文件 {wav_file.name}: {e}")
|
|
|
continue
|
|
|
|
|
|
- for idx, start in enumerate(range(0, len(y) - win_samples + 1, hop_samples)):
|
|
|
- segment = y[start:start + win_samples]
|
|
|
-
|
|
|
- mel_spec = librosa.feature.melspectrogram(
|
|
|
- y=segment, sr=CFG.SR, n_fft=CFG.N_FFT,
|
|
|
- hop_length=CFG.HOP_LENGTH, n_mels=CFG.N_MELS, power=2.0
|
|
|
- )
|
|
|
- mel_db = librosa.power_to_db(mel_spec, ref=np.max)
|
|
|
-
|
|
|
- # 对齐帧数
|
|
|
- if mel_db.shape[1] < CFG.TARGET_FRAMES:
|
|
|
- pad = CFG.TARGET_FRAMES - mel_db.shape[1]
|
|
|
- mel_db = np.pad(mel_db, ((0, 0), (0, pad)), mode="constant")
|
|
|
- else:
|
|
|
- mel_db = mel_db[:, :CFG.TARGET_FRAMES]
|
|
|
-
|
|
|
- # 收集所有值用于 min/max 计算
|
|
|
- all_values.append(mel_db.flatten())
|
|
|
- all_mel_data.append((wav_file, idx, mel_db))
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- logger.warning(f"跳过文件 {wav_file.name}: {e}")
|
|
|
- continue
|
|
|
-
|
|
|
- if not all_mel_data:
|
|
|
+ # ── 第一遍:流式计算 running min/max(O(1) 内存) ──
|
|
|
+ global_min = float('inf')
|
|
|
+ global_max = float('-inf')
|
|
|
+ patch_count = 0
|
|
|
+
|
|
|
+ for _, _, mel_db in _iter_mel_patches(wav_files):
|
|
|
+ local_min = float(mel_db.min())
|
|
|
+ local_max = float(mel_db.max())
|
|
|
+ if local_min < global_min:
|
|
|
+ global_min = local_min
|
|
|
+ if local_max > global_max:
|
|
|
+ global_max = local_max
|
|
|
+ patch_count += 1
|
|
|
+
|
|
|
+ if patch_count == 0:
|
|
|
logger.warning(f" {device_code}: 无有效数据")
|
|
|
return None, None
|
|
|
|
|
|
- # 计算全局 min/max(Min-Max 标准化参数)
|
|
|
- all_values_concat = np.concatenate(all_values)
|
|
|
- global_min = float(np.min(all_values_concat))
|
|
|
- global_max = float(np.max(all_values_concat))
|
|
|
-
|
|
|
- logger.info(f" {device_code}: {len(all_mel_data)} patches | "
|
|
|
+ # 增量训练时,用已部署的 scale 做 EMA 平滑,避免剧烈偏移
|
|
|
+ if inherit_scale:
|
|
|
+ old_scale = self._load_deployed_scale(device_code)
|
|
|
+ if old_scale is not None:
|
|
|
+ ema_alpha = 0.3 # 新数据权重
|
|
|
+ old_min, old_max = old_scale
|
|
|
+ global_min = ema_alpha * global_min + (1 - ema_alpha) * old_min
|
|
|
+ global_max = ema_alpha * global_max + (1 - ema_alpha) * old_max
|
|
|
+ logger.info(f" {device_code}: scale EMA 融合 | "
|
|
|
+ f"old=[{old_min:.4f}, {old_max:.4f}] → "
|
|
|
+ f"new=[{global_min:.4f}, {global_max:.4f}]")
|
|
|
+
|
|
|
+ logger.info(f" {device_code}: {patch_count} patches | "
|
|
|
f"min={global_min:.4f} max={global_max:.4f}")
|
|
|
|
|
|
- # 第二遍:Min-Max 标准化并保存
|
|
|
+ # ── 第二遍:Min-Max 标准化并保存 ──
|
|
|
device_mel_dir = self.temp_mel_dir / device_code
|
|
|
device_mel_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
- for wav_file, idx, mel_db in all_mel_data:
|
|
|
- # Min-Max: (x - min) / (max - min)
|
|
|
- mel_norm = (mel_db - global_min) / (global_max - global_min + 1e-6)
|
|
|
+ scale_range = global_max - global_min + 1e-6
|
|
|
+ for wav_file, idx, mel_db in _iter_mel_patches(wav_files):
|
|
|
+ mel_norm = (mel_db - global_min) / scale_range
|
|
|
npy_name = f"{device_code}@@{wav_file.stem}@@win{idx:05d}.npy"
|
|
|
np.save(device_mel_dir / npy_name, mel_norm.astype(np.float32))
|
|
|
|
|
|
return device_mel_dir, (global_min, global_max)
|
|
|
|
|
|
- def prepare_mel_features_per_device(self, device_files: Dict[str, List[Path]]
|
|
|
+ def _load_deployed_scale(self, device_code: str) -> Optional[Tuple[float, float]]:
|
|
|
+ """加载已部署的 global_scale.npy,用于增量训练时的 scale 继承"""
|
|
|
+ scale_path = self.model_root / device_code / "global_scale.npy"
|
|
|
+ if not scale_path.exists():
|
|
|
+ return None
|
|
|
+ try:
|
|
|
+ scale = np.load(scale_path)
|
|
|
+ return float(scale[0]), float(scale[1])
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"加载旧 scale 失败: {device_code} | {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ def prepare_mel_features_per_device(self, device_files: Dict[str, List[Path]],
|
|
|
+ inherit_scale: bool = False
|
|
|
) -> Dict[str, Tuple[Path, Tuple[float, float]]]:
|
|
|
"""
|
|
|
为每个设备独立提取 Mel 特征
|
|
|
@@ -334,6 +418,7 @@ class IncrementalTrainer:
|
|
|
|
|
|
Args:
|
|
|
device_files: {device_code: [wav_files]}
|
|
|
+ inherit_scale: 增量训练时传 True,将新旧 scale 做 EMA 融合
|
|
|
|
|
|
Returns:
|
|
|
{device_code: (mel_dir, (global_min, global_max))}
|
|
|
@@ -348,7 +433,9 @@ class IncrementalTrainer:
|
|
|
device_results = {}
|
|
|
|
|
|
for device_code, wav_files in device_files.items():
|
|
|
- mel_dir, scale = self._extract_mel_for_device(device_code, wav_files)
|
|
|
+ mel_dir, scale = self._extract_mel_for_device(
|
|
|
+ device_code, wav_files, inherit_scale=inherit_scale
|
|
|
+ )
|
|
|
if mel_dir is not None:
|
|
|
device_results[device_code] = (mel_dir, scale)
|
|
|
|
|
|
@@ -363,75 +450,279 @@ class IncrementalTrainer:
|
|
|
# 模型训练(每设备独立)
|
|
|
# ========================================
|
|
|
|
|
|
+ def _select_training_device(self) -> torch.device:
|
|
|
+ # 智能选择训练设备:GPU/NPU 显存充足则使用,否则回退 CPU
|
|
|
+ # 训练配置中可通过 training_device 强制指定 (auto/cpu/cuda/npu)
|
|
|
+ training_cfg = self.config['auto_training']['incremental']
|
|
|
+ forced_device = training_cfg.get('training_device', 'auto')
|
|
|
+
|
|
|
+ # 强制指定设备时直接返回
|
|
|
+ if forced_device == 'cpu':
|
|
|
+ logger.info("训练设备: CPU(配置强制指定)")
|
|
|
+ return torch.device('cpu')
|
|
|
+ if forced_device == 'cuda':
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ return torch.device('cuda')
|
|
|
+ logger.warning("配置指定 CUDA 但不可用,回退 CPU")
|
|
|
+ return torch.device('cpu')
|
|
|
+ if forced_device == 'npu':
|
|
|
+ if self._npu_available():
|
|
|
+ return torch.device('npu')
|
|
|
+ logger.warning("配置指定 NPU 但不可用,回退 CPU")
|
|
|
+ return torch.device('cpu')
|
|
|
+
|
|
|
+ # auto 模式:依次检测 CUDA → NPU → CPU
|
|
|
+ # 1. 检测 CUDA
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ try:
|
|
|
+ free_mem = torch.cuda.mem_get_info()[0] / (1024 * 1024)
|
|
|
+ min_gpu_mem_mb = training_cfg.get('min_gpu_mem_mb', 512)
|
|
|
+ if free_mem >= min_gpu_mem_mb:
|
|
|
+ logger.info(f"训练设备: CUDA(空闲显存 {free_mem:.0f}MB)")
|
|
|
+ return torch.device('cuda')
|
|
|
+ logger.info(
|
|
|
+ f"CUDA 空闲显存不足 ({free_mem:.0f}MB < {min_gpu_mem_mb}MB)"
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"CUDA 显存检测失败: {e}")
|
|
|
+
|
|
|
+ # 2. 检测 NPU (华为昇腾)
|
|
|
+ if self._npu_available():
|
|
|
+ try:
|
|
|
+ free_mem = torch.npu.mem_get_info()[0] / (1024 * 1024)
|
|
|
+ min_gpu_mem_mb = training_cfg.get('min_gpu_mem_mb', 512)
|
|
|
+ if free_mem >= min_gpu_mem_mb:
|
|
|
+ logger.info(f"训练设备: NPU(空闲显存 {free_mem:.0f}MB)")
|
|
|
+ return torch.device('npu')
|
|
|
+ logger.info(
|
|
|
+ f"NPU 空闲显存不足 ({free_mem:.0f}MB < {min_gpu_mem_mb}MB)"
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"NPU 显存检测失败: {e},回退 CPU")
|
|
|
+
|
|
|
+ logger.info("训练设备: CPU")
|
|
|
+ return torch.device('cpu')
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _npu_available() -> bool:
|
|
|
+ """检查华为昇腾 NPU 是否可用"""
|
|
|
+ try:
|
|
|
+ import torch_npu # noqa: F401
|
|
|
+ return torch.npu.is_available()
|
|
|
+ except ImportError:
|
|
|
+ return False
|
|
|
+
|
|
|
+ def _run_training_loop(self, device_code: str, model: nn.Module,
|
|
|
+ train_loader, val_loader, epochs: int, lr: float,
|
|
|
+ device: torch.device) -> Tuple[nn.Module, float]:
|
|
|
+ # 执行实际的训练循环,与设备选择解耦
|
|
|
+ # 早停基于验证集损失(如有),否则基于训练损失
|
|
|
+ model = model.to(device)
|
|
|
+ model.train()
|
|
|
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
|
+ criterion = nn.MSELoss()
|
|
|
+
|
|
|
+ # AMP 混合精度(GPU/NPU 生效,减少约 40% 显存占用)
|
|
|
+ use_amp = device.type in ('cuda', 'npu')
|
|
|
+ scaler = torch.amp.GradScaler(device.type) if use_amp else None
|
|
|
+
|
|
|
+ # 早停配置
|
|
|
+ early_stop_cfg = self.config['auto_training']['incremental']
|
|
|
+ patience = early_stop_cfg.get('early_stop_patience', 5)
|
|
|
+ best_loss = float('inf')
|
|
|
+ no_improve_count = 0
|
|
|
+
|
|
|
+ avg_loss = 0.0
|
|
|
+ actual_epochs = 0
|
|
|
+
|
|
|
+ for epoch in range(epochs):
|
|
|
+ # ── 训练阶段 ──
|
|
|
+ model.train()
|
|
|
+ epoch_loss = 0.0
|
|
|
+ batch_count = 0
|
|
|
+
|
|
|
+ for batch in train_loader:
|
|
|
+ batch = batch.to(device)
|
|
|
+ optimizer.zero_grad()
|
|
|
+
|
|
|
+ if use_amp:
|
|
|
+ with torch.amp.autocast(device.type):
|
|
|
+ output = model(batch)
|
|
|
+ output = align_to_target(output, batch)
|
|
|
+ loss = criterion(output, batch)
|
|
|
+ scaler.scale(loss).backward()
|
|
|
+ scaler.step(optimizer)
|
|
|
+ scaler.update()
|
|
|
+ else:
|
|
|
+ output = model(batch)
|
|
|
+ output = align_to_target(output, batch)
|
|
|
+ loss = criterion(output, batch)
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
+
|
|
|
+ epoch_loss += loss.item()
|
|
|
+ batch_count += 1
|
|
|
+
|
|
|
+ avg_loss = epoch_loss / batch_count
|
|
|
+ actual_epochs = epoch + 1
|
|
|
+
|
|
|
+ # ── 验证阶段(如有验证集) ──
|
|
|
+ if val_loader is not None:
|
|
|
+ model.eval()
|
|
|
+ val_loss = 0.0
|
|
|
+ val_count = 0
|
|
|
+ with torch.no_grad():
|
|
|
+ for batch in val_loader:
|
|
|
+ batch = batch.to(device)
|
|
|
+ if use_amp:
|
|
|
+ with torch.amp.autocast(device.type):
|
|
|
+ output = model(batch)
|
|
|
+ output = align_to_target(output, batch)
|
|
|
+ loss = criterion(output, batch)
|
|
|
+ else:
|
|
|
+ output = model(batch)
|
|
|
+ output = align_to_target(output, batch)
|
|
|
+ loss = criterion(output, batch)
|
|
|
+ val_loss += loss.item()
|
|
|
+ val_count += 1
|
|
|
+ avg_val_loss = val_loss / val_count
|
|
|
+ monitor_loss = avg_val_loss # 早停监控验证损失
|
|
|
+ else:
|
|
|
+ avg_val_loss = None
|
|
|
+ monitor_loss = avg_loss # 无验证集时回退训练损失
|
|
|
+
|
|
|
+ if actual_epochs % 10 == 0 or epoch == epochs - 1:
|
|
|
+ val_str = f" | ValLoss: {avg_val_loss:.6f}" if avg_val_loss is not None else ""
|
|
|
+ logger.info(f" [{device_code}] Epoch {actual_epochs}/{epochs} | "
|
|
|
+ f"Loss: {avg_loss:.6f}{val_str} | device={device.type}")
|
|
|
+
|
|
|
+ # 早停检测:连续 patience 轮无改善则提前终止
|
|
|
+ if monitor_loss < best_loss:
|
|
|
+ best_loss = monitor_loss
|
|
|
+ no_improve_count = 0
|
|
|
+ else:
|
|
|
+ no_improve_count += 1
|
|
|
+
|
|
|
+ if no_improve_count >= patience and actual_epochs >= 10:
|
|
|
+ logger.info(f" [{device_code}] 早停触发: 连续{patience}轮无改善 | "
|
|
|
+ f"最终轮数={actual_epochs}/{epochs} | Loss={avg_loss:.6f}")
|
|
|
+ break
|
|
|
+
|
|
|
+ # 训练后清理加速器缓存
|
|
|
+ if device.type == 'cuda':
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+ elif device.type == 'npu':
|
|
|
+ torch.npu.empty_cache()
|
|
|
+
|
|
|
+ if actual_epochs < epochs:
|
|
|
+ logger.info(f" [{device_code}] 早停节省 {epochs - actual_epochs} 轮训练")
|
|
|
+
|
|
|
+ return model, avg_loss
|
|
|
+
|
|
|
def train_single_device(self, device_code: str, mel_dir: Path,
|
|
|
epochs: int, lr: float,
|
|
|
from_scratch: bool = True
|
|
|
) -> Tuple[nn.Module, float]:
|
|
|
- """
|
|
|
- 训练单个设备的独立模型
|
|
|
-
|
|
|
- Args:
|
|
|
- device_code: 设备编码
|
|
|
- mel_dir: 该设备的 Mel 特征目录
|
|
|
- epochs: 训练轮数
|
|
|
- lr: 学习率
|
|
|
- from_scratch: True=从零训练(全量),False=加载已有模型微调(增量)
|
|
|
-
|
|
|
- Returns:
|
|
|
- (model, final_loss)
|
|
|
- """
|
|
|
+ # 训练单个设备的独立模型
|
|
|
+ # 策略:优先 GPU 训练,显存不足自动回退 CPU;训练中 OOM 也会捕获并用 CPU 重试
|
|
|
logger.info(f"训练设备 {device_code}: epochs={epochs}, lr={lr}, "
|
|
|
f"mode={'全量' if from_scratch else '增量'}")
|
|
|
|
|
|
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
- model = ConvAutoencoder().to(device)
|
|
|
+ # 智能选择训练设备
|
|
|
+ device = self._select_training_device()
|
|
|
+
|
|
|
+ # 训练前清理加速器缓存,释放推理残留的显存碎片
|
|
|
+ if device.type == 'cuda':
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+ import gc
|
|
|
+ gc.collect()
|
|
|
+ elif device.type == 'npu':
|
|
|
+ torch.npu.empty_cache()
|
|
|
+ import gc
|
|
|
+ gc.collect()
|
|
|
+
|
|
|
+ model = ConvAutoencoder()
|
|
|
|
|
|
# 增量模式下加载已有模型
|
|
|
if not from_scratch:
|
|
|
model_path = self.model_root / device_code / "ae_model.pth"
|
|
|
if model_path.exists():
|
|
|
- model.load_state_dict(torch.load(model_path, map_location=device))
|
|
|
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
|
|
logger.info(f" 已加载已有模型: {model_path}")
|
|
|
else:
|
|
|
logger.warning(f" 模型不存在,自动切换为全量训练: {model_path}")
|
|
|
|
|
|
- # 加载数据
|
|
|
+ # 加载数据并按 80/20 划分训练集/验证集
|
|
|
dataset = MelNPYDataset(mel_dir)
|
|
|
if len(dataset) == 0:
|
|
|
raise ValueError(f"设备 {device_code} 无训练数据")
|
|
|
|
|
|
batch_size = self.config['auto_training']['incremental']['batch_size']
|
|
|
- dataloader = torch.utils.data.DataLoader(
|
|
|
- dataset, batch_size=batch_size, shuffle=True, num_workers=0
|
|
|
- )
|
|
|
-
|
|
|
- # 训练
|
|
|
- model.train()
|
|
|
- optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
|
- criterion = nn.MSELoss()
|
|
|
|
|
|
- avg_loss = 0.0
|
|
|
- for epoch in range(epochs):
|
|
|
- epoch_loss = 0.0
|
|
|
- batch_count = 0
|
|
|
-
|
|
|
- for batch in dataloader:
|
|
|
- batch = batch.to(device)
|
|
|
- optimizer.zero_grad()
|
|
|
- output = model(batch)
|
|
|
- output = align_to_target(output, batch)
|
|
|
- loss = criterion(output, batch)
|
|
|
- loss.backward()
|
|
|
- optimizer.step()
|
|
|
- epoch_loss += loss.item()
|
|
|
- batch_count += 1
|
|
|
+ # 验证集划分:数据量 >= 20 时才划分(否则太少无统计意义)
|
|
|
+ val_loader = None
|
|
|
+ if len(dataset) >= 20:
|
|
|
+ val_size = max(1, int(len(dataset) * 0.2))
|
|
|
+ train_size = len(dataset) - val_size
|
|
|
+ train_dataset, val_dataset = torch.utils.data.random_split(
|
|
|
+ dataset, [train_size, val_size]
|
|
|
+ )
|
|
|
+ train_loader = torch.utils.data.DataLoader(
|
|
|
+ train_dataset, batch_size=batch_size, shuffle=True,
|
|
|
+ num_workers=0, pin_memory=False
|
|
|
+ )
|
|
|
+ val_loader = torch.utils.data.DataLoader(
|
|
|
+ val_dataset, batch_size=batch_size, shuffle=False,
|
|
|
+ num_workers=0, pin_memory=False
|
|
|
+ )
|
|
|
+ logger.info(f" 数据集划分: 训练={train_size}, 验证={val_size}")
|
|
|
+ else:
|
|
|
+ train_loader = torch.utils.data.DataLoader(
|
|
|
+ dataset, batch_size=batch_size, shuffle=True,
|
|
|
+ num_workers=0, pin_memory=False
|
|
|
+ )
|
|
|
+ logger.info(f" 数据量不足20,跳过验证集划分(共{len(dataset)}样本)")
|
|
|
+
|
|
|
+ # 尝试在选定设备上训练
|
|
|
+ if device.type in ('cuda', 'npu'):
|
|
|
+ try:
|
|
|
+ return self._run_training_loop(
|
|
|
+ device_code, model, train_loader, val_loader,
|
|
|
+ epochs, lr, device
|
|
|
+ )
|
|
|
+ except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
|
|
|
+ # GPU/NPU OOM -> 清理显存后回退 CPU 重试
|
|
|
+ if 'out of memory' not in str(e).lower() and isinstance(e, RuntimeError):
|
|
|
+ raise # 非 OOM 的 RuntimeError 不拦截
|
|
|
+ logger.warning(
|
|
|
+ f" [{device_code}] {device.type.upper()} OOM,"
|
|
|
+ f"清理显存后回退 CPU 训练"
|
|
|
+ )
|
|
|
+ import gc
|
|
|
+ gc.collect()
|
|
|
+ if device.type == 'cuda':
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+ elif device.type == 'npu':
|
|
|
+ torch.npu.empty_cache()
|
|
|
+ # 模型可能处于脏状态,重新初始化
|
|
|
+ model = ConvAutoencoder()
|
|
|
+ if not from_scratch:
|
|
|
+ model_path = self.model_root / device_code / "ae_model.pth"
|
|
|
+ if model_path.exists():
|
|
|
+ model.load_state_dict(
|
|
|
+ torch.load(model_path, map_location='cpu')
|
|
|
+ )
|
|
|
+ return self._run_training_loop(
|
|
|
+ device_code, model, train_loader, val_loader,
|
|
|
+ epochs, lr, torch.device('cpu')
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ # CPU 训练,无需 OOM 保护
|
|
|
+ return self._run_training_loop(
|
|
|
+ device_code, model, train_loader, val_loader,
|
|
|
+ epochs, lr, device
|
|
|
+ )
|
|
|
|
|
|
- avg_loss = epoch_loss / batch_count
|
|
|
- # 每10轮或最后一轮打印日志,避免日志刷屏
|
|
|
- if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
|
|
|
- logger.info(f" [{device_code}] Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.6f}")
|
|
|
-
|
|
|
- return model, avg_loss
|
|
|
|
|
|
# ========================================
|
|
|
# 产出部署(每设备独立目录)
|
|
|
@@ -616,14 +907,23 @@ class IncrementalTrainer:
|
|
|
# 增量训练入口(保留兼容)
|
|
|
# ========================================
|
|
|
|
|
|
- def run_daily_training(self) -> bool:
|
|
|
+ def run_daily_training(self, on_device_trained=None) -> bool:
|
|
|
"""
|
|
|
- 执行每日增量训练(保留原有逻辑)
|
|
|
+ 执行每日增量训练 — 逐设备串行处理
|
|
|
+
|
|
|
+ 流程(每个设备完整走完再处理下一个,降低内存/CPU 峰值):
|
|
|
+ 1. 收集所有设备的文件列表(仅路径,开销极低)
|
|
|
+ 2. 备份模型
|
|
|
+ 3. 逐设备:提取特征 → 训练 → 验证 → 部署 → 清理临时文件 → 回调通知
|
|
|
+ 4. 更新分类器基线
|
|
|
|
|
|
- 改造点:每设备独立训练+部署,不再共享模型
|
|
|
+ Args:
|
|
|
+ on_device_trained: 可选回调 fn(device_code: str),
|
|
|
+ 每个设备训练+部署成功后调用,
|
|
|
+ 用于即时触发该设备的模型热重载
|
|
|
|
|
|
Returns:
|
|
|
- bool: 是否成功
|
|
|
+ bool: 是否至少有一个设备成功
|
|
|
"""
|
|
|
try:
|
|
|
days_ago = (self.use_days_ago if self.use_days_ago is not None
|
|
|
@@ -635,7 +935,7 @@ class IncrementalTrainer:
|
|
|
logger.info(f"{mode_str} - {target_date}")
|
|
|
logger.info("=" * 60)
|
|
|
|
|
|
- # 1. 收集数据
|
|
|
+ # 1. 收集数据(仅文件路径,不加载音频,开销极低)
|
|
|
device_files = self.collect_training_data(target_date)
|
|
|
total = sum(len(f) for f in device_files.values())
|
|
|
min_samples = self.config['auto_training']['incremental']['min_samples']
|
|
|
@@ -647,43 +947,130 @@ class IncrementalTrainer:
|
|
|
if self.config['auto_training']['model']['backup_before_train']:
|
|
|
self.backup_model(target_date)
|
|
|
|
|
|
- # 3. 每设备提取特征
|
|
|
- device_results = self.prepare_mel_features_per_device(device_files)
|
|
|
- if not device_results:
|
|
|
- logger.error("特征提取失败")
|
|
|
- return False
|
|
|
-
|
|
|
- # 4. 训练参数
|
|
|
+ # 3. 训练参数
|
|
|
epochs = (self.epochs if self.epochs is not None
|
|
|
else self.config['auto_training']['incremental']['epochs'])
|
|
|
lr = (self.learning_rate if self.learning_rate is not None
|
|
|
else self.config['auto_training']['incremental']['learning_rate'])
|
|
|
|
|
|
- # 5. 每设备独立训练+部署
|
|
|
- # 冷启动=全量训练(从零),增量=加载已有模型微调
|
|
|
from_scratch = self.cold_start_mode
|
|
|
+ inherit_scale = not self.cold_start_mode
|
|
|
+
|
|
|
+ model_cfg = self.config['auto_training']['model']
|
|
|
+ rollback_enabled = model_cfg.get('rollback_on_degradation', True)
|
|
|
+ rollback_factor = model_cfg.get('rollback_factor', 2.0)
|
|
|
|
|
|
+ # 4. 逐设备串行处理:提取 → 训练 → 部署 → 清理
|
|
|
success_count = 0
|
|
|
- for device_code, (mel_dir, scale) in device_results.items():
|
|
|
+ degraded_count = 0
|
|
|
+ device_count = len(device_files)
|
|
|
+
|
|
|
+ for idx, (device_code, wav_files) in enumerate(device_files.items(), 1):
|
|
|
+ logger.info(f"\n{'='*40}")
|
|
|
+ logger.info(f"[{idx}/{device_count}] 设备: {device_code}")
|
|
|
+ logger.info(f"{'='*40}")
|
|
|
+
|
|
|
try:
|
|
|
- model, _ = self.train_single_device(
|
|
|
+ # ── 4a. 单设备特征提取 ──
|
|
|
+ mel_dir, scale = self._extract_mel_for_device(
|
|
|
+ device_code, wav_files, inherit_scale=inherit_scale
|
|
|
+ )
|
|
|
+ if mel_dir is None:
|
|
|
+ logger.warning(f"{device_code}: 特征提取无有效数据,跳过")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # ── 4b. 训练 ──
|
|
|
+ model, final_loss = self.train_single_device(
|
|
|
device_code, mel_dir, epochs, lr, from_scratch=from_scratch
|
|
|
)
|
|
|
|
|
|
- if self._validate_model(model):
|
|
|
- if self.config['auto_training']['model']['auto_deploy']:
|
|
|
- self.deploy_device_model(device_code, model, scale, mel_dir)
|
|
|
- success_count += 1
|
|
|
- else:
|
|
|
- logger.error(f"{device_code}: 验证失败,跳过部署")
|
|
|
+ # ── 4c. 形状验证 ──
|
|
|
+ if not self._validate_model(model):
|
|
|
+ logger.error(f"{device_code}: 形状验证失败,跳过部署")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # ── 4d. 损失退化检测(增量训练时生效) ──
|
|
|
+ if rollback_enabled and not self.cold_start_mode:
|
|
|
+ old_threshold = self._get_old_threshold(device_code)
|
|
|
+ if old_threshold and old_threshold > 0:
|
|
|
+ if final_loss > old_threshold * rollback_factor:
|
|
|
+ logger.warning(
|
|
|
+ f"{device_code}: 损失退化检测触发 | "
|
|
|
+ f"训练损失={final_loss:.6f} > "
|
|
|
+ f"旧阈值={old_threshold:.6f} × {rollback_factor} = "
|
|
|
+ f"{old_threshold * rollback_factor:.6f}"
|
|
|
+ )
|
|
|
+ degraded_count += 1
|
|
|
+ continue
|
|
|
+
|
|
|
+ # ── 4e. 阈值偏移检测 + 部署 ──
|
|
|
+ if model_cfg.get('auto_deploy', True):
|
|
|
+ if rollback_enabled and not self.cold_start_mode:
|
|
|
+ old_threshold = self._get_old_threshold(device_code)
|
|
|
+ if old_threshold and old_threshold > 0:
|
|
|
+ new_threshold = self._compute_threshold(model, mel_dir)
|
|
|
+ drift_ratio = abs(new_threshold - old_threshold) / old_threshold
|
|
|
+ # 记录阈值变化趋势(用于长期漂移监控)
|
|
|
+ self._log_threshold_history(
|
|
|
+ device_code, target_date,
|
|
|
+ old_threshold, new_threshold, final_loss
|
|
|
+ )
|
|
|
+ if drift_ratio > 0.3:
|
|
|
+ logger.warning(
|
|
|
+ f"{device_code}: 阈值偏移告警 | "
|
|
|
+ f"旧={old_threshold:.6f} → "
|
|
|
+ f"新={new_threshold:.6f} | "
|
|
|
+ f"偏移={drift_ratio:.1%}"
|
|
|
+ )
|
|
|
+ if drift_ratio > 1.0:
|
|
|
+ logger.warning(
|
|
|
+ f"{device_code}: 阈值偏移过大"
|
|
|
+ f"(>{drift_ratio:.0%}),跳过部署"
|
|
|
+ )
|
|
|
+ degraded_count += 1
|
|
|
+ continue
|
|
|
+ self.deploy_device_model(device_code, model, scale, mel_dir)
|
|
|
+
|
|
|
+ success_count += 1
|
|
|
+ logger.info(f"{device_code}: 训练+部署完成 | loss={final_loss:.6f}")
|
|
|
+
|
|
|
+ # ── 4f. 即时通知该设备模型重载 ──
|
|
|
+ if on_device_trained:
|
|
|
+ try:
|
|
|
+ on_device_trained(device_code)
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"{device_code}: 模型重载回调失败 | {e}")
|
|
|
+
|
|
|
except Exception as e:
|
|
|
- logger.error(f"{device_code}: 训练失败 | {e}")
|
|
|
+ logger.error(f"{device_code}: 训练失败 | {e}", exc_info=True)
|
|
|
+
|
|
|
+ finally:
|
|
|
+ # ── 4g. 清理该设备的临时 Mel 文件,释放磁盘空间 ──
|
|
|
+ device_mel_dir = self.temp_mel_dir / device_code
|
|
|
+ if device_mel_dir.exists():
|
|
|
+ shutil.rmtree(device_mel_dir)
|
|
|
+
|
|
|
+ # 5. 如果所有设备都退化,整体回滚到训练前备份
|
|
|
+ if degraded_count > 0 and success_count == 0:
|
|
|
+ logger.error(
|
|
|
+ f"所有设备训练后损失退化({degraded_count}个),执行整体回滚"
|
|
|
+ )
|
|
|
+ self.restore_backup(target_date)
|
|
|
+ return False
|
|
|
+
|
|
|
+ if degraded_count > 0:
|
|
|
+ logger.warning(
|
|
|
+ f"{degraded_count} 个设备因损失退化跳过部署,"
|
|
|
+ f"{success_count} 个设备部署成功"
|
|
|
+ )
|
|
|
|
|
|
# 6. 更新分类器基线
|
|
|
self._update_classifier_baseline(device_files)
|
|
|
|
|
|
logger.info("=" * 60)
|
|
|
- logger.info(f"增量训练完成: {success_count}/{len(device_results)} 个设备成功")
|
|
|
+ logger.info(f"增量训练完成: {success_count}/{device_count} 个设备成功")
|
|
|
+ if degraded_count > 0:
|
|
|
+ logger.info(f" 其中 {degraded_count} 个设备因损失退化跳过")
|
|
|
logger.info("=" * 60)
|
|
|
return success_count > 0
|
|
|
|
|
|
@@ -699,6 +1086,60 @@ class IncrementalTrainer:
|
|
|
# 辅助方法
|
|
|
# ========================================
|
|
|
|
|
|
+ def _get_old_threshold(self, device_code: str) -> float:
|
|
|
+ """
|
|
|
+ 读取设备当前已部署的阈值(训练前的旧阈值)
|
|
|
+
|
|
|
+ 用于损失退化校验:新模型的训练损失不应远超旧阈值。
|
|
|
+ 阈值文件路径: models/{device_code}/thresholds/threshold_{device_code}.npy
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ 阈值浮点数,文件不存在时返回 0.0
|
|
|
+ """
|
|
|
+ threshold_file = self.model_root / device_code / "thresholds" / f"threshold_{device_code}.npy"
|
|
|
+ if not threshold_file.exists():
|
|
|
+ return 0.0
|
|
|
+ try:
|
|
|
+ data = np.load(threshold_file)
|
|
|
+ return float(data.flat[0])
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"读取旧阈值失败: {device_code} | {e}")
|
|
|
+ return 0.0
|
|
|
+
|
|
|
+ def _log_threshold_history(self, device_code: str, date_str: str,
|
|
|
+ old_threshold: float, new_threshold: float,
|
|
|
+ train_loss: float):
|
|
|
+ """
|
|
|
+ 记录阈值变化历史到 CSV,用于监控模型长期漂移趋势
|
|
|
+
|
|
|
+ 文件路径: logs/threshold_history.csv
|
|
|
+ 格式: date,device_code,old_threshold,new_threshold,drift_ratio,train_loss
|
|
|
+ """
|
|
|
+ import csv
|
|
|
+
|
|
|
+ log_dir = self.deploy_root / "logs"
|
|
|
+ log_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+ csv_path = log_dir / "threshold_history.csv"
|
|
|
+
|
|
|
+ drift_ratio = (new_threshold - old_threshold) / old_threshold if old_threshold > 0 else 0.0
|
|
|
+ write_header = not csv_path.exists()
|
|
|
+
|
|
|
+ try:
|
|
|
+ with open(csv_path, 'a', newline='', encoding='utf-8') as f:
|
|
|
+ writer = csv.writer(f)
|
|
|
+ if write_header:
|
|
|
+ writer.writerow([
|
|
|
+ 'date', 'device_code', 'old_threshold', 'new_threshold',
|
|
|
+ 'drift_ratio', 'train_loss'
|
|
|
+ ])
|
|
|
+ writer.writerow([
|
|
|
+ date_str, device_code,
|
|
|
+ f"{old_threshold:.8f}", f"{new_threshold:.8f}",
|
|
|
+ f"{drift_ratio:.4f}", f"{train_loss:.8f}"
|
|
|
+ ])
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"写入阈值历史失败: {e}")
|
|
|
+
|
|
|
def _validate_model(self, model: nn.Module) -> bool:
|
|
|
# 验证模型输出形状是否合理
|
|
|
if not self.config['auto_training']['validation']['enabled']:
|