| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263 |
- # -*- coding: utf-8 -*-
- """
- multi_model_predictor.py - 多设备多模型预测器
- =============================================
- 支持每个设备(device_code)加载独立的模型目录。
- 支持模型热加载(检测文件变化后自动重载)。
- 使用示例:
- predictor = MultiModelPredictor()
- predictor.register_device("LT-2", "LT-2")
- predictor.register_device("LT-5", "LT-5")
- """
- import os
- import time
- import logging
- from pathlib import Path
- from typing import Dict, Optional, Tuple
- import numpy as np
- import torch
- from .config import CFG, DeployConfig
- from .model_def import ConvAutoencoder
- from .utils import get_device
- logger = logging.getLogger('MultiModelPredictor')
- class DevicePredictor:
- """
- 单设备预测器
- 封装一个设备的模型、Min-Max 标准化参数和阈值。
- """
- def __init__(self, device_code: str, model_subdir: str):
- self.device_code = device_code
- self.model_subdir = model_subdir
- # 模型目录路径
- self.model_dir = CFG.MODEL_ROOT / model_subdir
- self.model_path = self.model_dir / "ae_model.pth"
- self.scale_path = self.model_dir / "global_scale.npy"
- self.threshold_dir = self.model_dir / "thresholds"
- # 加载资源
- self.torch_device = get_device()
- self.model = self._load_model()
- # Min-Max 参数 (min, max)
- self.global_min, self.global_max = self._load_scale()
- # 阈值(标量)
- self.threshold = self._load_threshold()
- # 记录文件 mtime(用于热加载检测)
- self._model_mtime = self._get_mtime(self.model_path)
- self._scale_mtime = self._get_mtime(self.scale_path)
- logger.info(f"设备 {device_code} 模型加载完成 | 目录: {model_subdir} | "
- f"阈值: {self.threshold:.6f}")
- def _get_mtime(self, path: Path) -> float:
- # 获取文件修改时间,不存在返回 0
- try:
- return os.path.getmtime(path)
- except OSError:
- return 0.0
- def has_files_changed(self) -> bool:
- # 检查模型或标准化参数文件是否有更新
- new_model_mtime = self._get_mtime(self.model_path)
- new_scale_mtime = self._get_mtime(self.scale_path)
- return (new_model_mtime != self._model_mtime or
- new_scale_mtime != self._scale_mtime)
- def _load_model(self) -> ConvAutoencoder:
- if not self.model_path.exists():
- raise FileNotFoundError(f"模型不存在: {self.model_path}")
- model = ConvAutoencoder().to(self.torch_device)
- state = torch.load(self.model_path, map_location=self.torch_device)
- model.load_state_dict(state)
- model.eval()
- return model
- def _load_scale(self) -> Tuple[float, float]:
- # 加载 Min-Max 标准化参数 [min, max]
- if not self.scale_path.exists():
- raise FileNotFoundError(f"标准化参数不存在: {self.scale_path}")
- scale = np.load(self.scale_path)
- return float(scale[0]), float(scale[1])
- def _load_threshold(self) -> float:
- """
- 加载阈值文件
- Returns:
- overall_threshold
- """
- # 按 device_code 查找
- threshold_file = self.threshold_dir / f"threshold_{self.device_code}.npy"
- if not threshold_file.exists():
- # 尝试 default
- threshold_file = self.threshold_dir / "threshold_default.npy"
- if threshold_file.exists():
- data = np.load(threshold_file)
- # 兼容标量和数组格式
- return float(data.flat[0])
- logger.warning(f"设备 {self.device_code} 无阈值文件,使用默认值 0.01")
- return 0.01
- class MultiModelPredictor:
- """
- 多设备多模型预测器
- 支持:
- - 每设备独立模型、标准化参数和阈值
- - 模型热加载(检测文件更新后自动重载)
- - 冷启动设备定期重试
- """
- # 热加载检查间隔(秒)
- HOT_RELOAD_INTERVAL = 60
- # 失败设备重试间隔(秒)
- FAILED_RETRY_INTERVAL = 300
- def __init__(self):
- # device_code -> model_subdir 映射
- self.device_model_map: Dict[str, str] = {}
- # device_code -> DevicePredictor 实例(懒加载)
- self.predictors: Dict[str, DevicePredictor] = {}
- # 加载失败的设备记录 {device_code: 失败时间}
- self._failed_devices: Dict[str, float] = {}
- # 上次热加载检查时间
- self._last_reload_check: float = 0.0
- logger.info("MultiModelPredictor 初始化完成")
- def register_device(self, device_code: str, model_subdir: str):
- self.device_model_map[device_code] = model_subdir
- logger.debug(f"注册设备: {device_code} -> models/{model_subdir}/")
- def _check_hot_reload(self):
- """
- 检查所有已加载设备的模型文件是否有更新
- 检查频率由 HOT_RELOAD_INTERVAL 控制(默认60秒),避免频繁 stat 调用。
- 如果检测到文件变化,重新创建 DevicePredictor 实例替换旧的。
- 同时对失败设备定期重试加载。
- """
- now = time.time()
- if now - self._last_reload_check < self.HOT_RELOAD_INTERVAL:
- return
- self._last_reload_check = now
- # 1. 检查已加载设备的模型是否更新
- for device_code in list(self.predictors.keys()):
- predictor = self.predictors[device_code]
- if predictor.has_files_changed():
- logger.info(f"检测到模型文件更新: {device_code},执行热加载")
- model_subdir = self.device_model_map.get(device_code, device_code)
- try:
- new_predictor = DevicePredictor(device_code, model_subdir)
- self.predictors[device_code] = new_predictor
- logger.info(f"热加载成功: {device_code}")
- except Exception as e:
- logger.error(f"热加载失败: {device_code} | {e}")
- # 保留旧的 predictor 继续使用
- # 2. 对失败设备定期重试
- for device_code in list(self._failed_devices.keys()):
- fail_time = self._failed_devices[device_code]
- if now - fail_time > self.FAILED_RETRY_INTERVAL:
- model_subdir = self.device_model_map.get(device_code, device_code)
- try:
- predictor = DevicePredictor(device_code, model_subdir)
- self.predictors[device_code] = predictor
- del self._failed_devices[device_code]
- logger.info(f"失败设备重试成功: {device_code}")
- except Exception:
- # 重试仍然失败,更新时间戳
- self._failed_devices[device_code] = now
- def get_predictor(self, device_code: str) -> Optional[DevicePredictor]:
- # 先执行热加载检查
- self._check_hot_reload()
- # 已加载
- if device_code in self.predictors:
- return self.predictors[device_code]
- # 已失败且未到重试时间
- if device_code in self._failed_devices:
- return None
- # 首次懒加载
- model_subdir = self.device_model_map.get(device_code)
- if not model_subdir:
- model_subdir = device_code
- try:
- predictor = DevicePredictor(device_code, model_subdir)
- self.predictors[device_code] = predictor
- return predictor
- except Exception as e:
- logger.error(f"加载设备 {device_code} 模型失败: {e}")
- self._failed_devices[device_code] = time.time()
- return None
- def get_threshold(self, device_code: str) -> Optional[float]:
- predictor = self.get_predictor(device_code)
- if predictor:
- return predictor.threshold
- return None
- def get_scale(self, device_code: str) -> Tuple[Optional[float], Optional[float]]:
- # 获取 Z-score 参数 (mean, std)
- predictor = self.get_predictor(device_code)
- if predictor:
- return predictor.global_mean, predictor.global_std
- return None, None
- def get_model(self, device_code: str) -> Optional[ConvAutoencoder]:
- predictor = self.get_predictor(device_code)
- if predictor:
- return predictor.model
- return None
- def reload_device(self, device_code: str) -> bool:
- """
- 手动触发指定设备的模型重载
- 用于外部调用(如 API 接口更新模型后通知重载)
- Returns:
- bool: 是否重载成功
- """
- model_subdir = self.device_model_map.get(device_code, device_code)
- try:
- new_predictor = DevicePredictor(device_code, model_subdir)
- self.predictors[device_code] = new_predictor
- # 清除失败标记
- self._failed_devices.pop(device_code, None)
- logger.info(f"手动重载成功: {device_code}")
- return True
- except Exception as e:
- logger.error(f"手动重载失败: {device_code} | {e}")
- return False
- @property
- def registered_devices(self) -> list:
- return list(self.device_model_map.keys())
- @property
- def loaded_devices(self) -> list:
- return list(self.predictors.keys())
|