| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- # -*- coding: utf-8 -*-
- """
- utils.py - 部署环境工具函数
- ===========================
- 部署环境使用的工具函数。
- 与训练环境的utils.py功能相同,但去除了训练相关的函数。
- """
- from pathlib import Path
- import re
- import torch
- import torch.nn.functional as F
- import numpy as np
- from .config import CFG
- def ensure_dirs():
- """
- 确保部署所需目录存在
-
- 创建以下目录(如不存在):
- - AUDIO_DIR: 音频文件
- - MODEL_DIR: 模型文件
- - THRESHOLD_DIR: 阈值文件
- """
- for d in ['AUDIO_DIR', 'MODEL_DIR', 'THRESHOLD_DIR']:
- if hasattr(CFG, d):
- getattr(CFG, d).mkdir(parents=True, exist_ok=True)
- def get_device():
- """
- 获取可用的计算设备
-
- 返回:
- str: "cuda" 或 "cpu"
- """
- return "cuda" if torch.cuda.is_available() else "cpu"
- def align_to_target(pred, target):
- """
- 将预测tensor对齐到目标tensor的尺寸
-
- 处理卷积自编码器可能产生的尺寸偏差。
-
- 参数:
- pred: 预测tensor [B, C, H, W]
- target: 目标tensor [B, C, H_target, W_target]
-
- 返回:
- 对齐后的tensor
- """
- # 获取目标尺寸
- _, _, H_t, W_t = target.shape
- _, _, H_p, W_p = pred.shape
-
- x = pred
-
- # H维度对齐
- if H_p > H_t:
- start = (H_p - H_t) // 2
- x = x[:, :, start:start + H_t, :]
- elif H_p < H_t:
- diff = H_t - H_p
- x = F.pad(x, (0, 0, diff // 2, diff - diff // 2))
-
- # W维度对齐
- _, _, _, W_p2 = x.shape
- if W_p2 > W_t:
- start = (W_p2 - W_t) // 2
- x = x[:, :, :, start:start + W_t]
- elif W_p2 < W_t:
- diff = W_t - W_p2
- x = F.pad(x, (diff // 2, diff - diff // 2, 0, 0))
-
- return x
- def parse_metadata_from_filename(path):
- """
- 从音频文件名解析元数据
-
- 支持三种格式:
- 1. 4段式: {水厂}_ch{通道}_{起始时间}_{结束时间}.wav
- 2. 3段式: {水厂}_ch{通道}_{时间}.wav
- 3. 新格式: {project_id}_{device_code}_{时间}.wav (如 1450_LT-2_20260115103754.wav)
-
- 参数:
- path: 文件路径
-
- 返回:
- 元组 (plant_id, pump_id, start_time, end_time)
- """
- stem = Path(path).stem
-
- # 4段式
- m = re.match(r"(.+?)_ch(\d+)_(\d{14})_(\d{14})", stem)
- if m:
- return m.group(1).strip(), f"ch{m.group(2)}", m.group(3), m.group(4)
-
- # 3段式
- m = re.match(r"(.+?)_ch(\d+)_(\d{14})", stem)
- if m:
- return m.group(1).strip(), f"ch{m.group(2)}", m.group(3), ""
-
- # 新格式: {project_id}_{device_code}_{时间}.wav (如 1450_LT-2_20260115103754.wav)
- m = re.match(r"(\d+)_([A-Za-z0-9-]+)_(\d{14})", stem)
- if m:
- project_id = m.group(1)
- device_code = m.group(2)
- timestamp = m.group(3)
- # 返回 (project_id, device_code, timestamp, "")
- # 其中 device_code 作为 pump_id 用于阈值查找
- return project_id, device_code, timestamp, ""
-
- raise ValueError(f"文件名格式不符: {stem}")
- def load_global_scale():
- """
- 加载全局标准化参数(已过时)
- 注意:此函数加载全局共享的 scale 文件,仅用于向后兼容。
- 当前系统使用 DevicePredictor._load_scale() 按设备加载。
- 返回:
- 元组 (val_0, val_1)
- 如果文件不存在返回 (None, None)
- """
- if not CFG.SCALE_FILE.exists():
- return None, None
- scale = np.load(CFG.SCALE_FILE)
- return float(scale[0]), float(scale[1])
|