| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- run_with_auto_training.py
- -------------------------
- 带自动训练的主程序
- 功能:
- 1. RTSP拾音器监控(通过 PickupMonitoringSystem)
- 2. 冷启动检测(无模型时等待数据收集后自动训练)
- 3. 每日增量训练(默认02:00,可配置)
- 4. 每日数据清理(默认00:00,可配置)
- 5. 训练完成后主动通知 MultiModelPredictor 重载,无需等待60秒轮询
- """
- import sys
- import signal
- import logging
- import threading
- import time
- from pathlib import Path
- from datetime import datetime
- # 确保项目根目录在 sys.path 中
- sys.path.insert(0, str(Path(__file__).parent))
- try:
- from apscheduler.schedulers.background import BackgroundScheduler
- from apscheduler.triggers.cron import CronTrigger
- except ImportError:
- print("错误:缺少依赖库")
- print("请运行:pip install apscheduler")
- sys.exit(1)
- from config.config_manager import ConfigManager
- def setup_logging():
- # 配置日志系统(按文件大小轮转),与 run_pickup_monitor.py 共用同一日志文件
- from logging.handlers import RotatingFileHandler
- # 如果根 logger 已经被上层调用者配置过,则直接复用
- root = logging.getLogger()
- if root.handlers:
- return
- log_dir = Path(__file__).parent / "logs"
- log_dir.mkdir(parents=True, exist_ok=True)
- log_file = log_dir / "system.log"
- formatter = logging.Formatter(
- '%(asctime)s | %(levelname)-8s | %(name)-20s | %(message)s',
- datefmt='%Y-%m-%d %H:%M:%S'
- )
- # 10MB 单文件上限,保留 2 个备份(总计约 30MB)
- file_handler = RotatingFileHandler(
- log_file,
- maxBytes=10 * 1024 * 1024,
- backupCount=2,
- encoding='utf-8'
- )
- file_handler.setFormatter(formatter)
- console_handler = logging.StreamHandler(sys.stdout)
- console_handler.setFormatter(formatter)
- logging.basicConfig(
- level=logging.INFO,
- handlers=[file_handler, console_handler]
- )
- logger = logging.getLogger('MainSystem')
- class ColdStartManager:
- """
- 冷启动管理器
- 检查注册的设备是否都有对应的模型文件,
- 缺少模型的设备会被标记为冷启动状态,等待数据收集后自动训练。
- """
- def __init__(self, deploy_root: Path, config: dict):
- self.deploy_root = deploy_root
- self.config = config
- # 模型根目录(每设备子目录)
- self.model_root = deploy_root / "models"
- # 音频数据根目录
- self.audio_root = deploy_root / "data" / "audio"
- # self.audio_root = "/Volumes/mo/水厂正常音频/龙亭"
- # 冷启动配置
- cold_start_cfg = config.get('auto_training', {}).get('cold_start', {})
- self.enabled = cold_start_cfg.get('enabled', True)
- # 等待收集数据的最短时长(小时),避免数据量过少就开始训练
- self.wait_hours = cold_start_cfg.get('wait_hours', 2)
- # 每设备最少样本数
- self.min_samples = cold_start_cfg.get('min_samples', 100)
- # 运行时状态
- self.is_cold_start = False
- self.cold_start_time = None
- # 缺少模型的设备列表
- self.missing_devices = []
- def check_cold_start_needed(self, registered_devices: dict) -> bool:
- """
- 检查是否有设备缺少模型
- 参数:
- registered_devices: {device_code: model_subdir} 已注册的设备映射
- 返回:
- True 表示至少有一个设备缺少模型,需要冷启动
- """
- if not self.enabled:
- return False
- self.missing_devices = []
- for device_code, model_subdir in registered_devices.items():
- model_path = self.model_root / model_subdir / "ae_model.pth"
- if not model_path.exists():
- self.missing_devices.append(device_code)
- logger.warning(f"模型不存在: {device_code} -> {model_path}")
- if self.missing_devices:
- logger.warning(f"以下设备缺少模型,进入冷启动模式: {self.missing_devices}")
- return True
- return False
- def start_cold_start_mode(self):
- # 进入冷启动模式,开始计时
- self.is_cold_start = True
- self.cold_start_time = datetime.now()
- logger.info("=" * 70)
- logger.info("冷启动模式")
- logger.info("=" * 70)
- logger.info(f"缺少模型的设备: {self.missing_devices}")
- logger.info(f"等待时长: {self.wait_hours} 小时")
- logger.info(f"每设备最少样本: {self.min_samples}")
- logger.info("系统将只采集音频,不进行检测(缺模型的设备)")
- logger.info("=" * 70)
- def check_ready_for_training(self) -> bool:
- """
- 检查是否已收集到足够的数据可以开始首次训练
- 条件:
- 1. 已等待超过 wait_hours 小时
- 2. 所有缺模型设备的音频文件数 >= min_samples
- 返回:
- True 表示数据已就绪
- """
- if not self.is_cold_start:
- return False
- # 条件1:等待时间
- elapsed = (datetime.now() - self.cold_start_time).total_seconds() / 3600
- if elapsed < self.wait_hours:
- return False
- # 条件2:逐设备检查数据量
- for device_code in self.missing_devices:
- device_dir = self.audio_root / device_code
- if not device_dir.exists():
- logger.info(f"数据目录不存在: {device_code}")
- return False
- # 统计 current 目录和日期目录下的 wav 文件数
- total_samples = 0
- for sub_dir in device_dir.iterdir():
- if sub_dir.is_dir():
- # 新结构:{date}/normal/ 子目录
- normal_dir = sub_dir / "normal"
- if normal_dir.exists():
- total_samples += len(list(normal_dir.glob("*.wav")))
- # 兼容旧结构 + current 目录:直接存放的 wav
- total_samples += len(list(sub_dir.glob("*.wav")))
- if total_samples < self.min_samples:
- logger.info(f"数据量不足: {device_code} -> {total_samples}/{self.min_samples}")
- return False
- logger.info("所有冷启动设备数据收集完成")
- return True
- def run_initial_training(self, on_device_trained=None) -> bool:
- """
- 执行首次训练(冷启动)
- 使用 IncrementalTrainer 的 cold_start_mode,收集所有可用数据进行全量训练。
- Args:
- on_device_trained: 可选回调 fn(device_code),单设备训练完成后调用
- 返回:
- True 表示训练成功
- """
- logger.info("=" * 70)
- logger.info("开始首次训练(冷启动)")
- logger.info("=" * 70)
- try:
- from auto_training.incremental_trainer import IncrementalTrainer
- # 从当前内存中的配置 dict 初始化训练器(配置来源为数据库)
- trainer = IncrementalTrainer(config=self.config)
- # 冷启动模式:收集所有目录的数据,用全量训练
- trainer.cold_start_mode = True
- trainer.use_days_ago = 0
- trainer.sample_hours = 0
- # 首次训练用更多轮数和稍高学习率
- trainer.epochs = 90
- trainer.learning_rate = 0.001
- success = trainer.run_daily_training(
- on_device_trained=on_device_trained
- )
- if success:
- logger.info("=" * 70)
- logger.info("首次训练完成,切换到正常检测模式")
- logger.info("=" * 70)
- self.is_cold_start = False
- return success
- except Exception as e:
- logger.error(f"首次训练失败: {e}", exc_info=True)
- return False
- class IntegratedSystem:
- """
- 集成系统:冷启动 + 拾音器监控 + 自动增量训练 + 数据清理
- 作为 run_pickup_monitor.py 的上层封装,在其基础上增加:
- 1. 冷启动自动训练
- 2. APScheduler 定时增量训练和数据清理
- 3. 训练完成后主动触发模型热更新
- """
- def __init__(self):
- self.deploy_root = Path(__file__).parent
- # =========================================================================
- # 配置加载来源:自动检测
- # 优先使用 YAML(config/rtsp_config.yaml 存在时)
- # 否则使用 SQLite 数据库
- # =========================================================================
- yaml_path = self.deploy_root / "config" / "rtsp_config.yaml"
- if yaml_path.exists():
- import yaml
- with open(yaml_path, 'r', encoding='utf-8') as f:
- full_config = yaml.safe_load(f)
- self.full_yaml_config = full_config
- self.auto_config = {'auto_training': full_config.get('auto_training', {})}
- logger.info(f"已从 YAML ({yaml_path.name}) 加载配置")
- else:
- self.full_yaml_config = None
- mgr = ConfigManager()
- self.auto_config = {'auto_training': mgr.get_system_config('auto_training')}
- mgr.close()
- logger.info(f"已从数据库加载 auto_training 配置 ({len(self.auto_config.get('auto_training', {}))} 项)")
- # 运行时对象
- self.scheduler = None
- self.pickup_system = None
- self.cold_start_manager = None
- self.cold_start_thread = None
- def _check_and_handle_cold_start(self) -> bool:
- """
- 检查并处理冷启动
- 利用 PickupMonitoringSystem 中已注册的设备列表,
- 检查哪些设备缺少模型文件,对缺失设备启动冷启动训练。
- 返回:
- True 表示处于冷启动模式
- """
- self.cold_start_manager = ColdStartManager(
- self.deploy_root,
- self.auto_config
- )
- # 从 pickup_system 的 multi_predictor 获取已注册设备
- registered = self.pickup_system.multi_predictor.device_model_map
- if not self.cold_start_manager.check_cold_start_needed(registered):
- logger.info("所有设备模型完整,进入正常检测模式")
- return False
- # 进入冷启动模式
- self.cold_start_manager.start_cold_start_mode()
- # 启动后台线程监控冷启动状态
- self.cold_start_thread = threading.Thread(
- target=self._cold_start_monitor_loop,
- daemon=True,
- name="cold-start-monitor"
- )
- self.cold_start_thread.start()
- return True
- def _cold_start_monitor_loop(self):
- # 后台线程:每分钟检查冷启动数据是否就绪,就绪后触发首次训练
- while self.cold_start_manager.is_cold_start:
- time.sleep(60)
- if self.cold_start_manager.check_ready_for_training():
- success = self.cold_start_manager.run_initial_training(
- on_device_trained=self._reload_single_device
- )
- if success:
- break
- else:
- # 训练失败,等待10分钟后重试
- logger.warning("首次训练失败,将在10分钟后重试")
- time.sleep(600)
- def _reload_single_device(self, device_code: str):
- """训练回调:单个设备训练完成后即时重载该设备模型"""
- if not self.pickup_system:
- return
- mp = self.pickup_system.multi_predictor
- try:
- if mp.reload_device(device_code):
- logger.info(f"模型即时重载成功: {device_code}")
- else:
- logger.warning(f"模型即时重载失败: {device_code}")
- except Exception as e:
- logger.error(f"模型即时重载异常: {device_code} | {e}")
- def _reload_models_after_training(self):
- """
- 训练完成后主动触发 MultiModelPredictor 重载所有设备模型
- 相比被动等待60秒轮询检测 mtime 变化,主动调用 reload_device() 可以:
- 1. 即时生效(零延迟)
- 2. 有明确的成功/失败反馈
- """
- if not self.pickup_system:
- return
- mp = self.pickup_system.multi_predictor
- success_count = 0
- fail_count = 0
- for device_code in mp.registered_devices:
- try:
- if mp.reload_device(device_code):
- success_count += 1
- else:
- fail_count += 1
- except Exception as e:
- logger.error(f"重载设备模型失败: {device_code} | {e}")
- fail_count += 1
- logger.info(f"模型重载完成: 成功={success_count}, 失败={fail_count}")
- def _setup_auto_training_tasks(self):
- # 配置 APScheduler 定时任务:增量训练 + 数据清理
- if not self.auto_config.get('auto_training', {}).get('enabled', False):
- logger.info("自动训练已禁用(auto_training.enabled=false),跳过定时任务配置")
- return
- logger.info("=" * 70)
- logger.info("配置定时任务")
- logger.info("=" * 70)
- self.scheduler = BackgroundScheduler()
- # 定时增量训练
- incremental_cfg = self.auto_config['auto_training'].get('incremental', {})
- if incremental_cfg.get('enabled', False):
- schedule_time = incremental_cfg.get('schedule_time', '02:00')
- hour, minute = map(int, schedule_time.split(':'))
- self.scheduler.add_job(
- self._run_incremental_training,
- trigger=CronTrigger(hour=hour, minute=minute),
- id='incremental_training',
- name='每日增量训练',
- misfire_grace_time=3600 # 错过1小时内仍执行
- )
- logger.info(f"每日增量训练: 每天 {schedule_time}")
- # 定时数据清理
- data_cfg = self.auto_config['auto_training'].get('data', {})
- cleanup_time = data_cfg.get('cleanup_time', '00:00')
- hour, minute = map(int, cleanup_time.split(':'))
- self.scheduler.add_job(
- self._run_data_cleanup,
- trigger=CronTrigger(hour=hour, minute=minute),
- id='data_cleanup',
- name='每日数据清理',
- misfire_grace_time=3600
- )
- logger.info(f"每日数据清理: 每天 {cleanup_time}")
- self.scheduler.start()
- logger.info("定时任务调度器已启动")
- logger.info("=" * 70)
- def _run_incremental_training(self):
- # 定时任务回调:执行增量训练(逐设备串行,每完成一个即时重载)
- try:
- logger.info("定时任务触发:增量训练开始")
- from auto_training.incremental_trainer import IncrementalTrainer
- # 传 config dict 而非 YAML 路径,配置来源为数据库
- trainer = IncrementalTrainer(config=self.auto_config)
- success = trainer.run_daily_training(
- on_device_trained=self._reload_single_device
- )
- if not success:
- logger.warning("增量训练返回失败,保持当前推理模型不变")
- except Exception as e:
- logger.error(f"增量训练异常: {e}", exc_info=True)
- def _run_data_cleanup(self):
- # 定时任务回调:执行每日数据清理
- try:
- logger.info("定时任务触发:数据清理开始")
- from auto_training.data_cleanup import DataCleaner
- # 传 config dict 而非 YAML 路径,配置来源为数据库
- cleaner = DataCleaner(config=self.auto_config)
- cleaner.run_cleanup()
- except Exception as e:
- logger.error(f"数据清理异常: {e}", exc_info=True)
- def start(self):
- # 主启动流程
- logger.info("=" * 70)
- logger.info("拾音器异响检测系统(带自动训练)")
- logger.info("=" * 70)
- # 1. 创建 PickupMonitoringSystem(会初始化 multi_predictor + 注册设备)
- logger.info("初始化监控系统...")
- from run_pickup_monitor import PickupMonitoringSystem
- self.pickup_system = PickupMonitoringSystem(yaml_config=self.full_yaml_config)
- # 2. 检查冷启动(需要在 pickup_system 初始化之后,因为需要设备注册信息)
- is_cold_start = self._check_and_handle_cold_start()
- # 3. 设置定时任务
- self._setup_auto_training_tasks()
- # 4. 覆盖信号处理(确保关闭 scheduler)
- signal.signal(signal.SIGINT, self._signal_handler)
- signal.signal(signal.SIGTERM, self._signal_handler)
- # 5. 启动拾音器监控(这是阻塞调用,包含主循环)
- logger.info("启动拾音器监控...")
- self.pickup_system.start()
- def stop(self):
- # 关闭所有组件
- logger.info("停止系统...")
- # 先关 scheduler,避免训练任务在关停过程中触发
- if self.scheduler and self.scheduler.running:
- self.scheduler.shutdown(wait=False)
- logger.info("定时任务调度器已停止")
- if self.pickup_system:
- self.pickup_system.stop()
- logger.info("系统已停止")
- def _signal_handler(self, signum, frame):
- logger.info(f"收到信号 {signum}")
- self.stop()
- sys.exit(0)
- def main():
- setup_logging()
- system = IntegratedSystem()
- system.start()
- if __name__ == "__main__":
- main()
|