anomaly_classifier.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. anomaly_classifier.py
  5. ---------------------
  6. 规则分类器 - 用于判断异常类型
  7. 异常类型编码 (level_two):
  8. - 6: 未分类/一般异常
  9. - 7: 轴承问题
  10. - 8: 气蚀问题
  11. - 9: 松动/共振
  12. - 10: 叶轮问题
  13. - 11: 阀件/冲击
  14. """
  15. import numpy as np
  16. import logging
  17. logger = logging.getLogger('AnomalyClassifier')
  18. # 异常类型与编码映射
  19. ANOMALY_TYPES = {
  20. 'unknown': {'code': 6, 'name': '未分类异常'},
  21. 'bearing': {'code': 7, 'name': '轴承问题'},
  22. 'cavitation': {'code': 8, 'name': '气蚀问题'},
  23. 'loosening': {'code': 9, 'name': '松动/共振'},
  24. 'impeller': {'code': 10, 'name': '叶轮问题'},
  25. 'valve': {'code': 11, 'name': '阀件/冲击'},
  26. }
  27. from pathlib import Path
  28. # 基线文件路径
  29. BASELINE_FILE = Path(__file__).parent / "models" / "classifier_baseline.npy"
  30. # 默认基线(泵正常运行时的典型特征值)
  31. DEFAULT_BASELINE = {
  32. 'rms': 0.02,
  33. 'zcr': 0.05,
  34. 'energy_std': 0.3,
  35. 'spectral_centroid': 2000.0,
  36. 'spectral_bandwidth': 1500.0,
  37. 'spectral_flatness': 0.01,
  38. 'low_energy': 0.1,
  39. 'mid_energy': 0.05,
  40. 'high_energy': 0.02,
  41. 'has_periodic': False,
  42. }
  43. class AnomalyClassifier:
  44. """
  45. 规则分类器
  46. 基于音频特征判断异常类型
  47. """
  48. def __init__(self):
  49. # 正常基线特征
  50. self.normal_baseline = None
  51. # 自动加载基线
  52. self._load_baseline()
  53. def _load_baseline(self):
  54. """
  55. 从文件加载基线,如不存在则使用默认值
  56. """
  57. if BASELINE_FILE.exists():
  58. try:
  59. data = np.load(BASELINE_FILE, allow_pickle=True).item()
  60. self.normal_baseline = data
  61. logger.info(f"已加载分类器基线: {BASELINE_FILE.name}")
  62. except Exception as e:
  63. logger.warning(f"加载基线失败: {e},使用默认基线")
  64. self.normal_baseline = DEFAULT_BASELINE.copy()
  65. else:
  66. # 使用默认基线
  67. self.normal_baseline = DEFAULT_BASELINE.copy()
  68. logger.info("使用默认分类器基线")
  69. def set_baseline(self, baseline_features: dict):
  70. """
  71. 设置正常基线
  72. 参数:
  73. baseline_features: 正常状态的平均特征
  74. """
  75. self.normal_baseline = baseline_features
  76. def save_baseline(self, baseline_features: dict = None):
  77. """
  78. 保存基线到文件
  79. 参数:
  80. baseline_features: 基线特征字典,None则保存当前基线
  81. """
  82. if baseline_features is not None:
  83. self.normal_baseline = baseline_features
  84. if self.normal_baseline is None:
  85. logger.warning("无基线可保存")
  86. return False
  87. try:
  88. BASELINE_FILE.parent.mkdir(parents=True, exist_ok=True)
  89. np.save(BASELINE_FILE, self.normal_baseline)
  90. logger.info(f"已保存分类器基线: {BASELINE_FILE.name}")
  91. return True
  92. except Exception as e:
  93. logger.error(f"保存基线失败: {e}")
  94. return False
  95. def extract_features(self, y: np.ndarray, sr: int = 16000) -> dict:
  96. """
  97. 提取音频特征
  98. 参数:
  99. y: 音频信号
  100. sr: 采样率
  101. 返回:
  102. 特征字典
  103. """
  104. try:
  105. import librosa
  106. except ImportError:
  107. logger.warning("librosa未安装,无法提取特征")
  108. return {}
  109. # 时域特征
  110. rms = float(np.sqrt(np.mean(y**2)))
  111. zcr = float(np.mean(librosa.feature.zero_crossing_rate(y)))
  112. # 能量波动
  113. frame_length = int(0.025 * sr)
  114. hop = int(0.010 * sr)
  115. frames = librosa.util.frame(y, frame_length=frame_length, hop_length=hop)
  116. frame_energies = np.sum(frames**2, axis=0)
  117. energy_std = float(np.std(frame_energies) / (np.mean(frame_energies) + 1e-10))
  118. # 频域特征
  119. spectral_centroid = float(np.mean(librosa.feature.spectral_centroid(y=y, sr=sr)))
  120. spectral_bandwidth = float(np.mean(librosa.feature.spectral_bandwidth(y=y, sr=sr)))
  121. spectral_flatness = float(np.mean(librosa.feature.spectral_flatness(y=y)))
  122. # 频段能量
  123. D = np.abs(librosa.stft(y, n_fft=1024, hop_length=256))
  124. freqs = librosa.fft_frequencies(sr=sr, n_fft=1024)
  125. low_mask = freqs < 1000
  126. mid_mask = (freqs >= 1000) & (freqs < 3000)
  127. high_mask = freqs >= 3000
  128. low_energy = float(np.mean(D[low_mask, :]))
  129. mid_energy = float(np.mean(D[mid_mask, :]))
  130. high_energy = float(np.mean(D[high_mask, :]))
  131. # 周期性检测(简化版)
  132. autocorr = np.correlate(y[:min(len(y), sr)], y[:min(len(y), sr)], mode='full')
  133. autocorr = autocorr[len(autocorr)//2:]
  134. autocorr = autocorr / (autocorr[0] + 1e-10)
  135. min_lag = int(0.01 * sr)
  136. max_lag = int(0.5 * sr)
  137. search_range = autocorr[min_lag:max_lag]
  138. has_periodic = False
  139. if len(search_range) > 0:
  140. peak_value = np.max(search_range)
  141. has_periodic = peak_value > 0.3
  142. return {
  143. 'rms': rms,
  144. 'zcr': zcr,
  145. 'energy_std': energy_std,
  146. 'spectral_centroid': spectral_centroid,
  147. 'spectral_bandwidth': spectral_bandwidth,
  148. 'spectral_flatness': spectral_flatness,
  149. 'low_energy': low_energy,
  150. 'mid_energy': mid_energy,
  151. 'high_energy': high_energy,
  152. 'has_periodic': has_periodic,
  153. }
  154. def classify(self, current_features: dict) -> tuple:
  155. """
  156. 分类异常类型
  157. 参数:
  158. current_features: 当前音频特征
  159. 返回:
  160. (level_two编码, 异常类型名称, 置信度)
  161. """
  162. if self.normal_baseline is None:
  163. # 无基线,返回未分类
  164. return 6, '未分类异常', 0.0
  165. # 计算变化率
  166. def safe_change(curr, base):
  167. if base == 0:
  168. return 0.0
  169. return (curr - base) / base
  170. changes = {
  171. 'high_freq': safe_change(current_features.get('high_energy', 0), self.normal_baseline.get('high_energy', 1)),
  172. 'low_freq': safe_change(current_features.get('low_energy', 0), self.normal_baseline.get('low_energy', 1)),
  173. 'zcr': safe_change(current_features.get('zcr', 0), self.normal_baseline.get('zcr', 1)),
  174. 'centroid': safe_change(current_features.get('spectral_centroid', 0), self.normal_baseline.get('spectral_centroid', 1)),
  175. 'bandwidth': safe_change(current_features.get('spectral_bandwidth', 0), self.normal_baseline.get('spectral_bandwidth', 1)),
  176. 'energy_std': safe_change(current_features.get('energy_std', 0), self.normal_baseline.get('energy_std', 1)),
  177. 'flatness': safe_change(current_features.get('spectral_flatness', 0), self.normal_baseline.get('spectral_flatness', 1)),
  178. 'rms': safe_change(current_features.get('rms', 0), self.normal_baseline.get('rms', 1)),
  179. }
  180. has_periodic = current_features.get('has_periodic', False)
  181. # 规则判断
  182. scores = {
  183. 'bearing': 0.0,
  184. 'cavitation': 0.0,
  185. 'loosening': 0.0,
  186. 'impeller': 0.0,
  187. 'valve': 0.0,
  188. }
  189. # 轴承问题:高频增加、过零率变高、频谱质心上移
  190. if changes['high_freq'] > 0.3:
  191. scores['bearing'] += 0.4
  192. if changes['zcr'] > 0.2:
  193. scores['bearing'] += 0.3
  194. if changes['centroid'] > 0.15:
  195. scores['bearing'] += 0.3
  196. # 气蚀问题:噪声增加、频谱变宽、能量波动大
  197. if changes['bandwidth'] > 0.25:
  198. scores['cavitation'] += 0.35
  199. if changes['energy_std'] > 0.4:
  200. scores['cavitation'] += 0.35
  201. if changes['flatness'] > 0.2:
  202. scores['cavitation'] += 0.3
  203. # 松动/共振:低频增加、周期性冲击
  204. if changes['low_freq'] > 0.35:
  205. scores['loosening'] += 0.5
  206. if has_periodic:
  207. scores['loosening'] += 0.5
  208. # 叶轮问题:能量变化、周期性
  209. if changes['rms'] > 0.3:
  210. scores['impeller'] += 0.35
  211. if abs(changes['centroid']) > 0.25:
  212. scores['impeller'] += 0.35
  213. if has_periodic:
  214. scores['impeller'] += 0.3
  215. # 阀件/冲击:能量波动大、低频有增加
  216. if changes['energy_std'] > 0.5:
  217. scores['valve'] += 0.6
  218. if changes['low_freq'] > 0.2:
  219. scores['valve'] += 0.4
  220. # 选择最高分
  221. if max(scores.values()) < 0.3:
  222. return 6, '未分类异常', 0.0
  223. best_type = max(scores, key=scores.get)
  224. confidence = min(scores[best_type], 1.0)
  225. code = ANOMALY_TYPES[best_type]['code']
  226. name = ANOMALY_TYPES[best_type]['name']
  227. return code, name, confidence
  228. def classify_audio(self, y: np.ndarray, sr: int = 16000) -> tuple:
  229. """
  230. 直接从音频分类
  231. 参数:
  232. y: 音频信号
  233. sr: 采样率
  234. 返回:
  235. (level_two编码, 异常类型名称, 置信度)
  236. """
  237. features = self.extract_features(y, sr)
  238. return self.classify(features)
  239. # 全局分类器实例
  240. _classifier = None
  241. def get_classifier() -> AnomalyClassifier:
  242. """获取全局分类器实例"""
  243. global _classifier
  244. if _classifier is None:
  245. _classifier = AnomalyClassifier()
  246. return _classifier
  247. def classify_anomaly(y: np.ndarray, sr: int = 16000) -> tuple:
  248. """
  249. 便捷函数:分类异常类型
  250. 参数:
  251. y: 音频信号
  252. sr: 采样率
  253. 返回:
  254. (level_two编码, 异常类型名称, 置信度)
  255. """
  256. classifier = get_classifier()
  257. return classifier.classify_audio(y, sr)