utils.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. # -*- coding: utf-8 -*-
  2. """
  3. utils.py - 部署环境工具函数
  4. ===========================
  5. 部署环境使用的工具函数。
  6. 与训练环境的utils.py功能相同,但去除了训练相关的函数。
  7. """
  8. from pathlib import Path
  9. import re
  10. import torch
  11. import torch.nn.functional as F
  12. import numpy as np
  13. from .config import CFG
  14. def ensure_dirs():
  15. """
  16. 确保部署所需目录存在
  17. 创建以下目录(如不存在):
  18. - AUDIO_DIR: 音频文件
  19. - MODEL_DIR: 模型文件
  20. - THRESHOLD_DIR: 阈值文件
  21. """
  22. for d in ['AUDIO_DIR', 'MODEL_DIR', 'THRESHOLD_DIR']:
  23. if hasattr(CFG, d):
  24. getattr(CFG, d).mkdir(parents=True, exist_ok=True)
  25. def get_device():
  26. """
  27. 获取可用的 PyTorch 计算设备
  28. 优先级: CUDA > CPU
  29. 注意: BM1684X (算能) 不走此函数,它不是 PyTorch 后端。
  30. BM1684X 推理由 bm1684x_engine.py 中的 sophon.sail 独立处理。
  31. 返回:
  32. str: "cuda" 或 "cpu"
  33. """
  34. if torch.cuda.is_available():
  35. return "cuda"
  36. return "cpu"
  37. # --- BM1684X NPU 推理适配(预埋,暂未启用) ---
  38. # BM1684X 不是 PyTorch 后端,不能通过 get_device() 返回。
  39. # 以下函数用于检测 BM1684X 硬件是否可用,供 multi_model_predictor 判断。
  40. # 启用时从 bm1684x_engine.py import is_bm1684x_available 即可。
  41. #
  42. # def is_bm1684x_available() -> bool:
  43. # """检测 BM1684X 硬件是否可用(SDK + 设备节点)"""
  44. # try:
  45. # import sophon.sail # noqa: F401
  46. # except ImportError:
  47. # return False
  48. # import glob
  49. # return len(glob.glob("/dev/bm-sophon*")) > 0
  50. def align_to_target(pred, target):
  51. """
  52. 将预测tensor对齐到目标tensor的尺寸
  53. 处理卷积自编码器可能产生的尺寸偏差。
  54. 参数:
  55. pred: 预测tensor [B, C, H, W]
  56. target: 目标tensor [B, C, H_target, W_target]
  57. 返回:
  58. 对齐后的tensor
  59. """
  60. # 获取目标尺寸
  61. _, _, H_t, W_t = target.shape
  62. _, _, H_p, W_p = pred.shape
  63. x = pred
  64. # H维度对齐
  65. if H_p > H_t:
  66. start = (H_p - H_t) // 2
  67. x = x[:, :, start:start + H_t, :]
  68. elif H_p < H_t:
  69. diff = H_t - H_p
  70. x = F.pad(x, (0, 0, diff // 2, diff - diff // 2))
  71. # W维度对齐
  72. _, _, _, W_p2 = x.shape
  73. if W_p2 > W_t:
  74. start = (W_p2 - W_t) // 2
  75. x = x[:, :, :, start:start + W_t]
  76. elif W_p2 < W_t:
  77. diff = W_t - W_p2
  78. x = F.pad(x, (diff // 2, diff - diff // 2, 0, 0))
  79. return x
  80. def parse_metadata_from_filename(path):
  81. """
  82. 从音频文件名解析元数据
  83. 支持三种格式:
  84. 1. 4段式: {水厂}_ch{通道}_{起始时间}_{结束时间}.wav
  85. 2. 3段式: {水厂}_ch{通道}_{时间}.wav
  86. 3. 新格式: {project_id}_{device_code}_{时间}.wav (如 1450_LT-2_20260115103754.wav)
  87. 参数:
  88. path: 文件路径
  89. 返回:
  90. 元组 (plant_id, pump_id, start_time, end_time)
  91. """
  92. stem = Path(path).stem
  93. # 4段式
  94. m = re.match(r"(.+?)_ch(\d+)_(\d{14})_(\d{14})", stem)
  95. if m:
  96. return m.group(1).strip(), f"ch{m.group(2)}", m.group(3), m.group(4)
  97. # 3段式
  98. m = re.match(r"(.+?)_ch(\d+)_(\d{14})", stem)
  99. if m:
  100. return m.group(1).strip(), f"ch{m.group(2)}", m.group(3), ""
  101. # 新格式: {project_id}_{device_code}_{时间}.wav (如 1450_LT-2_20260115103754.wav)
  102. m = re.match(r"(\d+)_([A-Za-z0-9-]+)_(\d{14})", stem)
  103. if m:
  104. project_id = m.group(1)
  105. device_code = m.group(2)
  106. timestamp = m.group(3)
  107. # 返回 (project_id, device_code, timestamp, "")
  108. # 其中 device_code 作为 pump_id 用于阈值查找
  109. return project_id, device_code, timestamp, ""
  110. raise ValueError(f"文件名格式不符: {stem}")
  111. def load_global_scale():
  112. """
  113. 加载全局标准化参数(已过时)
  114. 注意:此函数加载全局共享的 scale 文件,仅用于向后兼容。
  115. 当前系统使用 DevicePredictor._load_scale() 按设备加载。
  116. 返回:
  117. 元组 (val_0, val_1)
  118. 如果文件不存在返回 (None, None)
  119. """
  120. if not CFG.SCALE_FILE.exists():
  121. return None, None
  122. scale = np.load(CFG.SCALE_FILE)
  123. return float(scale[0]), float(scale[1])