multi_model_predictor.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # -*- coding: utf-8 -*-
  2. """
  3. multi_model_predictor.py - 多设备多模型预测器
  4. =============================================
  5. 支持每个设备(device_code)加载独立的模型目录。
  6. 支持模型热加载(检测文件变化后自动重载)。
  7. 使用示例:
  8. predictor = MultiModelPredictor()
  9. predictor.register_device("LT-2", "LT-2")
  10. predictor.register_device("LT-5", "LT-5")
  11. """
  12. import os
  13. import time
  14. import logging
  15. from pathlib import Path
  16. from typing import Dict, Optional, Tuple
  17. import numpy as np
  18. import torch
  19. from .config import CFG, DeployConfig
  20. from .model_def import ConvAutoencoder
  21. from .utils import get_device
  22. logger = logging.getLogger('MultiModelPredictor')
  23. class DevicePredictor:
  24. """
  25. 单设备预测器
  26. 封装一个设备的模型、Min-Max 标准化参数和阈值。
  27. """
  28. def __init__(self, device_code: str, model_subdir: str):
  29. self.device_code = device_code
  30. self.model_subdir = model_subdir
  31. # 模型目录路径
  32. self.model_dir = CFG.MODEL_ROOT / model_subdir
  33. self.model_path = self.model_dir / "ae_model.pth"
  34. self.scale_path = self.model_dir / "global_scale.npy"
  35. self.threshold_dir = self.model_dir / "thresholds"
  36. # 加载资源
  37. self.torch_device = get_device()
  38. self.model = self._load_model()
  39. # Min-Max 参数 (min, max)
  40. self.global_min, self.global_max = self._load_scale()
  41. # 阈值(标量)
  42. self.threshold = self._load_threshold()
  43. # 记录文件 mtime(用于热加载检测)
  44. self._model_mtime = self._get_mtime(self.model_path)
  45. self._scale_mtime = self._get_mtime(self.scale_path)
  46. logger.info(f"设备 {device_code} 模型加载完成 | 目录: {model_subdir} | "
  47. f"阈值: {self.threshold:.6f}")
  48. def _get_mtime(self, path: Path) -> float:
  49. # 获取文件修改时间,不存在返回 0
  50. try:
  51. return os.path.getmtime(path)
  52. except OSError:
  53. return 0.0
  54. def has_files_changed(self) -> bool:
  55. # 检查模型或标准化参数文件是否有更新
  56. new_model_mtime = self._get_mtime(self.model_path)
  57. new_scale_mtime = self._get_mtime(self.scale_path)
  58. return (new_model_mtime != self._model_mtime or
  59. new_scale_mtime != self._scale_mtime)
  60. def _load_model(self) -> ConvAutoencoder:
  61. if not self.model_path.exists():
  62. raise FileNotFoundError(f"模型不存在: {self.model_path}")
  63. model = ConvAutoencoder().to(self.torch_device)
  64. state = torch.load(self.model_path, map_location=self.torch_device)
  65. model.load_state_dict(state)
  66. model.eval()
  67. return model
  68. def _load_scale(self) -> Tuple[float, float]:
  69. # 加载 Min-Max 标准化参数 [min, max]
  70. if not self.scale_path.exists():
  71. raise FileNotFoundError(f"标准化参数不存在: {self.scale_path}")
  72. scale = np.load(self.scale_path)
  73. return float(scale[0]), float(scale[1])
  74. def _load_threshold(self) -> float:
  75. """
  76. 加载阈值文件
  77. Returns:
  78. overall_threshold
  79. """
  80. # 按 device_code 查找
  81. threshold_file = self.threshold_dir / f"threshold_{self.device_code}.npy"
  82. if not threshold_file.exists():
  83. # 尝试 default
  84. threshold_file = self.threshold_dir / "threshold_default.npy"
  85. if threshold_file.exists():
  86. data = np.load(threshold_file)
  87. # 兼容标量和数组格式
  88. return float(data.flat[0])
  89. logger.warning(f"设备 {self.device_code} 无阈值文件,使用默认值 0.01")
  90. return 0.01
  91. class MultiModelPredictor:
  92. """
  93. 多设备多模型预测器
  94. 支持:
  95. - 每设备独立模型、标准化参数和阈值
  96. - 模型热加载(检测文件更新后自动重载)
  97. - 冷启动设备定期重试
  98. """
  99. # 热加载检查间隔(秒)
  100. HOT_RELOAD_INTERVAL = 60
  101. # 失败设备重试间隔(秒)
  102. FAILED_RETRY_INTERVAL = 300
  103. def __init__(self):
  104. # device_code -> model_subdir 映射
  105. self.device_model_map: Dict[str, str] = {}
  106. # device_code -> DevicePredictor 实例(懒加载)
  107. self.predictors: Dict[str, DevicePredictor] = {}
  108. # 加载失败的设备记录 {device_code: 失败时间}
  109. self._failed_devices: Dict[str, float] = {}
  110. # 上次热加载检查时间
  111. self._last_reload_check: float = 0.0
  112. logger.info("MultiModelPredictor 初始化完成")
  113. def register_device(self, device_code: str, model_subdir: str):
  114. self.device_model_map[device_code] = model_subdir
  115. logger.debug(f"注册设备: {device_code} -> models/{model_subdir}/")
  116. def _check_hot_reload(self):
  117. """
  118. 检查所有已加载设备的模型文件是否有更新
  119. 检查频率由 HOT_RELOAD_INTERVAL 控制(默认60秒),避免频繁 stat 调用。
  120. 如果检测到文件变化,重新创建 DevicePredictor 实例替换旧的。
  121. 同时对失败设备定期重试加载。
  122. """
  123. now = time.time()
  124. if now - self._last_reload_check < self.HOT_RELOAD_INTERVAL:
  125. return
  126. self._last_reload_check = now
  127. # 1. 检查已加载设备的模型是否更新
  128. for device_code in list(self.predictors.keys()):
  129. predictor = self.predictors[device_code]
  130. if predictor.has_files_changed():
  131. logger.info(f"检测到模型文件更新: {device_code},执行热加载")
  132. model_subdir = self.device_model_map.get(device_code, device_code)
  133. try:
  134. new_predictor = DevicePredictor(device_code, model_subdir)
  135. self.predictors[device_code] = new_predictor
  136. logger.info(f"热加载成功: {device_code}")
  137. except Exception as e:
  138. logger.error(f"热加载失败: {device_code} | {e}")
  139. # 保留旧的 predictor 继续使用
  140. # 2. 对失败设备定期重试
  141. for device_code in list(self._failed_devices.keys()):
  142. fail_time = self._failed_devices[device_code]
  143. if now - fail_time > self.FAILED_RETRY_INTERVAL:
  144. model_subdir = self.device_model_map.get(device_code, device_code)
  145. try:
  146. predictor = DevicePredictor(device_code, model_subdir)
  147. self.predictors[device_code] = predictor
  148. del self._failed_devices[device_code]
  149. logger.info(f"失败设备重试成功: {device_code}")
  150. except Exception:
  151. # 重试仍然失败,更新时间戳
  152. self._failed_devices[device_code] = now
  153. def get_predictor(self, device_code: str) -> Optional[DevicePredictor]:
  154. # 先执行热加载检查
  155. self._check_hot_reload()
  156. # 已加载
  157. if device_code in self.predictors:
  158. return self.predictors[device_code]
  159. # 已失败且未到重试时间
  160. if device_code in self._failed_devices:
  161. return None
  162. # 首次懒加载
  163. model_subdir = self.device_model_map.get(device_code)
  164. if not model_subdir:
  165. model_subdir = device_code
  166. try:
  167. predictor = DevicePredictor(device_code, model_subdir)
  168. self.predictors[device_code] = predictor
  169. return predictor
  170. except Exception as e:
  171. logger.error(f"加载设备 {device_code} 模型失败: {e}")
  172. self._failed_devices[device_code] = time.time()
  173. return None
  174. def get_threshold(self, device_code: str) -> Optional[float]:
  175. predictor = self.get_predictor(device_code)
  176. if predictor:
  177. return predictor.threshold
  178. return None
  179. def get_scale(self, device_code: str) -> Tuple[Optional[float], Optional[float]]:
  180. # 获取 Z-score 参数 (mean, std)
  181. predictor = self.get_predictor(device_code)
  182. if predictor:
  183. return predictor.global_mean, predictor.global_std
  184. return None, None
  185. def get_model(self, device_code: str) -> Optional[ConvAutoencoder]:
  186. predictor = self.get_predictor(device_code)
  187. if predictor:
  188. return predictor.model
  189. return None
  190. def reload_device(self, device_code: str) -> bool:
  191. """
  192. 手动触发指定设备的模型重载
  193. 用于外部调用(如 API 接口更新模型后通知重载)
  194. Returns:
  195. bool: 是否重载成功
  196. """
  197. model_subdir = self.device_model_map.get(device_code, device_code)
  198. try:
  199. new_predictor = DevicePredictor(device_code, model_subdir)
  200. self.predictors[device_code] = new_predictor
  201. # 清除失败标记
  202. self._failed_devices.pop(device_code, None)
  203. logger.info(f"手动重载成功: {device_code}")
  204. return True
  205. except Exception as e:
  206. logger.error(f"手动重载失败: {device_code} | {e}")
  207. return False
  208. @property
  209. def registered_devices(self) -> list:
  210. return list(self.device_model_map.keys())
  211. @property
  212. def loaded_devices(self) -> list:
  213. return list(self.predictors.keys())