|
@@ -57,15 +57,18 @@ class IncrementalTrainer:
|
|
|
2. 增量训练:使用运行中采集的数据,对已有模型微调(兼容旧逻辑)
|
|
2. 增量训练:使用运行中采集的数据,对已有模型微调(兼容旧逻辑)
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
- def __init__(self, config_file: Path):
|
|
|
|
|
- """
|
|
|
|
|
- 初始化训练器
|
|
|
|
|
-
|
|
|
|
|
- Args:
|
|
|
|
|
- config_file: auto_training.yaml 配置文件路径
|
|
|
|
|
- """
|
|
|
|
|
- self.config_file = config_file
|
|
|
|
|
- self.config = self._load_config()
|
|
|
|
|
|
|
+ def __init__(self, config_file: Path = None, config: dict = None):
|
|
|
|
|
+ # 支持两种初始化方式:
|
|
|
|
|
+ # 1. 传 config dict(从数据库读取后直接传入,主程序使用)
|
|
|
|
|
+ # 2. 传 config_file YAML 路径(standalone_train.py 等独立工具使用)
|
|
|
|
|
+ if config is not None:
|
|
|
|
|
+ self.config = config
|
|
|
|
|
+ self.config_file = None
|
|
|
|
|
+ elif config_file is not None:
|
|
|
|
|
+ self.config_file = config_file
|
|
|
|
|
+ self.config = self._load_config()
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise ValueError("必须提供 config_file 或 config 之一")
|
|
|
|
|
|
|
|
# 路径配置
|
|
# 路径配置
|
|
|
self.deploy_root = Path(__file__).parent.parent
|
|
self.deploy_root = Path(__file__).parent.parent
|
|
@@ -88,7 +91,7 @@ class IncrementalTrainer:
|
|
|
self.cold_start_mode = False
|
|
self.cold_start_mode = False
|
|
|
|
|
|
|
|
def _load_config(self) -> Dict:
|
|
def _load_config(self) -> Dict:
|
|
|
- # 加载配置文件
|
|
|
|
|
|
|
+ # 从 YAML 文件加载配置(仅 config_file 模式使用)
|
|
|
with open(self.config_file, 'r', encoding='utf-8') as f:
|
|
with open(self.config_file, 'r', encoding='utf-8') as f:
|
|
|
return yaml.safe_load(f)
|
|
return yaml.safe_load(f)
|
|
|
|
|
|
|
@@ -192,30 +195,43 @@ class IncrementalTrainer:
|
|
|
device_code = device_dir.name
|
|
device_code = device_dir.name
|
|
|
audio_files = []
|
|
audio_files = []
|
|
|
|
|
|
|
|
- # 冷启动模式:收集所有已归档日期目录的数据(跳过 current/)
|
|
|
|
|
|
|
+ # 冷启动模式:收集所有已归档日期目录的正常音频(跳过 current/)
|
|
|
if self.cold_start_mode:
|
|
if self.cold_start_mode:
|
|
|
# 注意:跳过 current/ 目录,因其中可能包含 FFmpeg 正在写入的不完整文件
|
|
# 注意:跳过 current/ 目录,因其中可能包含 FFmpeg 正在写入的不完整文件
|
|
|
for sub_dir in device_dir.iterdir():
|
|
for sub_dir in device_dir.iterdir():
|
|
|
if sub_dir.is_dir() and sub_dir.name.isdigit() and len(sub_dir.name) == 8:
|
|
if sub_dir.is_dir() and sub_dir.name.isdigit() and len(sub_dir.name) == 8:
|
|
|
|
|
+ # 新结构:从 {date}/normal/ 子目录读取
|
|
|
|
|
+ normal_dir = sub_dir / "normal"
|
|
|
|
|
+ if normal_dir.exists():
|
|
|
|
|
+ audio_files.extend(list(normal_dir.glob("*.wav")))
|
|
|
|
|
+ audio_files.extend(list(normal_dir.glob("*.mp4")))
|
|
|
|
|
+ # 兼容旧结构:日期目录下直接存放的音频文件
|
|
|
audio_files.extend(list(sub_dir.glob("*.wav")))
|
|
audio_files.extend(list(sub_dir.glob("*.wav")))
|
|
|
audio_files.extend(list(sub_dir.glob("*.mp4")))
|
|
audio_files.extend(list(sub_dir.glob("*.mp4")))
|
|
|
else:
|
|
else:
|
|
|
- # 正常模式:只收集指定日期的目录
|
|
|
|
|
|
|
+ # 正常模式:只收集指定日期的正常音频
|
|
|
date_dir = device_dir / target_date
|
|
date_dir = device_dir / target_date
|
|
|
if date_dir.exists():
|
|
if date_dir.exists():
|
|
|
|
|
+ # 新结构:从 {date}/normal/ 子目录读取
|
|
|
|
|
+ normal_dir = date_dir / "normal"
|
|
|
|
|
+ if normal_dir.exists():
|
|
|
|
|
+ audio_files.extend(list(normal_dir.glob("*.wav")))
|
|
|
|
|
+ audio_files.extend(list(normal_dir.glob("*.mp4")))
|
|
|
|
|
+ # 兼容旧结构:日期目录下直接存放的音频文件
|
|
|
audio_files.extend(list(date_dir.glob("*.wav")))
|
|
audio_files.extend(list(date_dir.glob("*.wav")))
|
|
|
audio_files.extend(list(date_dir.glob("*.mp4")))
|
|
audio_files.extend(list(date_dir.glob("*.mp4")))
|
|
|
|
|
|
|
|
- # 加上 verified_normal 目录
|
|
|
|
|
|
|
+ # 加上 verified_normal 目录(单独收集,不参与采样和质量预筛)
|
|
|
verified_dir = device_dir / "verified_normal"
|
|
verified_dir = device_dir / "verified_normal"
|
|
|
|
|
+ verified_files = []
|
|
|
if verified_dir.exists():
|
|
if verified_dir.exists():
|
|
|
- audio_files.extend(list(verified_dir.glob("*.wav")))
|
|
|
|
|
- audio_files.extend(list(verified_dir.glob("*.mp4")))
|
|
|
|
|
|
|
+ verified_files.extend(list(verified_dir.glob("*.wav")))
|
|
|
|
|
+ verified_files.extend(list(verified_dir.glob("*.mp4")))
|
|
|
|
|
|
|
|
- # 去重
|
|
|
|
|
|
|
+ # 去重(仅日期目录音频)
|
|
|
audio_files = list(set(audio_files))
|
|
audio_files = list(set(audio_files))
|
|
|
|
|
|
|
|
- # 数据质量预筛:过滤能量/频谱异常的音频
|
|
|
|
|
|
|
+ # 数据质量预筛:仅对日期目录音频过滤,verified_normal 已经人工确认,跳过
|
|
|
if audio_files and not self.cold_start_mode:
|
|
if audio_files and not self.cold_start_mode:
|
|
|
before_count = len(audio_files)
|
|
before_count = len(audio_files)
|
|
|
audio_files = self._filter_audio_quality(audio_files, device_code)
|
|
audio_files = self._filter_audio_quality(audio_files, device_code)
|
|
@@ -223,7 +239,7 @@ class IncrementalTrainer:
|
|
|
if filtered > 0:
|
|
if filtered > 0:
|
|
|
logger.info(f" {device_code}: 质量预筛过滤 {filtered} 个异常音频")
|
|
logger.info(f" {device_code}: 质量预筛过滤 {filtered} 个异常音频")
|
|
|
|
|
|
|
|
- # 随机采样(如果配置了采样时长)
|
|
|
|
|
|
|
+ # 随机采样(仅对日期目录音频采样,verified_normal 不参与)
|
|
|
if sample_hours > 0 and audio_files:
|
|
if sample_hours > 0 and audio_files:
|
|
|
files_needed = int(sample_hours * 3600 / 60)
|
|
files_needed = int(sample_hours * 3600 / 60)
|
|
|
if len(audio_files) > files_needed:
|
|
if len(audio_files) > files_needed:
|
|
@@ -234,6 +250,11 @@ class IncrementalTrainer:
|
|
|
else:
|
|
else:
|
|
|
logger.info(f" {device_code}: {len(audio_files)} 个音频")
|
|
logger.info(f" {device_code}: {len(audio_files)} 个音频")
|
|
|
|
|
|
|
|
|
|
+ # 合并 verified_normal(采样后追加,保证全量参与训练)
|
|
|
|
|
+ if verified_files:
|
|
|
|
|
+ audio_files.extend(verified_files)
|
|
|
|
|
+ logger.info(f" {device_code}: +{len(verified_files)} 个核查确认音频(verified_normal)")
|
|
|
|
|
+
|
|
|
if audio_files:
|
|
if audio_files:
|
|
|
device_files[device_code] = audio_files
|
|
device_files[device_code] = audio_files
|
|
|
|
|
|
|
@@ -302,8 +323,7 @@ class IncrementalTrainer:
|
|
|
# ========================================
|
|
# ========================================
|
|
|
|
|
|
|
|
def _extract_mel_for_device(self, device_code: str,
|
|
def _extract_mel_for_device(self, device_code: str,
|
|
|
- wav_files: List[Path],
|
|
|
|
|
- inherit_scale: bool = False
|
|
|
|
|
|
|
+ wav_files: List[Path]
|
|
|
) -> Tuple[Optional[Path], Optional[Tuple[float, float]]]:
|
|
) -> Tuple[Optional[Path], Optional[Tuple[float, float]]]:
|
|
|
"""
|
|
"""
|
|
|
为单个设备提取 Mel 特征并计算独立的 Min-Max 标准化参数
|
|
为单个设备提取 Mel 特征并计算独立的 Min-Max 标准化参数
|
|
@@ -315,7 +335,6 @@ class IncrementalTrainer:
|
|
|
Args:
|
|
Args:
|
|
|
device_code: 设备编码
|
|
device_code: 设备编码
|
|
|
wav_files: 该设备的音频文件列表
|
|
wav_files: 该设备的音频文件列表
|
|
|
- inherit_scale: 增量训练时是否继承已部署的 scale 参数
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
|
(mel_dir, (global_min, global_max)),失败返回 (None, None)
|
|
(mel_dir, (global_min, global_max)),失败返回 (None, None)
|
|
@@ -369,18 +388,6 @@ class IncrementalTrainer:
|
|
|
logger.warning(f" {device_code}: 无有效数据")
|
|
logger.warning(f" {device_code}: 无有效数据")
|
|
|
return None, None
|
|
return None, None
|
|
|
|
|
|
|
|
- # 增量训练时,用已部署的 scale 做 EMA 平滑,避免剧烈偏移
|
|
|
|
|
- if inherit_scale:
|
|
|
|
|
- old_scale = self._load_deployed_scale(device_code)
|
|
|
|
|
- if old_scale is not None:
|
|
|
|
|
- ema_alpha = 0.3 # 新数据权重
|
|
|
|
|
- old_min, old_max = old_scale
|
|
|
|
|
- global_min = ema_alpha * global_min + (1 - ema_alpha) * old_min
|
|
|
|
|
- global_max = ema_alpha * global_max + (1 - ema_alpha) * old_max
|
|
|
|
|
- logger.info(f" {device_code}: scale EMA 融合 | "
|
|
|
|
|
- f"old=[{old_min:.4f}, {old_max:.4f}] → "
|
|
|
|
|
- f"new=[{global_min:.4f}, {global_max:.4f}]")
|
|
|
|
|
-
|
|
|
|
|
logger.info(f" {device_code}: {patch_count} patches | "
|
|
logger.info(f" {device_code}: {patch_count} patches | "
|
|
|
f"min={global_min:.4f} max={global_max:.4f}")
|
|
f"min={global_min:.4f} max={global_max:.4f}")
|
|
|
|
|
|
|
@@ -396,20 +403,7 @@ class IncrementalTrainer:
|
|
|
|
|
|
|
|
return device_mel_dir, (global_min, global_max)
|
|
return device_mel_dir, (global_min, global_max)
|
|
|
|
|
|
|
|
- def _load_deployed_scale(self, device_code: str) -> Optional[Tuple[float, float]]:
|
|
|
|
|
- """加载已部署的 global_scale.npy,用于增量训练时的 scale 继承"""
|
|
|
|
|
- scale_path = self.model_root / device_code / "global_scale.npy"
|
|
|
|
|
- if not scale_path.exists():
|
|
|
|
|
- return None
|
|
|
|
|
- try:
|
|
|
|
|
- scale = np.load(scale_path)
|
|
|
|
|
- return float(scale[0]), float(scale[1])
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.warning(f"加载旧 scale 失败: {device_code} | {e}")
|
|
|
|
|
- return None
|
|
|
|
|
-
|
|
|
|
|
- def prepare_mel_features_per_device(self, device_files: Dict[str, List[Path]],
|
|
|
|
|
- inherit_scale: bool = False
|
|
|
|
|
|
|
+ def prepare_mel_features_per_device(self, device_files: Dict[str, List[Path]]
|
|
|
) -> Dict[str, Tuple[Path, Tuple[float, float]]]:
|
|
) -> Dict[str, Tuple[Path, Tuple[float, float]]]:
|
|
|
"""
|
|
"""
|
|
|
为每个设备独立提取 Mel 特征
|
|
为每个设备独立提取 Mel 特征
|
|
@@ -418,7 +412,6 @@ class IncrementalTrainer:
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
device_files: {device_code: [wav_files]}
|
|
device_files: {device_code: [wav_files]}
|
|
|
- inherit_scale: 增量训练时传 True,将新旧 scale 做 EMA 融合
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
|
{device_code: (mel_dir, (global_min, global_max))}
|
|
{device_code: (mel_dir, (global_min, global_max))}
|
|
@@ -433,9 +426,7 @@ class IncrementalTrainer:
|
|
|
device_results = {}
|
|
device_results = {}
|
|
|
|
|
|
|
|
for device_code, wav_files in device_files.items():
|
|
for device_code, wav_files in device_files.items():
|
|
|
- mel_dir, scale = self._extract_mel_for_device(
|
|
|
|
|
- device_code, wav_files, inherit_scale=inherit_scale
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ mel_dir, scale = self._extract_mel_for_device(device_code, wav_files)
|
|
|
if mel_dir is not None:
|
|
if mel_dir is not None:
|
|
|
device_results[device_code] = (mel_dir, scale)
|
|
device_results[device_code] = (mel_dir, scale)
|
|
|
|
|
|
|
@@ -451,8 +442,8 @@ class IncrementalTrainer:
|
|
|
# ========================================
|
|
# ========================================
|
|
|
|
|
|
|
|
def _select_training_device(self) -> torch.device:
|
|
def _select_training_device(self) -> torch.device:
|
|
|
- # 智能选择训练设备:GPU/NPU 显存充足则使用,否则回退 CPU
|
|
|
|
|
- # 训练配置中可通过 training_device 强制指定 (auto/cpu/cuda/npu)
|
|
|
|
|
|
|
+ # 智能选择训练设备:GPU 显存充足则使用,否则回退 CPU
|
|
|
|
|
+ # 训练配置中可通过 training_device 强制指定 (auto/cpu/cuda)
|
|
|
training_cfg = self.config['auto_training']['incremental']
|
|
training_cfg = self.config['auto_training']['incremental']
|
|
|
forced_device = training_cfg.get('training_device', 'auto')
|
|
forced_device = training_cfg.get('training_device', 'auto')
|
|
|
|
|
|
|
@@ -465,14 +456,8 @@ class IncrementalTrainer:
|
|
|
return torch.device('cuda')
|
|
return torch.device('cuda')
|
|
|
logger.warning("配置指定 CUDA 但不可用,回退 CPU")
|
|
logger.warning("配置指定 CUDA 但不可用,回退 CPU")
|
|
|
return torch.device('cpu')
|
|
return torch.device('cpu')
|
|
|
- if forced_device == 'npu':
|
|
|
|
|
- if self._npu_available():
|
|
|
|
|
- return torch.device('npu')
|
|
|
|
|
- logger.warning("配置指定 NPU 但不可用,回退 CPU")
|
|
|
|
|
- return torch.device('cpu')
|
|
|
|
|
|
|
|
|
|
- # auto 模式:依次检测 CUDA → NPU → CPU
|
|
|
|
|
- # 1. 检测 CUDA
|
|
|
|
|
|
|
+ # auto 模式:检测 CUDA → CPU
|
|
|
if torch.cuda.is_available():
|
|
if torch.cuda.is_available():
|
|
|
try:
|
|
try:
|
|
|
free_mem = torch.cuda.mem_get_info()[0] / (1024 * 1024)
|
|
free_mem = torch.cuda.mem_get_info()[0] / (1024 * 1024)
|
|
@@ -486,32 +471,9 @@ class IncrementalTrainer:
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
logger.warning(f"CUDA 显存检测失败: {e}")
|
|
logger.warning(f"CUDA 显存检测失败: {e}")
|
|
|
|
|
|
|
|
- # 2. 检测 NPU (华为昇腾)
|
|
|
|
|
- if self._npu_available():
|
|
|
|
|
- try:
|
|
|
|
|
- free_mem = torch.npu.mem_get_info()[0] / (1024 * 1024)
|
|
|
|
|
- min_gpu_mem_mb = training_cfg.get('min_gpu_mem_mb', 512)
|
|
|
|
|
- if free_mem >= min_gpu_mem_mb:
|
|
|
|
|
- logger.info(f"训练设备: NPU(空闲显存 {free_mem:.0f}MB)")
|
|
|
|
|
- return torch.device('npu')
|
|
|
|
|
- logger.info(
|
|
|
|
|
- f"NPU 空闲显存不足 ({free_mem:.0f}MB < {min_gpu_mem_mb}MB)"
|
|
|
|
|
- )
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.warning(f"NPU 显存检测失败: {e},回退 CPU")
|
|
|
|
|
-
|
|
|
|
|
logger.info("训练设备: CPU")
|
|
logger.info("训练设备: CPU")
|
|
|
return torch.device('cpu')
|
|
return torch.device('cpu')
|
|
|
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def _npu_available() -> bool:
|
|
|
|
|
- """检查华为昇腾 NPU 是否可用"""
|
|
|
|
|
- try:
|
|
|
|
|
- import torch_npu # noqa: F401
|
|
|
|
|
- return torch.npu.is_available()
|
|
|
|
|
- except ImportError:
|
|
|
|
|
- return False
|
|
|
|
|
-
|
|
|
|
|
def _run_training_loop(self, device_code: str, model: nn.Module,
|
|
def _run_training_loop(self, device_code: str, model: nn.Module,
|
|
|
train_loader, val_loader, epochs: int, lr: float,
|
|
train_loader, val_loader, epochs: int, lr: float,
|
|
|
device: torch.device) -> Tuple[nn.Module, float]:
|
|
device: torch.device) -> Tuple[nn.Module, float]:
|
|
@@ -522,8 +484,8 @@ class IncrementalTrainer:
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
|
criterion = nn.MSELoss()
|
|
criterion = nn.MSELoss()
|
|
|
|
|
|
|
|
- # AMP 混合精度(GPU/NPU 生效,减少约 40% 显存占用)
|
|
|
|
|
- use_amp = device.type in ('cuda', 'npu')
|
|
|
|
|
|
|
+ # AMP 混合精度(GPU 生效,减少约 40% 显存占用)
|
|
|
|
|
+ use_amp = device.type == 'cuda'
|
|
|
scaler = torch.amp.GradScaler(device.type) if use_amp else None
|
|
scaler = torch.amp.GradScaler(device.type) if use_amp else None
|
|
|
|
|
|
|
|
# 早停配置
|
|
# 早停配置
|
|
@@ -608,11 +570,9 @@ class IncrementalTrainer:
|
|
|
f"最终轮数={actual_epochs}/{epochs} | Loss={avg_loss:.6f}")
|
|
f"最终轮数={actual_epochs}/{epochs} | Loss={avg_loss:.6f}")
|
|
|
break
|
|
break
|
|
|
|
|
|
|
|
- # 训练后清理加速器缓存
|
|
|
|
|
|
|
+ # 训练后清理 GPU 缓存
|
|
|
if device.type == 'cuda':
|
|
if device.type == 'cuda':
|
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.empty_cache()
|
|
|
- elif device.type == 'npu':
|
|
|
|
|
- torch.npu.empty_cache()
|
|
|
|
|
|
|
|
|
|
if actual_epochs < epochs:
|
|
if actual_epochs < epochs:
|
|
|
logger.info(f" [{device_code}] 早停节省 {epochs - actual_epochs} 轮训练")
|
|
logger.info(f" [{device_code}] 早停节省 {epochs - actual_epochs} 轮训练")
|
|
@@ -631,15 +591,11 @@ class IncrementalTrainer:
|
|
|
# 智能选择训练设备
|
|
# 智能选择训练设备
|
|
|
device = self._select_training_device()
|
|
device = self._select_training_device()
|
|
|
|
|
|
|
|
- # 训练前清理加速器缓存,释放推理残留的显存碎片
|
|
|
|
|
|
|
+ # 训练前清理 GPU 缓存,释放推理残留的显存碎片
|
|
|
if device.type == 'cuda':
|
|
if device.type == 'cuda':
|
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.empty_cache()
|
|
|
import gc
|
|
import gc
|
|
|
gc.collect()
|
|
gc.collect()
|
|
|
- elif device.type == 'npu':
|
|
|
|
|
- torch.npu.empty_cache()
|
|
|
|
|
- import gc
|
|
|
|
|
- gc.collect()
|
|
|
|
|
|
|
|
|
|
model = ConvAutoencoder()
|
|
model = ConvAutoencoder()
|
|
|
|
|
|
|
@@ -684,26 +640,22 @@ class IncrementalTrainer:
|
|
|
logger.info(f" 数据量不足20,跳过验证集划分(共{len(dataset)}样本)")
|
|
logger.info(f" 数据量不足20,跳过验证集划分(共{len(dataset)}样本)")
|
|
|
|
|
|
|
|
# 尝试在选定设备上训练
|
|
# 尝试在选定设备上训练
|
|
|
- if device.type in ('cuda', 'npu'):
|
|
|
|
|
|
|
+ if device.type == 'cuda':
|
|
|
try:
|
|
try:
|
|
|
return self._run_training_loop(
|
|
return self._run_training_loop(
|
|
|
device_code, model, train_loader, val_loader,
|
|
device_code, model, train_loader, val_loader,
|
|
|
epochs, lr, device
|
|
epochs, lr, device
|
|
|
)
|
|
)
|
|
|
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
|
|
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
|
|
|
- # GPU/NPU OOM -> 清理显存后回退 CPU 重试
|
|
|
|
|
|
|
+ # GPU OOM -> 清理显存后回退 CPU 重试
|
|
|
if 'out of memory' not in str(e).lower() and isinstance(e, RuntimeError):
|
|
if 'out of memory' not in str(e).lower() and isinstance(e, RuntimeError):
|
|
|
raise # 非 OOM 的 RuntimeError 不拦截
|
|
raise # 非 OOM 的 RuntimeError 不拦截
|
|
|
logger.warning(
|
|
logger.warning(
|
|
|
- f" [{device_code}] {device.type.upper()} OOM,"
|
|
|
|
|
- f"清理显存后回退 CPU 训练"
|
|
|
|
|
|
|
+ f" [{device_code}] CUDA OOM,清理显存后回退 CPU 训练"
|
|
|
)
|
|
)
|
|
|
import gc
|
|
import gc
|
|
|
gc.collect()
|
|
gc.collect()
|
|
|
- if device.type == 'cuda':
|
|
|
|
|
- torch.cuda.empty_cache()
|
|
|
|
|
- elif device.type == 'npu':
|
|
|
|
|
- torch.npu.empty_cache()
|
|
|
|
|
|
|
+ torch.cuda.empty_cache()
|
|
|
# 模型可能处于脏状态,重新初始化
|
|
# 模型可能处于脏状态,重新初始化
|
|
|
model = ConvAutoencoder()
|
|
model = ConvAutoencoder()
|
|
|
if not from_scratch:
|
|
if not from_scratch:
|
|
@@ -812,6 +764,31 @@ class IncrementalTrainer:
|
|
|
|
|
|
|
|
return threshold
|
|
return threshold
|
|
|
|
|
|
|
|
|
|
+ def _eval_model_error(self, model: nn.Module, mel_dir: Path) -> float:
|
|
|
|
|
+ """在验证数据上计算模型的平均重建误差,用于新旧模型对比"""
|
|
|
|
|
+ device = next(model.parameters()).device
|
|
|
|
|
+ model.eval()
|
|
|
|
|
+
|
|
|
|
|
+ dataset = MelNPYDataset(mel_dir)
|
|
|
|
|
+ if len(dataset) == 0:
|
|
|
|
|
+ return float('inf')
|
|
|
|
|
+
|
|
|
|
|
+ dataloader = torch.utils.data.DataLoader(
|
|
|
|
|
+ dataset, batch_size=64, shuffle=False
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ all_errors = []
|
|
|
|
|
+ with torch.no_grad():
|
|
|
|
|
+ for batch in dataloader:
|
|
|
|
|
+ batch = batch.to(device)
|
|
|
|
|
+ output = model(batch)
|
|
|
|
|
+ output = align_to_target(output, batch)
|
|
|
|
|
+ mse = torch.mean((output - batch) ** 2, dim=[1, 2, 3])
|
|
|
|
|
+ all_errors.append(mse.cpu().numpy())
|
|
|
|
|
+
|
|
|
|
|
+ errors = np.concatenate(all_errors)
|
|
|
|
|
+ return float(np.mean(errors))
|
|
|
|
|
+
|
|
|
# ========================================
|
|
# ========================================
|
|
|
# 全量训练入口
|
|
# 全量训练入口
|
|
|
# ========================================
|
|
# ========================================
|
|
@@ -954,7 +931,6 @@ class IncrementalTrainer:
|
|
|
else self.config['auto_training']['incremental']['learning_rate'])
|
|
else self.config['auto_training']['incremental']['learning_rate'])
|
|
|
|
|
|
|
|
from_scratch = self.cold_start_mode
|
|
from_scratch = self.cold_start_mode
|
|
|
- inherit_scale = not self.cold_start_mode
|
|
|
|
|
|
|
|
|
|
model_cfg = self.config['auto_training']['model']
|
|
model_cfg = self.config['auto_training']['model']
|
|
|
rollback_enabled = model_cfg.get('rollback_on_degradation', True)
|
|
rollback_enabled = model_cfg.get('rollback_on_degradation', True)
|
|
@@ -973,7 +949,7 @@ class IncrementalTrainer:
|
|
|
try:
|
|
try:
|
|
|
# ── 4a. 单设备特征提取 ──
|
|
# ── 4a. 单设备特征提取 ──
|
|
|
mel_dir, scale = self._extract_mel_for_device(
|
|
mel_dir, scale = self._extract_mel_for_device(
|
|
|
- device_code, wav_files, inherit_scale=inherit_scale
|
|
|
|
|
|
|
+ device_code, wav_files
|
|
|
)
|
|
)
|
|
|
if mel_dir is None:
|
|
if mel_dir is None:
|
|
|
logger.warning(f"{device_code}: 特征提取无有效数据,跳过")
|
|
logger.warning(f"{device_code}: 特征提取无有效数据,跳过")
|
|
@@ -989,52 +965,49 @@ class IncrementalTrainer:
|
|
|
logger.error(f"{device_code}: 形状验证失败,跳过部署")
|
|
logger.error(f"{device_code}: 形状验证失败,跳过部署")
|
|
|
continue
|
|
continue
|
|
|
|
|
|
|
|
- # ── 4d. 损失退化检测(增量训练时生效) ──
|
|
|
|
|
|
|
+ # ── 4d. 新旧模型对比(增量训练时生效) ──
|
|
|
|
|
+ # 在相同验证数据上比较新旧模型的重建误差,新模型更差则跳过部署
|
|
|
if rollback_enabled and not self.cold_start_mode:
|
|
if rollback_enabled and not self.cold_start_mode:
|
|
|
- old_threshold = self._get_old_threshold(device_code)
|
|
|
|
|
- if old_threshold and old_threshold > 0:
|
|
|
|
|
- if final_loss > old_threshold * rollback_factor:
|
|
|
|
|
|
|
+ old_model_path = self.model_root / device_code / "ae_model.pth"
|
|
|
|
|
+ if old_model_path.exists():
|
|
|
|
|
+ new_avg_err = self._eval_model_error(model, mel_dir)
|
|
|
|
|
+ old_model = ConvAutoencoder()
|
|
|
|
|
+ old_model.load_state_dict(
|
|
|
|
|
+ torch.load(old_model_path, map_location='cpu')
|
|
|
|
|
+ )
|
|
|
|
|
+ old_avg_err = self._eval_model_error(old_model, mel_dir)
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(
|
|
|
|
|
+ f" {device_code}: 新旧模型对比 | "
|
|
|
|
|
+ f"旧模型误差={old_avg_err:.6f} 新模型误差={new_avg_err:.6f}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if new_avg_err > old_avg_err * rollback_factor:
|
|
|
logger.warning(
|
|
logger.warning(
|
|
|
- f"{device_code}: 损失退化检测触发 | "
|
|
|
|
|
- f"训练损失={final_loss:.6f} > "
|
|
|
|
|
- f"旧阈值={old_threshold:.6f} × {rollback_factor} = "
|
|
|
|
|
- f"{old_threshold * rollback_factor:.6f}"
|
|
|
|
|
|
|
+ f"{device_code}: 新模型退化 | "
|
|
|
|
|
+ f"新={new_avg_err:.6f} > 旧={old_avg_err:.6f} × {rollback_factor},跳过部署"
|
|
|
)
|
|
)
|
|
|
degraded_count += 1
|
|
degraded_count += 1
|
|
|
continue
|
|
continue
|
|
|
|
|
|
|
|
- # ── 4e. 阈值偏移检测 + 部署 ──
|
|
|
|
|
|
|
+ # ── 4e. 部署 ──
|
|
|
if model_cfg.get('auto_deploy', True):
|
|
if model_cfg.get('auto_deploy', True):
|
|
|
- if rollback_enabled and not self.cold_start_mode:
|
|
|
|
|
- old_threshold = self._get_old_threshold(device_code)
|
|
|
|
|
- if old_threshold and old_threshold > 0:
|
|
|
|
|
- new_threshold = self._compute_threshold(model, mel_dir)
|
|
|
|
|
- drift_ratio = abs(new_threshold - old_threshold) / old_threshold
|
|
|
|
|
- # 记录阈值变化趋势(用于长期漂移监控)
|
|
|
|
|
- self._log_threshold_history(
|
|
|
|
|
- device_code, target_date,
|
|
|
|
|
- old_threshold, new_threshold, final_loss
|
|
|
|
|
- )
|
|
|
|
|
- if drift_ratio > 0.3:
|
|
|
|
|
- logger.warning(
|
|
|
|
|
- f"{device_code}: 阈值偏移告警 | "
|
|
|
|
|
- f"旧={old_threshold:.6f} → "
|
|
|
|
|
- f"新={new_threshold:.6f} | "
|
|
|
|
|
- f"偏移={drift_ratio:.1%}"
|
|
|
|
|
- )
|
|
|
|
|
- if drift_ratio > 1.0:
|
|
|
|
|
- logger.warning(
|
|
|
|
|
- f"{device_code}: 阈值偏移过大"
|
|
|
|
|
- f"(>{drift_ratio:.0%}),跳过部署"
|
|
|
|
|
- )
|
|
|
|
|
- degraded_count += 1
|
|
|
|
|
- continue
|
|
|
|
|
self.deploy_device_model(device_code, model, scale, mel_dir)
|
|
self.deploy_device_model(device_code, model, scale, mel_dir)
|
|
|
|
|
|
|
|
success_count += 1
|
|
success_count += 1
|
|
|
logger.info(f"{device_code}: 训练+部署完成 | loss={final_loss:.6f}")
|
|
logger.info(f"{device_code}: 训练+部署完成 | loss={final_loss:.6f}")
|
|
|
|
|
|
|
|
- # ── 4f. 即时通知该设备模型重载 ──
|
|
|
|
|
|
|
+ # ── 4f. 清理已参与训练的 verified_normal 目录 ──
|
|
|
|
|
+ # 核查确认的音频已被模型吸收,训练后清空释放磁盘空间
|
|
|
|
|
+ verified_dir = self.audio_root / device_code / "verified_normal"
|
|
|
|
|
+ if verified_dir.exists():
|
|
|
|
|
+ v_count = len(list(verified_dir.glob("*")))
|
|
|
|
|
+ if v_count > 0:
|
|
|
|
|
+ shutil.rmtree(verified_dir)
|
|
|
|
|
+ verified_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+ logger.info(f"{device_code}: 已清理 verified_normal ({v_count} 个文件)")
|
|
|
|
|
+
|
|
|
|
|
+ # ── 4g. 即时通知该设备模型重载 ──
|
|
|
if on_device_trained:
|
|
if on_device_trained:
|
|
|
try:
|
|
try:
|
|
|
on_device_trained(device_code)
|
|
on_device_trained(device_code)
|
|
@@ -1045,7 +1018,7 @@ class IncrementalTrainer:
|
|
|
logger.error(f"{device_code}: 训练失败 | {e}", exc_info=True)
|
|
logger.error(f"{device_code}: 训练失败 | {e}", exc_info=True)
|
|
|
|
|
|
|
|
finally:
|
|
finally:
|
|
|
- # ── 4g. 清理该设备的临时 Mel 文件,释放磁盘空间 ──
|
|
|
|
|
|
|
+ # ── 4h. 清理该设备的临时 Mel 文件,释放磁盘空间 ──
|
|
|
device_mel_dir = self.temp_mel_dir / device_code
|
|
device_mel_dir = self.temp_mel_dir / device_code
|
|
|
if device_mel_dir.exists():
|
|
if device_mel_dir.exists():
|
|
|
shutil.rmtree(device_mel_dir)
|
|
shutil.rmtree(device_mel_dir)
|
|
@@ -1064,9 +1037,6 @@ class IncrementalTrainer:
|
|
|
f"{success_count} 个设备部署成功"
|
|
f"{success_count} 个设备部署成功"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # 6. 更新分类器基线
|
|
|
|
|
- self._update_classifier_baseline(device_files)
|
|
|
|
|
-
|
|
|
|
|
logger.info("=" * 60)
|
|
logger.info("=" * 60)
|
|
|
logger.info(f"增量训练完成: {success_count}/{device_count} 个设备成功")
|
|
logger.info(f"增量训练完成: {success_count}/{device_count} 个设备成功")
|
|
|
if degraded_count > 0:
|
|
if degraded_count > 0:
|
|
@@ -1086,60 +1056,6 @@ class IncrementalTrainer:
|
|
|
# 辅助方法
|
|
# 辅助方法
|
|
|
# ========================================
|
|
# ========================================
|
|
|
|
|
|
|
|
- def _get_old_threshold(self, device_code: str) -> float:
|
|
|
|
|
- """
|
|
|
|
|
- 读取设备当前已部署的阈值(训练前的旧阈值)
|
|
|
|
|
-
|
|
|
|
|
- 用于损失退化校验:新模型的训练损失不应远超旧阈值。
|
|
|
|
|
- 阈值文件路径: models/{device_code}/thresholds/threshold_{device_code}.npy
|
|
|
|
|
-
|
|
|
|
|
- 返回:
|
|
|
|
|
- 阈值浮点数,文件不存在时返回 0.0
|
|
|
|
|
- """
|
|
|
|
|
- threshold_file = self.model_root / device_code / "thresholds" / f"threshold_{device_code}.npy"
|
|
|
|
|
- if not threshold_file.exists():
|
|
|
|
|
- return 0.0
|
|
|
|
|
- try:
|
|
|
|
|
- data = np.load(threshold_file)
|
|
|
|
|
- return float(data.flat[0])
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.warning(f"读取旧阈值失败: {device_code} | {e}")
|
|
|
|
|
- return 0.0
|
|
|
|
|
-
|
|
|
|
|
- def _log_threshold_history(self, device_code: str, date_str: str,
|
|
|
|
|
- old_threshold: float, new_threshold: float,
|
|
|
|
|
- train_loss: float):
|
|
|
|
|
- """
|
|
|
|
|
- 记录阈值变化历史到 CSV,用于监控模型长期漂移趋势
|
|
|
|
|
-
|
|
|
|
|
- 文件路径: logs/threshold_history.csv
|
|
|
|
|
- 格式: date,device_code,old_threshold,new_threshold,drift_ratio,train_loss
|
|
|
|
|
- """
|
|
|
|
|
- import csv
|
|
|
|
|
-
|
|
|
|
|
- log_dir = self.deploy_root / "logs"
|
|
|
|
|
- log_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
- csv_path = log_dir / "threshold_history.csv"
|
|
|
|
|
-
|
|
|
|
|
- drift_ratio = (new_threshold - old_threshold) / old_threshold if old_threshold > 0 else 0.0
|
|
|
|
|
- write_header = not csv_path.exists()
|
|
|
|
|
-
|
|
|
|
|
- try:
|
|
|
|
|
- with open(csv_path, 'a', newline='', encoding='utf-8') as f:
|
|
|
|
|
- writer = csv.writer(f)
|
|
|
|
|
- if write_header:
|
|
|
|
|
- writer.writerow([
|
|
|
|
|
- 'date', 'device_code', 'old_threshold', 'new_threshold',
|
|
|
|
|
- 'drift_ratio', 'train_loss'
|
|
|
|
|
- ])
|
|
|
|
|
- writer.writerow([
|
|
|
|
|
- date_str, device_code,
|
|
|
|
|
- f"{old_threshold:.8f}", f"{new_threshold:.8f}",
|
|
|
|
|
- f"{drift_ratio:.4f}", f"{train_loss:.8f}"
|
|
|
|
|
- ])
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.warning(f"写入阈值历史失败: {e}")
|
|
|
|
|
-
|
|
|
|
|
def _validate_model(self, model: nn.Module) -> bool:
|
|
def _validate_model(self, model: nn.Module) -> bool:
|
|
|
# 验证模型输出形状是否合理
|
|
# 验证模型输出形状是否合理
|
|
|
if not self.config['auto_training']['validation']['enabled']:
|
|
if not self.config['auto_training']['validation']['enabled']:
|
|
@@ -1235,57 +1151,6 @@ class IncrementalTrainer:
|
|
|
logger.info(f"恢复完成: {restored} 个设备")
|
|
logger.info(f"恢复完成: {restored} 个设备")
|
|
|
return restored > 0
|
|
return restored > 0
|
|
|
|
|
|
|
|
- def _update_classifier_baseline(self, device_files: Dict[str, List[Path]]):
|
|
|
|
|
- # 从训练数据计算并更新分类器基线
|
|
|
|
|
- logger.info("更新分类器基线")
|
|
|
|
|
-
|
|
|
|
|
- try:
|
|
|
|
|
- import librosa
|
|
|
|
|
- from core.anomaly_classifier import AnomalyClassifier
|
|
|
|
|
-
|
|
|
|
|
- classifier = AnomalyClassifier()
|
|
|
|
|
-
|
|
|
|
|
- all_files = []
|
|
|
|
|
- for files in device_files.values():
|
|
|
|
|
- all_files.extend(files)
|
|
|
|
|
-
|
|
|
|
|
- if not all_files:
|
|
|
|
|
- logger.warning("无音频文件,跳过基线更新")
|
|
|
|
|
- return
|
|
|
|
|
-
|
|
|
|
|
- sample_files = random.sample(all_files, min(50, len(all_files)))
|
|
|
|
|
-
|
|
|
|
|
- all_features = []
|
|
|
|
|
- for wav_file in sample_files:
|
|
|
|
|
- try:
|
|
|
|
|
- y, _ = librosa.load(str(wav_file), sr=CFG.SR, mono=True)
|
|
|
|
|
- if len(y) < CFG.SR:
|
|
|
|
|
- continue
|
|
|
|
|
- features = classifier.extract_features(y, sr=CFG.SR)
|
|
|
|
|
- if features:
|
|
|
|
|
- all_features.append(features)
|
|
|
|
|
- except Exception:
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- if not all_features:
|
|
|
|
|
- logger.warning("无法提取特征,跳过基线更新")
|
|
|
|
|
- return
|
|
|
|
|
-
|
|
|
|
|
- baseline = {}
|
|
|
|
|
- keys = all_features[0].keys()
|
|
|
|
|
- for key in keys:
|
|
|
|
|
- if key == 'has_periodic':
|
|
|
|
|
- values = [f[key] for f in all_features]
|
|
|
|
|
- baseline[key] = sum(values) > len(values) / 2
|
|
|
|
|
- else:
|
|
|
|
|
- values = [f[key] for f in all_features]
|
|
|
|
|
- baseline[key] = float(np.mean(values))
|
|
|
|
|
-
|
|
|
|
|
- classifier.save_baseline(baseline)
|
|
|
|
|
- logger.info(f" 基线已更新 (样本数: {len(all_features)})")
|
|
|
|
|
-
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.warning(f"更新基线失败: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
def main():
|