#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ run_with_auto_training.py ------------------------- 带自动训练的主程序 功能: 1. RTSP拾音器监控(通过 PickupMonitoringSystem) 2. 冷启动检测(无模型时等待数据收集后自动训练) 3. 每日增量训练(默认02:00,可配置) 4. 每日数据清理(默认00:00,可配置) 5. 训练完成后主动通知 MultiModelPredictor 重载,无需等待60秒轮询 """ import sys import signal import logging import threading import time from pathlib import Path from datetime import datetime # 确保项目根目录在 sys.path 中 sys.path.insert(0, str(Path(__file__).parent)) try: from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.triggers.cron import CronTrigger except ImportError: print("错误:缺少依赖库") print("请运行:pip install apscheduler") sys.exit(1) from config.config_manager import ConfigManager def setup_logging(): # 配置日志系统(按文件大小轮转),与 run_pickup_monitor.py 共用同一日志文件 from logging.handlers import RotatingFileHandler # 如果根 logger 已经被上层调用者配置过,则直接复用 root = logging.getLogger() if root.handlers: return log_dir = Path(__file__).parent / "logs" log_dir.mkdir(parents=True, exist_ok=True) log_file = log_dir / "system.log" formatter = logging.Formatter( '%(asctime)s | %(levelname)-8s | %(name)-20s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) # 10MB 单文件上限,保留 2 个备份(总计约 30MB) file_handler = RotatingFileHandler( log_file, maxBytes=10 * 1024 * 1024, backupCount=2, encoding='utf-8' ) file_handler.setFormatter(formatter) console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(formatter) logging.basicConfig( level=logging.INFO, handlers=[file_handler, console_handler] ) logger = logging.getLogger('MainSystem') class ColdStartManager: """ 冷启动管理器 检查注册的设备是否都有对应的模型文件, 缺少模型的设备会被标记为冷启动状态,等待数据收集后自动训练。 """ def __init__(self, deploy_root: Path, config: dict): self.deploy_root = deploy_root self.config = config # 模型根目录(每设备子目录) self.model_root = deploy_root / "models" # 音频数据根目录 self.audio_root = deploy_root / "data" / "audio" # self.audio_root = "/Volumes/mo/水厂正常音频/龙亭" # 冷启动配置 cold_start_cfg = config.get('auto_training', {}).get('cold_start', {}) self.enabled = cold_start_cfg.get('enabled', True) # 等待收集数据的最短时长(小时),避免数据量过少就开始训练 self.wait_hours = cold_start_cfg.get('wait_hours', 2) # 每设备最少样本数 self.min_samples = cold_start_cfg.get('min_samples', 100) # 运行时状态 self.is_cold_start = False self.cold_start_time = None # 缺少模型的设备列表 self.missing_devices = [] def check_cold_start_needed(self, registered_devices: dict) -> bool: """ 检查是否有设备缺少模型 参数: registered_devices: {device_code: model_subdir} 已注册的设备映射 返回: True 表示至少有一个设备缺少模型,需要冷启动 """ if not self.enabled: return False self.missing_devices = [] for device_code, model_subdir in registered_devices.items(): model_path = self.model_root / model_subdir / "ae_model.pth" if not model_path.exists(): self.missing_devices.append(device_code) logger.warning(f"模型不存在: {device_code} -> {model_path}") if self.missing_devices: logger.warning(f"以下设备缺少模型,进入冷启动模式: {self.missing_devices}") return True return False def start_cold_start_mode(self): # 进入冷启动模式,开始计时 self.is_cold_start = True self.cold_start_time = datetime.now() logger.info("=" * 70) logger.info("冷启动模式") logger.info("=" * 70) logger.info(f"缺少模型的设备: {self.missing_devices}") logger.info(f"等待时长: {self.wait_hours} 小时") logger.info(f"每设备最少样本: {self.min_samples}") logger.info("系统将只采集音频,不进行检测(缺模型的设备)") logger.info("=" * 70) def check_ready_for_training(self) -> bool: """ 检查是否已收集到足够的数据可以开始首次训练 条件: 1. 已等待超过 wait_hours 小时 2. 所有缺模型设备的音频文件数 >= min_samples 返回: True 表示数据已就绪 """ if not self.is_cold_start: return False # 条件1:等待时间 elapsed = (datetime.now() - self.cold_start_time).total_seconds() / 3600 if elapsed < self.wait_hours: return False # 条件2:逐设备检查数据量 for device_code in self.missing_devices: device_dir = self.audio_root / device_code if not device_dir.exists(): logger.info(f"数据目录不存在: {device_code}") return False # 统计 current 目录和日期目录下的 wav 文件数 total_samples = 0 for sub_dir in device_dir.iterdir(): if sub_dir.is_dir(): # 新结构:{date}/normal/ 子目录 normal_dir = sub_dir / "normal" if normal_dir.exists(): total_samples += len(list(normal_dir.glob("*.wav"))) # 兼容旧结构 + current 目录:直接存放的 wav total_samples += len(list(sub_dir.glob("*.wav"))) if total_samples < self.min_samples: logger.info(f"数据量不足: {device_code} -> {total_samples}/{self.min_samples}") return False logger.info("所有冷启动设备数据收集完成") return True def run_initial_training(self, on_device_trained=None) -> bool: """ 执行首次训练(冷启动) 使用 IncrementalTrainer 的 cold_start_mode,收集所有可用数据进行全量训练。 Args: on_device_trained: 可选回调 fn(device_code),单设备训练完成后调用 返回: True 表示训练成功 """ logger.info("=" * 70) logger.info("开始首次训练(冷启动)") logger.info("=" * 70) try: from auto_training.incremental_trainer import IncrementalTrainer # 从当前内存中的配置 dict 初始化训练器(配置来源为数据库) trainer = IncrementalTrainer(config=self.config) # 冷启动模式:收集所有目录的数据,用全量训练 trainer.cold_start_mode = True trainer.use_days_ago = 0 trainer.sample_hours = 0 # 首次训练用更多轮数和稍高学习率 trainer.epochs = 90 trainer.learning_rate = 0.001 success = trainer.run_daily_training( on_device_trained=on_device_trained ) if success: logger.info("=" * 70) logger.info("首次训练完成,切换到正常检测模式") logger.info("=" * 70) self.is_cold_start = False return success except Exception as e: logger.error(f"首次训练失败: {e}", exc_info=True) return False class IntegratedSystem: """ 集成系统:冷启动 + 拾音器监控 + 自动增量训练 + 数据清理 作为 run_pickup_monitor.py 的上层封装,在其基础上增加: 1. 冷启动自动训练 2. APScheduler 定时增量训练和数据清理 3. 训练完成后主动触发模型热更新 """ def __init__(self): self.deploy_root = Path(__file__).parent # ========================================================================= # 配置加载来源:自动检测 # 优先使用 YAML(config/rtsp_config.yaml 存在时) # 否则使用 SQLite 数据库 # ========================================================================= yaml_path = self.deploy_root / "config" / "rtsp_config.yaml" if yaml_path.exists(): import yaml with open(yaml_path, 'r', encoding='utf-8') as f: full_config = yaml.safe_load(f) self.full_yaml_config = full_config self.auto_config = {'auto_training': full_config.get('auto_training', {})} logger.info(f"已从 YAML ({yaml_path.name}) 加载配置") else: self.full_yaml_config = None mgr = ConfigManager() self.auto_config = {'auto_training': mgr.get_system_config('auto_training')} mgr.close() logger.info(f"已从数据库加载 auto_training 配置 ({len(self.auto_config.get('auto_training', {}))} 项)") # 运行时对象 self.scheduler = None self.pickup_system = None self.cold_start_manager = None self.cold_start_thread = None def _check_and_handle_cold_start(self) -> bool: """ 检查并处理冷启动 利用 PickupMonitoringSystem 中已注册的设备列表, 检查哪些设备缺少模型文件,对缺失设备启动冷启动训练。 返回: True 表示处于冷启动模式 """ self.cold_start_manager = ColdStartManager( self.deploy_root, self.auto_config ) # 从 pickup_system 的 multi_predictor 获取已注册设备 registered = self.pickup_system.multi_predictor.device_model_map if not self.cold_start_manager.check_cold_start_needed(registered): logger.info("所有设备模型完整,进入正常检测模式") return False # 进入冷启动模式 self.cold_start_manager.start_cold_start_mode() # 启动后台线程监控冷启动状态 self.cold_start_thread = threading.Thread( target=self._cold_start_monitor_loop, daemon=True, name="cold-start-monitor" ) self.cold_start_thread.start() return True def _cold_start_monitor_loop(self): # 后台线程:每分钟检查冷启动数据是否就绪,就绪后触发首次训练 while self.cold_start_manager.is_cold_start: time.sleep(60) if self.cold_start_manager.check_ready_for_training(): success = self.cold_start_manager.run_initial_training( on_device_trained=self._reload_single_device ) if success: break else: # 训练失败,等待10分钟后重试 logger.warning("首次训练失败,将在10分钟后重试") time.sleep(600) def _reload_single_device(self, device_code: str): """训练回调:单个设备训练完成后即时重载该设备模型""" if not self.pickup_system: return mp = self.pickup_system.multi_predictor try: if mp.reload_device(device_code): logger.info(f"模型即时重载成功: {device_code}") else: logger.warning(f"模型即时重载失败: {device_code}") except Exception as e: logger.error(f"模型即时重载异常: {device_code} | {e}") def _reload_models_after_training(self): """ 训练完成后主动触发 MultiModelPredictor 重载所有设备模型 相比被动等待60秒轮询检测 mtime 变化,主动调用 reload_device() 可以: 1. 即时生效(零延迟) 2. 有明确的成功/失败反馈 """ if not self.pickup_system: return mp = self.pickup_system.multi_predictor success_count = 0 fail_count = 0 for device_code in mp.registered_devices: try: if mp.reload_device(device_code): success_count += 1 else: fail_count += 1 except Exception as e: logger.error(f"重载设备模型失败: {device_code} | {e}") fail_count += 1 logger.info(f"模型重载完成: 成功={success_count}, 失败={fail_count}") def _setup_auto_training_tasks(self): # 配置 APScheduler 定时任务:增量训练 + 数据清理 if not self.auto_config.get('auto_training', {}).get('enabled', False): logger.info("自动训练已禁用(auto_training.enabled=false),跳过定时任务配置") return logger.info("=" * 70) logger.info("配置定时任务") logger.info("=" * 70) self.scheduler = BackgroundScheduler() # 定时增量训练 incremental_cfg = self.auto_config['auto_training'].get('incremental', {}) if incremental_cfg.get('enabled', False): schedule_time = incremental_cfg.get('schedule_time', '02:00') hour, minute = map(int, schedule_time.split(':')) self.scheduler.add_job( self._run_incremental_training, trigger=CronTrigger(hour=hour, minute=minute), id='incremental_training', name='每日增量训练', misfire_grace_time=3600 # 错过1小时内仍执行 ) logger.info(f"每日增量训练: 每天 {schedule_time}") # 定时数据清理 data_cfg = self.auto_config['auto_training'].get('data', {}) cleanup_time = data_cfg.get('cleanup_time', '00:00') hour, minute = map(int, cleanup_time.split(':')) self.scheduler.add_job( self._run_data_cleanup, trigger=CronTrigger(hour=hour, minute=minute), id='data_cleanup', name='每日数据清理', misfire_grace_time=3600 ) logger.info(f"每日数据清理: 每天 {cleanup_time}") self.scheduler.start() logger.info("定时任务调度器已启动") logger.info("=" * 70) def _run_incremental_training(self): # 定时任务回调:执行增量训练(逐设备串行,每完成一个即时重载) try: logger.info("定时任务触发:增量训练开始") from auto_training.incremental_trainer import IncrementalTrainer # 传 config dict 而非 YAML 路径,配置来源为数据库 trainer = IncrementalTrainer(config=self.auto_config) success = trainer.run_daily_training( on_device_trained=self._reload_single_device ) if not success: logger.warning("增量训练返回失败,保持当前推理模型不变") except Exception as e: logger.error(f"增量训练异常: {e}", exc_info=True) def _run_data_cleanup(self): # 定时任务回调:执行每日数据清理 try: logger.info("定时任务触发:数据清理开始") from auto_training.data_cleanup import DataCleaner # 传 config dict 而非 YAML 路径,配置来源为数据库 cleaner = DataCleaner(config=self.auto_config) cleaner.run_cleanup() except Exception as e: logger.error(f"数据清理异常: {e}", exc_info=True) def start(self): # 主启动流程 logger.info("=" * 70) logger.info("拾音器异响检测系统(带自动训练)") logger.info("=" * 70) # 1. 创建 PickupMonitoringSystem(会初始化 multi_predictor + 注册设备) logger.info("初始化监控系统...") from run_pickup_monitor import PickupMonitoringSystem self.pickup_system = PickupMonitoringSystem(yaml_config=self.full_yaml_config) # 2. 检查冷启动(需要在 pickup_system 初始化之后,因为需要设备注册信息) is_cold_start = self._check_and_handle_cold_start() # 3. 设置定时任务 self._setup_auto_training_tasks() # 4. 覆盖信号处理(确保关闭 scheduler) signal.signal(signal.SIGINT, self._signal_handler) signal.signal(signal.SIGTERM, self._signal_handler) # 5. 启动拾音器监控(这是阻塞调用,包含主循环) logger.info("启动拾音器监控...") self.pickup_system.start() def stop(self): # 关闭所有组件 logger.info("停止系统...") # 先关 scheduler,避免训练任务在关停过程中触发 if self.scheduler and self.scheduler.running: self.scheduler.shutdown(wait=False) logger.info("定时任务调度器已停止") if self.pickup_system: self.pickup_system.stop() logger.info("系统已停止") def _signal_handler(self, signum, frame): logger.info(f"收到信号 {signum}") self.stop() sys.exit(0) def main(): setup_logging() system = IntegratedSystem() system.start() if __name__ == "__main__": main()