utils.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # -*- coding: utf-8 -*-
  2. """
  3. utils.py - 部署环境工具函数
  4. ===========================
  5. 部署环境使用的工具函数。
  6. 与训练环境的utils.py功能相同,但去除了训练相关的函数。
  7. """
  8. from pathlib import Path
  9. import re
  10. import torch
  11. import torch.nn.functional as F
  12. import numpy as np
  13. from .config import CFG
  14. def ensure_dirs():
  15. """
  16. 确保部署所需目录存在
  17. 创建以下目录(如不存在):
  18. - AUDIO_DIR: 音频文件
  19. - MODEL_DIR: 模型文件
  20. - THRESHOLD_DIR: 阈值文件
  21. """
  22. for d in ['AUDIO_DIR', 'MODEL_DIR', 'THRESHOLD_DIR']:
  23. if hasattr(CFG, d):
  24. getattr(CFG, d).mkdir(parents=True, exist_ok=True)
  25. def get_device():
  26. """
  27. 获取可用的计算设备
  28. 优先级: CUDA > NPU (华为昇腾) > CPU
  29. 返回:
  30. str: "cuda", "npu" 或 "cpu"
  31. """
  32. if torch.cuda.is_available():
  33. return "cuda"
  34. # 华为昇腾 NPU
  35. try:
  36. import torch_npu # noqa: F401
  37. if torch.npu.is_available():
  38. return "npu"
  39. except ImportError:
  40. pass
  41. return "cpu"
  42. def align_to_target(pred, target):
  43. """
  44. 将预测tensor对齐到目标tensor的尺寸
  45. 处理卷积自编码器可能产生的尺寸偏差。
  46. 参数:
  47. pred: 预测tensor [B, C, H, W]
  48. target: 目标tensor [B, C, H_target, W_target]
  49. 返回:
  50. 对齐后的tensor
  51. """
  52. # 获取目标尺寸
  53. _, _, H_t, W_t = target.shape
  54. _, _, H_p, W_p = pred.shape
  55. x = pred
  56. # H维度对齐
  57. if H_p > H_t:
  58. start = (H_p - H_t) // 2
  59. x = x[:, :, start:start + H_t, :]
  60. elif H_p < H_t:
  61. diff = H_t - H_p
  62. x = F.pad(x, (0, 0, diff // 2, diff - diff // 2))
  63. # W维度对齐
  64. _, _, _, W_p2 = x.shape
  65. if W_p2 > W_t:
  66. start = (W_p2 - W_t) // 2
  67. x = x[:, :, :, start:start + W_t]
  68. elif W_p2 < W_t:
  69. diff = W_t - W_p2
  70. x = F.pad(x, (diff // 2, diff - diff // 2, 0, 0))
  71. return x
  72. def parse_metadata_from_filename(path):
  73. """
  74. 从音频文件名解析元数据
  75. 支持三种格式:
  76. 1. 4段式: {水厂}_ch{通道}_{起始时间}_{结束时间}.wav
  77. 2. 3段式: {水厂}_ch{通道}_{时间}.wav
  78. 3. 新格式: {project_id}_{device_code}_{时间}.wav (如 1450_LT-2_20260115103754.wav)
  79. 参数:
  80. path: 文件路径
  81. 返回:
  82. 元组 (plant_id, pump_id, start_time, end_time)
  83. """
  84. stem = Path(path).stem
  85. # 4段式
  86. m = re.match(r"(.+?)_ch(\d+)_(\d{14})_(\d{14})", stem)
  87. if m:
  88. return m.group(1).strip(), f"ch{m.group(2)}", m.group(3), m.group(4)
  89. # 3段式
  90. m = re.match(r"(.+?)_ch(\d+)_(\d{14})", stem)
  91. if m:
  92. return m.group(1).strip(), f"ch{m.group(2)}", m.group(3), ""
  93. # 新格式: {project_id}_{device_code}_{时间}.wav (如 1450_LT-2_20260115103754.wav)
  94. m = re.match(r"(\d+)_([A-Za-z0-9-]+)_(\d{14})", stem)
  95. if m:
  96. project_id = m.group(1)
  97. device_code = m.group(2)
  98. timestamp = m.group(3)
  99. # 返回 (project_id, device_code, timestamp, "")
  100. # 其中 device_code 作为 pump_id 用于阈值查找
  101. return project_id, device_code, timestamp, ""
  102. raise ValueError(f"文件名格式不符: {stem}")
  103. def load_global_scale():
  104. """
  105. 加载全局标准化参数(已过时)
  106. 注意:此函数加载全局共享的 scale 文件,仅用于向后兼容。
  107. 当前系统使用 DevicePredictor._load_scale() 按设备加载。
  108. 返回:
  109. 元组 (val_0, val_1)
  110. 如果文件不存在返回 (None, None)
  111. """
  112. if not CFG.SCALE_FILE.exists():
  113. return None, None
  114. scale = np.load(CFG.SCALE_FILE)
  115. return float(scale[0]), float(scale[1])