multi_model_predictor.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # -*- coding: utf-8 -*-
  2. """
  3. multi_model_predictor.py - 多设备多模型预测器
  4. =============================================
  5. 支持每个设备(device_code)加载独立的模型目录。
  6. 支持模型热加载(检测文件变化后自动重载)。
  7. 使用示例:
  8. predictor = MultiModelPredictor()
  9. predictor.register_device("LT-2", "LT-2")
  10. predictor.register_device("LT-5", "LT-5")
  11. """
  12. import os
  13. import time
  14. import logging
  15. from pathlib import Path
  16. from typing import Dict, Optional, Tuple
  17. import numpy as np
  18. import torch
  19. from .config import CFG, DeployConfig
  20. from .model_def import ConvAutoencoder
  21. from .utils import get_device
  22. # --- BM1684X NPU 推理适配(预埋,暂未启用) ---
  23. # 取消以下注释以启用 BM1684X NPU 推理(需先用 convert_to_bmodel.py 生成 .bmodel)
  24. from .bm1684x_engine import BM1684XEngine, is_bm1684x_available
  25. logger = logging.getLogger('MultiModelPredictor')
  26. class DevicePredictor:
  27. """
  28. 单设备预测器
  29. 封装一个设备的模型、Min-Max 标准化参数和阈值。
  30. """
  31. def __init__(self, device_code: str, model_subdir: str):
  32. self.device_code = device_code
  33. self.model_subdir = model_subdir
  34. # 模型目录路径
  35. self.model_dir = CFG.MODEL_ROOT / model_subdir
  36. self.model_path = self.model_dir / "ae_model.pth"
  37. self.scale_path = self.model_dir / "global_scale.npy"
  38. self.threshold_dir = self.model_dir / "thresholds"
  39. # 加载资源
  40. self.torch_device = get_device()
  41. self.model = self._load_model()
  42. # Min-Max 参数 (min, max)
  43. self.global_min, self.global_max = self._load_scale()
  44. # 阈值(标量)
  45. self.threshold = self._load_threshold()
  46. # --- BM1684X NPU 推理适配(预埋,暂未启用) ---
  47. # 如果 .bmodel 文件存在且 BM1684X 硬件可用,优先使用 NPU 推理
  48. # 启用后 self.bm_engine 不为 None,推理时调用 bm_engine.predict() 而非 self.model()
  49. self.bm_engine = None
  50. self.bm_engine = self._load_bmodel()
  51. # 记录文件 mtime(用于热加载检测)
  52. self._model_mtime = self._get_mtime(self.model_path)
  53. self._scale_mtime = self._get_mtime(self.scale_path)
  54. engine_type = "BM1684X BModel" if self.bm_engine else "PyTorch (.pth)"
  55. logger.info(f"设备 {device_code} 模型加载完成 | 引擎: {engine_type} | 目录: {model_subdir} | "
  56. f"阈值: {self.threshold:.6f}")
  57. def _get_mtime(self, path: Path) -> float:
  58. # 获取文件修改时间,不存在返回 0
  59. try:
  60. return os.path.getmtime(path)
  61. except OSError:
  62. return 0.0
  63. def has_files_changed(self) -> bool:
  64. # 检查模型或标准化参数文件是否有更新
  65. new_model_mtime = self._get_mtime(self.model_path)
  66. new_scale_mtime = self._get_mtime(self.scale_path)
  67. return (new_model_mtime != self._model_mtime or
  68. new_scale_mtime != self._scale_mtime)
  69. def _load_model(self) -> ConvAutoencoder:
  70. if not self.model_path.exists():
  71. raise FileNotFoundError(f"模型不存在: {self.model_path}")
  72. model = ConvAutoencoder().to(self.torch_device)
  73. state = torch.load(self.model_path, map_location=self.torch_device)
  74. model.load_state_dict(state)
  75. model.eval()
  76. return model
  77. # --- BM1684X NPU 推理适配(预埋,暂未启用) ---
  78. # 取消注释以启用 BModel 加载逻辑
  79. def _load_bmodel(self):
  80. """尝试加载 BModel,成功返回 BM1684XEngine 实例,否则返回 None"""
  81. bmodel_path = self.model_dir / "ae_model.bmodel"
  82. if not bmodel_path.exists():
  83. return None
  84. if not is_bm1684x_available():
  85. logger.info(f"设备 {self.device_code}: 存在 .bmodel 但 BM1684X 不可用,回退 PyTorch")
  86. return None
  87. try:
  88. engine = BM1684XEngine(str(bmodel_path))
  89. logger.info(f"设备 {self.device_code}: 已加载 BM1684X BModel 推理引擎")
  90. return engine
  91. except Exception as e:
  92. logger.warning(f"设备 {self.device_code}: BModel 加载失败,回退 PyTorch | {e}")
  93. return None
  94. def _load_scale(self) -> Tuple[float, float]:
  95. # 加载 Min-Max 标准化参数 [min, max]
  96. if not self.scale_path.exists():
  97. raise FileNotFoundError(f"标准化参数不存在: {self.scale_path}")
  98. scale = np.load(self.scale_path)
  99. return float(scale[0]), float(scale[1])
  100. def _load_threshold(self) -> float:
  101. """
  102. 加载阈值文件
  103. Returns:
  104. overall_threshold
  105. """
  106. # 按 device_code 查找
  107. threshold_file = self.threshold_dir / f"threshold_{self.device_code}.npy"
  108. if not threshold_file.exists():
  109. # 尝试 default
  110. threshold_file = self.threshold_dir / "threshold_default.npy"
  111. if threshold_file.exists():
  112. data = np.load(threshold_file)
  113. # 兼容标量和数组格式
  114. return float(data.flat[0])
  115. logger.warning(f"设备 {self.device_code} 无阈值文件,使用默认值 0.01")
  116. return 0.01
  117. class MultiModelPredictor:
  118. """
  119. 多设备多模型预测器
  120. 支持:
  121. - 每设备独立模型、标准化参数和阈值
  122. - 模型热加载(检测文件更新后自动重载)
  123. - 冷启动设备定期重试
  124. """
  125. # 热加载检查间隔(秒)
  126. HOT_RELOAD_INTERVAL = 60
  127. # 失败设备重试间隔(秒)
  128. FAILED_RETRY_INTERVAL = 300
  129. def __init__(self):
  130. # device_code -> model_subdir 映射
  131. self.device_model_map: Dict[str, str] = {}
  132. # device_code -> DevicePredictor 实例(懒加载)
  133. self.predictors: Dict[str, DevicePredictor] = {}
  134. # 加载失败的设备记录 {device_code: 失败时间}
  135. self._failed_devices: Dict[str, float] = {}
  136. # 上次热加载检查时间
  137. self._last_reload_check: float = 0.0
  138. logger.info("MultiModelPredictor 初始化完成")
  139. def register_device(self, device_code: str, model_subdir: str):
  140. self.device_model_map[device_code] = model_subdir
  141. logger.debug(f"注册设备: {device_code} -> models/{model_subdir}/")
  142. def _check_hot_reload(self):
  143. """
  144. 检查所有已加载设备的模型文件是否有更新
  145. 检查频率由 HOT_RELOAD_INTERVAL 控制(默认60秒),避免频繁 stat 调用。
  146. 如果检测到文件变化,重新创建 DevicePredictor 实例替换旧的。
  147. 同时对失败设备定期重试加载。
  148. """
  149. now = time.time()
  150. if now - self._last_reload_check < self.HOT_RELOAD_INTERVAL:
  151. return
  152. self._last_reload_check = now
  153. # 1. 检查已加载设备的模型是否更新
  154. for device_code in list(self.predictors.keys()):
  155. predictor = self.predictors[device_code]
  156. if predictor.has_files_changed():
  157. logger.info(f"检测到模型文件更新: {device_code},执行热加载")
  158. model_subdir = self.device_model_map.get(device_code, device_code)
  159. try:
  160. new_predictor = DevicePredictor(device_code, model_subdir)
  161. self.predictors[device_code] = new_predictor
  162. logger.info(f"热加载成功: {device_code}")
  163. except Exception as e:
  164. logger.error(f"热加载失败: {device_code} | {e}")
  165. # 保留旧的 predictor 继续使用
  166. # 2. 对失败设备定期重试
  167. for device_code in list(self._failed_devices.keys()):
  168. fail_time = self._failed_devices[device_code]
  169. if now - fail_time > self.FAILED_RETRY_INTERVAL:
  170. model_subdir = self.device_model_map.get(device_code, device_code)
  171. try:
  172. predictor = DevicePredictor(device_code, model_subdir)
  173. self.predictors[device_code] = predictor
  174. del self._failed_devices[device_code]
  175. logger.info(f"失败设备重试成功: {device_code}")
  176. except Exception:
  177. # 重试仍然失败,更新时间戳
  178. self._failed_devices[device_code] = now
  179. def get_predictor(self, device_code: str) -> Optional[DevicePredictor]:
  180. # 先执行热加载检查
  181. self._check_hot_reload()
  182. # 已加载
  183. if device_code in self.predictors:
  184. return self.predictors[device_code]
  185. # 已失败且未到重试时间
  186. if device_code in self._failed_devices:
  187. return None
  188. # 首次懒加载
  189. model_subdir = self.device_model_map.get(device_code)
  190. if not model_subdir:
  191. model_subdir = device_code
  192. try:
  193. predictor = DevicePredictor(device_code, model_subdir)
  194. self.predictors[device_code] = predictor
  195. return predictor
  196. except Exception as e:
  197. logger.error(f"加载设备 {device_code} 模型失败: {e}")
  198. self._failed_devices[device_code] = time.time()
  199. return None
  200. def get_threshold(self, device_code: str) -> Optional[float]:
  201. predictor = self.get_predictor(device_code)
  202. if predictor:
  203. return predictor.threshold
  204. return None
  205. def get_scale(self, device_code: str) -> Tuple[Optional[float], Optional[float]]:
  206. # 获取 Z-score 参数 (mean, std)
  207. predictor = self.get_predictor(device_code)
  208. if predictor:
  209. return predictor.global_mean, predictor.global_std
  210. return None, None
  211. def get_model(self, device_code: str) -> Optional[ConvAutoencoder]:
  212. predictor = self.get_predictor(device_code)
  213. if predictor:
  214. return predictor.model
  215. return None
  216. def reload_device(self, device_code: str) -> bool:
  217. """
  218. 手动触发指定设备的模型重载
  219. 用于外部调用(如 API 接口更新模型后通知重载)
  220. Returns:
  221. bool: 是否重载成功
  222. """
  223. model_subdir = self.device_model_map.get(device_code, device_code)
  224. try:
  225. new_predictor = DevicePredictor(device_code, model_subdir)
  226. self.predictors[device_code] = new_predictor
  227. # 清除失败标记
  228. self._failed_devices.pop(device_code, None)
  229. logger.info(f"手动重载成功: {device_code}")
  230. return True
  231. except Exception as e:
  232. logger.error(f"手动重载失败: {device_code} | {e}")
  233. return False
  234. @property
  235. def registered_devices(self) -> list:
  236. return list(self.device_model_map.keys())
  237. @property
  238. def loaded_devices(self) -> list:
  239. return list(self.predictors.keys())