# -*- coding: utf-8 -*- """ multi_model_predictor.py - 多设备多模型预测器 ============================================= 支持每个设备(device_code)加载独立的模型目录。 支持模型热加载(检测文件变化后自动重载)。 使用示例: predictor = MultiModelPredictor() predictor.register_device("LT-2", "LT-2") predictor.register_device("LT-5", "LT-5") """ import os import time import logging from pathlib import Path from typing import Dict, Optional, Tuple import numpy as np import torch from .config import CFG, DeployConfig from .model_def import ConvAutoencoder from .utils import get_device logger = logging.getLogger('MultiModelPredictor') class DevicePredictor: """ 单设备预测器 封装一个设备的模型、Min-Max 标准化参数和阈值。 """ def __init__(self, device_code: str, model_subdir: str): self.device_code = device_code self.model_subdir = model_subdir # 模型目录路径 self.model_dir = CFG.MODEL_ROOT / model_subdir self.model_path = self.model_dir / "ae_model.pth" self.scale_path = self.model_dir / "global_scale.npy" self.threshold_dir = self.model_dir / "thresholds" # 加载资源 self.torch_device = get_device() self.model = self._load_model() # Min-Max 参数 (min, max) self.global_min, self.global_max = self._load_scale() # 阈值(标量) self.threshold = self._load_threshold() # 记录文件 mtime(用于热加载检测) self._model_mtime = self._get_mtime(self.model_path) self._scale_mtime = self._get_mtime(self.scale_path) logger.info(f"设备 {device_code} 模型加载完成 | 目录: {model_subdir} | " f"阈值: {self.threshold:.6f}") def _get_mtime(self, path: Path) -> float: # 获取文件修改时间,不存在返回 0 try: return os.path.getmtime(path) except OSError: return 0.0 def has_files_changed(self) -> bool: # 检查模型或标准化参数文件是否有更新 new_model_mtime = self._get_mtime(self.model_path) new_scale_mtime = self._get_mtime(self.scale_path) return (new_model_mtime != self._model_mtime or new_scale_mtime != self._scale_mtime) def _load_model(self) -> ConvAutoencoder: if not self.model_path.exists(): raise FileNotFoundError(f"模型不存在: {self.model_path}") model = ConvAutoencoder().to(self.torch_device) state = torch.load(self.model_path, map_location=self.torch_device) model.load_state_dict(state) model.eval() return model def _load_scale(self) -> Tuple[float, float]: # 加载 Min-Max 标准化参数 [min, max] if not self.scale_path.exists(): raise FileNotFoundError(f"标准化参数不存在: {self.scale_path}") scale = np.load(self.scale_path) return float(scale[0]), float(scale[1]) def _load_threshold(self) -> float: """ 加载阈值文件 Returns: overall_threshold """ # 按 device_code 查找 threshold_file = self.threshold_dir / f"threshold_{self.device_code}.npy" if not threshold_file.exists(): # 尝试 default threshold_file = self.threshold_dir / "threshold_default.npy" if threshold_file.exists(): data = np.load(threshold_file) # 兼容标量和数组格式 return float(data.flat[0]) logger.warning(f"设备 {self.device_code} 无阈值文件,使用默认值 0.01") return 0.01 class MultiModelPredictor: """ 多设备多模型预测器 支持: - 每设备独立模型、标准化参数和阈值 - 模型热加载(检测文件更新后自动重载) - 冷启动设备定期重试 """ # 热加载检查间隔(秒) HOT_RELOAD_INTERVAL = 60 # 失败设备重试间隔(秒) FAILED_RETRY_INTERVAL = 300 def __init__(self): # device_code -> model_subdir 映射 self.device_model_map: Dict[str, str] = {} # device_code -> DevicePredictor 实例(懒加载) self.predictors: Dict[str, DevicePredictor] = {} # 加载失败的设备记录 {device_code: 失败时间} self._failed_devices: Dict[str, float] = {} # 上次热加载检查时间 self._last_reload_check: float = 0.0 logger.info("MultiModelPredictor 初始化完成") def register_device(self, device_code: str, model_subdir: str): self.device_model_map[device_code] = model_subdir logger.debug(f"注册设备: {device_code} -> models/{model_subdir}/") def _check_hot_reload(self): """ 检查所有已加载设备的模型文件是否有更新 检查频率由 HOT_RELOAD_INTERVAL 控制(默认60秒),避免频繁 stat 调用。 如果检测到文件变化,重新创建 DevicePredictor 实例替换旧的。 同时对失败设备定期重试加载。 """ now = time.time() if now - self._last_reload_check < self.HOT_RELOAD_INTERVAL: return self._last_reload_check = now # 1. 检查已加载设备的模型是否更新 for device_code in list(self.predictors.keys()): predictor = self.predictors[device_code] if predictor.has_files_changed(): logger.info(f"检测到模型文件更新: {device_code},执行热加载") model_subdir = self.device_model_map.get(device_code, device_code) try: new_predictor = DevicePredictor(device_code, model_subdir) self.predictors[device_code] = new_predictor logger.info(f"热加载成功: {device_code}") except Exception as e: logger.error(f"热加载失败: {device_code} | {e}") # 保留旧的 predictor 继续使用 # 2. 对失败设备定期重试 for device_code in list(self._failed_devices.keys()): fail_time = self._failed_devices[device_code] if now - fail_time > self.FAILED_RETRY_INTERVAL: model_subdir = self.device_model_map.get(device_code, device_code) try: predictor = DevicePredictor(device_code, model_subdir) self.predictors[device_code] = predictor del self._failed_devices[device_code] logger.info(f"失败设备重试成功: {device_code}") except Exception: # 重试仍然失败,更新时间戳 self._failed_devices[device_code] = now def get_predictor(self, device_code: str) -> Optional[DevicePredictor]: # 先执行热加载检查 self._check_hot_reload() # 已加载 if device_code in self.predictors: return self.predictors[device_code] # 已失败且未到重试时间 if device_code in self._failed_devices: return None # 首次懒加载 model_subdir = self.device_model_map.get(device_code) if not model_subdir: model_subdir = device_code try: predictor = DevicePredictor(device_code, model_subdir) self.predictors[device_code] = predictor return predictor except Exception as e: logger.error(f"加载设备 {device_code} 模型失败: {e}") self._failed_devices[device_code] = time.time() return None def get_threshold(self, device_code: str) -> Optional[float]: predictor = self.get_predictor(device_code) if predictor: return predictor.threshold return None def get_scale(self, device_code: str) -> Tuple[Optional[float], Optional[float]]: # 获取 Z-score 参数 (mean, std) predictor = self.get_predictor(device_code) if predictor: return predictor.global_mean, predictor.global_std return None, None def get_model(self, device_code: str) -> Optional[ConvAutoencoder]: predictor = self.get_predictor(device_code) if predictor: return predictor.model return None def reload_device(self, device_code: str) -> bool: """ 手动触发指定设备的模型重载 用于外部调用(如 API 接口更新模型后通知重载) Returns: bool: 是否重载成功 """ model_subdir = self.device_model_map.get(device_code, device_code) try: new_predictor = DevicePredictor(device_code, model_subdir) self.predictors[device_code] = new_predictor # 清除失败标记 self._failed_devices.pop(device_code, None) logger.info(f"手动重载成功: {device_code}") return True except Exception as e: logger.error(f"手动重载失败: {device_code} | {e}") return False @property def registered_devices(self) -> list: return list(self.device_model_map.keys()) @property def loaded_devices(self) -> list: return list(self.predictors.keys())