utils.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # -*- coding: utf-8 -*-
  2. """
  3. utils.py - 部署环境工具函数
  4. ===========================
  5. 部署环境使用的工具函数。
  6. 与训练环境的utils.py功能相同,但去除了训练相关的函数。
  7. """
  8. from pathlib import Path
  9. import re
  10. import torch
  11. import torch.nn.functional as F
  12. import numpy as np
  13. from .config import CFG
  14. def ensure_dirs():
  15. """
  16. 确保部署所需目录存在
  17. 创建以下目录(如不存在):
  18. - AUDIO_DIR: 音频文件
  19. - MODEL_DIR: 模型文件
  20. - THRESHOLD_DIR: 阈值文件
  21. """
  22. for d in ['AUDIO_DIR', 'MODEL_DIR', 'THRESHOLD_DIR']:
  23. if hasattr(CFG, d):
  24. getattr(CFG, d).mkdir(parents=True, exist_ok=True)
  25. def get_device():
  26. """
  27. 获取可用的计算设备
  28. 返回:
  29. str: "cuda" 或 "cpu"
  30. """
  31. return "cuda" if torch.cuda.is_available() else "cpu"
  32. def align_to_target(pred, target):
  33. """
  34. 将预测tensor对齐到目标tensor的尺寸
  35. 处理卷积自编码器可能产生的尺寸偏差。
  36. 参数:
  37. pred: 预测tensor [B, C, H, W]
  38. target: 目标tensor [B, C, H_target, W_target]
  39. 返回:
  40. 对齐后的tensor
  41. """
  42. # 获取目标尺寸
  43. _, _, H_t, W_t = target.shape
  44. _, _, H_p, W_p = pred.shape
  45. x = pred
  46. # H维度对齐
  47. if H_p > H_t:
  48. start = (H_p - H_t) // 2
  49. x = x[:, :, start:start + H_t, :]
  50. elif H_p < H_t:
  51. diff = H_t - H_p
  52. x = F.pad(x, (0, 0, diff // 2, diff - diff // 2))
  53. # W维度对齐
  54. _, _, _, W_p2 = x.shape
  55. if W_p2 > W_t:
  56. start = (W_p2 - W_t) // 2
  57. x = x[:, :, :, start:start + W_t]
  58. elif W_p2 < W_t:
  59. diff = W_t - W_p2
  60. x = F.pad(x, (diff // 2, diff - diff // 2, 0, 0))
  61. return x
  62. def parse_metadata_from_filename(path):
  63. """
  64. 从音频文件名解析元数据
  65. 支持三种格式:
  66. 1. 4段式: {水厂}_ch{通道}_{起始时间}_{结束时间}.wav
  67. 2. 3段式: {水厂}_ch{通道}_{时间}.wav
  68. 3. 新格式: {project_id}_{device_code}_{时间}.wav (如 1450_LT-2_20260115103754.wav)
  69. 参数:
  70. path: 文件路径
  71. 返回:
  72. 元组 (plant_id, pump_id, start_time, end_time)
  73. """
  74. stem = Path(path).stem
  75. # 4段式
  76. m = re.match(r"(.+?)_ch(\d+)_(\d{14})_(\d{14})", stem)
  77. if m:
  78. return m.group(1).strip(), f"ch{m.group(2)}", m.group(3), m.group(4)
  79. # 3段式
  80. m = re.match(r"(.+?)_ch(\d+)_(\d{14})", stem)
  81. if m:
  82. return m.group(1).strip(), f"ch{m.group(2)}", m.group(3), ""
  83. # 新格式: {project_id}_{device_code}_{时间}.wav (如 1450_LT-2_20260115103754.wav)
  84. m = re.match(r"(\d+)_([A-Za-z0-9-]+)_(\d{14})", stem)
  85. if m:
  86. project_id = m.group(1)
  87. device_code = m.group(2)
  88. timestamp = m.group(3)
  89. # 返回 (project_id, device_code, timestamp, "")
  90. # 其中 device_code 作为 pump_id 用于阈值查找
  91. return project_id, device_code, timestamp, ""
  92. raise ValueError(f"文件名格式不符: {stem}")
  93. def load_global_scale():
  94. """
  95. 加载全局标准化参数(已过时)
  96. 注意:此函数加载全局共享的 scale 文件,仅用于向后兼容。
  97. 当前系统使用 DevicePredictor._load_scale() 按设备加载。
  98. 返回:
  99. 元组 (val_0, val_1)
  100. 如果文件不存在返回 (None, None)
  101. """
  102. if not CFG.SCALE_FILE.exists():
  103. return None, None
  104. scale = np.load(CFG.SCALE_FILE)
  105. return float(scale[0]), float(scale[1])