# -*- coding: utf-8 -*- """ utils.py - 部署环境工具函数 =========================== 部署环境使用的工具函数。 与训练环境的utils.py功能相同,但去除了训练相关的函数。 """ from pathlib import Path import re import torch import torch.nn.functional as F import numpy as np from .config import CFG def ensure_dirs(): """ 确保部署所需目录存在 创建以下目录(如不存在): - AUDIO_DIR: 音频文件 - MODEL_DIR: 模型文件 - THRESHOLD_DIR: 阈值文件 """ for d in ['AUDIO_DIR', 'MODEL_DIR', 'THRESHOLD_DIR']: if hasattr(CFG, d): getattr(CFG, d).mkdir(parents=True, exist_ok=True) def get_device(): """ 获取可用的计算设备 返回: str: "cuda" 或 "cpu" """ return "cuda" if torch.cuda.is_available() else "cpu" def align_to_target(pred, target): """ 将预测tensor对齐到目标tensor的尺寸 处理卷积自编码器可能产生的尺寸偏差。 参数: pred: 预测tensor [B, C, H, W] target: 目标tensor [B, C, H_target, W_target] 返回: 对齐后的tensor """ # 获取目标尺寸 _, _, H_t, W_t = target.shape _, _, H_p, W_p = pred.shape x = pred # H维度对齐 if H_p > H_t: start = (H_p - H_t) // 2 x = x[:, :, start:start + H_t, :] elif H_p < H_t: diff = H_t - H_p x = F.pad(x, (0, 0, diff // 2, diff - diff // 2)) # W维度对齐 _, _, _, W_p2 = x.shape if W_p2 > W_t: start = (W_p2 - W_t) // 2 x = x[:, :, :, start:start + W_t] elif W_p2 < W_t: diff = W_t - W_p2 x = F.pad(x, (diff // 2, diff - diff // 2, 0, 0)) return x def parse_metadata_from_filename(path): """ 从音频文件名解析元数据 支持三种格式: 1. 4段式: {水厂}_ch{通道}_{起始时间}_{结束时间}.wav 2. 3段式: {水厂}_ch{通道}_{时间}.wav 3. 新格式: {project_id}_{device_code}_{时间}.wav (如 1450_LT-2_20260115103754.wav) 参数: path: 文件路径 返回: 元组 (plant_id, pump_id, start_time, end_time) """ stem = Path(path).stem # 4段式 m = re.match(r"(.+?)_ch(\d+)_(\d{14})_(\d{14})", stem) if m: return m.group(1).strip(), f"ch{m.group(2)}", m.group(3), m.group(4) # 3段式 m = re.match(r"(.+?)_ch(\d+)_(\d{14})", stem) if m: return m.group(1).strip(), f"ch{m.group(2)}", m.group(3), "" # 新格式: {project_id}_{device_code}_{时间}.wav (如 1450_LT-2_20260115103754.wav) m = re.match(r"(\d+)_([A-Za-z0-9-]+)_(\d{14})", stem) if m: project_id = m.group(1) device_code = m.group(2) timestamp = m.group(3) # 返回 (project_id, device_code, timestamp, "") # 其中 device_code 作为 pump_id 用于阈值查找 return project_id, device_code, timestamp, "" raise ValueError(f"文件名格式不符: {stem}") def load_global_scale(): """ 加载全局标准化参数(已过时) 注意:此函数加载全局共享的 scale 文件,仅用于向后兼容。 当前系统使用 DevicePredictor._load_scale() 按设备加载。 返回: 元组 (val_0, val_1) 如果文件不存在返回 (None, None) """ if not CFG.SCALE_FILE.exists(): return None, None scale = np.load(CFG.SCALE_FILE) return float(scale[0]), float(scale[1])