run_with_auto_training.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. run_with_auto_training.py
  5. -------------------------
  6. 带自动训练的主程序
  7. 功能:
  8. 1. RTSP拾音器监控(通过 PickupMonitoringSystem)
  9. 2. 冷启动检测(无模型时等待数据收集后自动训练)
  10. 3. 每日增量训练(默认02:00,可配置)
  11. 4. 每日数据清理(默认00:00,可配置)
  12. 5. 训练完成后主动通知 MultiModelPredictor 重载,无需等待60秒轮询
  13. """
  14. import sys
  15. import signal
  16. import logging
  17. import threading
  18. import time
  19. from pathlib import Path
  20. from datetime import datetime
  21. # 确保项目根目录在 sys.path 中
  22. sys.path.insert(0, str(Path(__file__).parent))
  23. try:
  24. from apscheduler.schedulers.background import BackgroundScheduler
  25. from apscheduler.triggers.cron import CronTrigger
  26. import yaml
  27. except ImportError:
  28. print("错误:缺少依赖库")
  29. print("请运行:pip install apscheduler pyyaml")
  30. sys.exit(1)
  31. def setup_logging():
  32. # 配置日志系统(按文件大小轮转),与 run_pickup_monitor.py 共用同一日志文件
  33. from logging.handlers import RotatingFileHandler
  34. # 如果根 logger 已经被上层调用者配置过,则直接复用
  35. root = logging.getLogger()
  36. if root.handlers:
  37. return
  38. log_dir = Path(__file__).parent / "logs"
  39. log_dir.mkdir(parents=True, exist_ok=True)
  40. log_file = log_dir / "system.log"
  41. formatter = logging.Formatter(
  42. '%(asctime)s | %(levelname)-8s | %(name)-20s | %(message)s',
  43. datefmt='%Y-%m-%d %H:%M:%S'
  44. )
  45. # 10MB 单文件上限,保留 2 个备份(总计约 30MB)
  46. file_handler = RotatingFileHandler(
  47. log_file,
  48. maxBytes=10 * 1024 * 1024,
  49. backupCount=2,
  50. encoding='utf-8'
  51. )
  52. file_handler.setFormatter(formatter)
  53. console_handler = logging.StreamHandler(sys.stdout)
  54. console_handler.setFormatter(formatter)
  55. logging.basicConfig(
  56. level=logging.INFO,
  57. handlers=[file_handler, console_handler]
  58. )
  59. logger = logging.getLogger('MainSystem')
  60. class ColdStartManager:
  61. """
  62. 冷启动管理器
  63. 检查注册的设备是否都有对应的模型文件,
  64. 缺少模型的设备会被标记为冷启动状态,等待数据收集后自动训练。
  65. """
  66. def __init__(self, deploy_root: Path, config: dict):
  67. self.deploy_root = deploy_root
  68. self.config = config
  69. # 模型根目录(每设备子目录)
  70. self.model_root = deploy_root / "models"
  71. # 音频数据根目录
  72. self.audio_root = deploy_root / "data" / "audio"
  73. # 冷启动配置
  74. cold_start_cfg = config.get('auto_training', {}).get('cold_start', {})
  75. self.enabled = cold_start_cfg.get('enabled', True)
  76. # 等待收集数据的最短时长(小时),避免数据量过少就开始训练
  77. self.wait_hours = cold_start_cfg.get('wait_hours', 2)
  78. # 每设备最少样本数
  79. self.min_samples = cold_start_cfg.get('min_samples', 100)
  80. # 运行时状态
  81. self.is_cold_start = False
  82. self.cold_start_time = None
  83. # 缺少模型的设备列表
  84. self.missing_devices = []
  85. def check_cold_start_needed(self, registered_devices: dict) -> bool:
  86. """
  87. 检查是否有设备缺少模型
  88. 参数:
  89. registered_devices: {device_code: model_subdir} 已注册的设备映射
  90. 返回:
  91. True 表示至少有一个设备缺少模型,需要冷启动
  92. """
  93. if not self.enabled:
  94. return False
  95. self.missing_devices = []
  96. for device_code, model_subdir in registered_devices.items():
  97. model_path = self.model_root / model_subdir / "ae_model.pth"
  98. if not model_path.exists():
  99. self.missing_devices.append(device_code)
  100. logger.warning(f"模型不存在: {device_code} -> {model_path}")
  101. if self.missing_devices:
  102. logger.warning(f"以下设备缺少模型,进入冷启动模式: {self.missing_devices}")
  103. return True
  104. return False
  105. def start_cold_start_mode(self):
  106. # 进入冷启动模式,开始计时
  107. self.is_cold_start = True
  108. self.cold_start_time = datetime.now()
  109. logger.info("=" * 70)
  110. logger.info("冷启动模式")
  111. logger.info("=" * 70)
  112. logger.info(f"缺少模型的设备: {self.missing_devices}")
  113. logger.info(f"等待时长: {self.wait_hours} 小时")
  114. logger.info(f"每设备最少样本: {self.min_samples}")
  115. logger.info("系统将只采集音频,不进行检测(缺模型的设备)")
  116. logger.info("=" * 70)
  117. def check_ready_for_training(self) -> bool:
  118. """
  119. 检查是否已收集到足够的数据可以开始首次训练
  120. 条件:
  121. 1. 已等待超过 wait_hours 小时
  122. 2. 所有缺模型设备的音频文件数 >= min_samples
  123. 返回:
  124. True 表示数据已就绪
  125. """
  126. if not self.is_cold_start:
  127. return False
  128. # 条件1:等待时间
  129. elapsed = (datetime.now() - self.cold_start_time).total_seconds() / 3600
  130. if elapsed < self.wait_hours:
  131. return False
  132. # 条件2:逐设备检查数据量
  133. for device_code in self.missing_devices:
  134. device_dir = self.audio_root / device_code
  135. if not device_dir.exists():
  136. logger.info(f"数据目录不存在: {device_code}")
  137. return False
  138. # 统计 current 目录和日期目录下的 wav 文件数
  139. total_samples = 0
  140. for sub_dir in device_dir.iterdir():
  141. if sub_dir.is_dir():
  142. total_samples += len(list(sub_dir.glob("*.wav")))
  143. if total_samples < self.min_samples:
  144. logger.info(f"数据量不足: {device_code} -> {total_samples}/{self.min_samples}")
  145. return False
  146. logger.info("所有冷启动设备数据收集完成")
  147. return True
  148. def run_initial_training(self, on_device_trained=None) -> bool:
  149. """
  150. 执行首次训练(冷启动)
  151. 使用 IncrementalTrainer 的 cold_start_mode,收集所有可用数据进行全量训练。
  152. Args:
  153. on_device_trained: 可选回调 fn(device_code),单设备训练完成后调用
  154. 返回:
  155. True 表示训练成功
  156. """
  157. logger.info("=" * 70)
  158. logger.info("开始首次训练(冷启动)")
  159. logger.info("=" * 70)
  160. try:
  161. from auto_training.incremental_trainer import IncrementalTrainer
  162. config_file = self.deploy_root / "config" / "auto_training.yaml"
  163. trainer = IncrementalTrainer(config_file)
  164. # 冷启动模式:收集所有目录的数据,用全量训练
  165. trainer.cold_start_mode = True
  166. trainer.use_days_ago = 0
  167. trainer.sample_hours = 0
  168. # 首次训练用更多轮数和稍高学习率
  169. trainer.epochs = 90
  170. trainer.learning_rate = 0.001
  171. success = trainer.run_daily_training(
  172. on_device_trained=on_device_trained
  173. )
  174. if success:
  175. logger.info("=" * 70)
  176. logger.info("首次训练完成,切换到正常检测模式")
  177. logger.info("=" * 70)
  178. self.is_cold_start = False
  179. return success
  180. except Exception as e:
  181. logger.error(f"首次训练失败: {e}", exc_info=True)
  182. return False
  183. class IntegratedSystem:
  184. """
  185. 集成系统:冷启动 + 拾音器监控 + 自动增量训练 + 数据清理
  186. 作为 run_pickup_monitor.py 的上层封装,在其基础上增加:
  187. 1. 冷启动自动训练
  188. 2. APScheduler 定时增量训练和数据清理
  189. 3. 训练完成后主动触发模型热更新
  190. """
  191. def __init__(self):
  192. self.deploy_root = Path(__file__).parent
  193. self.auto_config_file = self.deploy_root / "config" / "auto_training.yaml"
  194. # 加载自动训练配置
  195. self.auto_config = self._load_yaml(self.auto_config_file)
  196. # 运行时对象
  197. self.scheduler = None
  198. self.pickup_system = None
  199. self.cold_start_manager = None
  200. self.cold_start_thread = None
  201. def _load_yaml(self, config_file: Path) -> dict:
  202. # 加载 YAML 配置文件,不存在时返回空字典
  203. if not config_file.exists():
  204. logger.warning(f"配置文件不存在: {config_file}")
  205. return {}
  206. with open(config_file, 'r', encoding='utf-8') as f:
  207. return yaml.safe_load(f) or {}
  208. def _check_and_handle_cold_start(self) -> bool:
  209. """
  210. 检查并处理冷启动
  211. 利用 PickupMonitoringSystem 中已注册的设备列表,
  212. 检查哪些设备缺少模型文件,对缺失设备启动冷启动训练。
  213. 返回:
  214. True 表示处于冷启动模式
  215. """
  216. self.cold_start_manager = ColdStartManager(
  217. self.deploy_root,
  218. self.auto_config
  219. )
  220. # 从 pickup_system 的 multi_predictor 获取已注册设备
  221. registered = self.pickup_system.multi_predictor.device_model_map
  222. if not self.cold_start_manager.check_cold_start_needed(registered):
  223. logger.info("所有设备模型完整,进入正常检测模式")
  224. return False
  225. # 进入冷启动模式
  226. self.cold_start_manager.start_cold_start_mode()
  227. # 启动后台线程监控冷启动状态
  228. self.cold_start_thread = threading.Thread(
  229. target=self._cold_start_monitor_loop,
  230. daemon=True,
  231. name="cold-start-monitor"
  232. )
  233. self.cold_start_thread.start()
  234. return True
  235. def _cold_start_monitor_loop(self):
  236. # 后台线程:每分钟检查冷启动数据是否就绪,就绪后触发首次训练
  237. while self.cold_start_manager.is_cold_start:
  238. time.sleep(60)
  239. if self.cold_start_manager.check_ready_for_training():
  240. success = self.cold_start_manager.run_initial_training(
  241. on_device_trained=self._reload_single_device
  242. )
  243. if success:
  244. break
  245. else:
  246. # 训练失败,等待10分钟后重试
  247. logger.warning("首次训练失败,将在10分钟后重试")
  248. time.sleep(600)
  249. def _reload_single_device(self, device_code: str):
  250. """训练回调:单个设备训练完成后即时重载该设备模型"""
  251. if not self.pickup_system:
  252. return
  253. mp = self.pickup_system.multi_predictor
  254. try:
  255. if mp.reload_device(device_code):
  256. logger.info(f"模型即时重载成功: {device_code}")
  257. else:
  258. logger.warning(f"模型即时重载失败: {device_code}")
  259. except Exception as e:
  260. logger.error(f"模型即时重载异常: {device_code} | {e}")
  261. def _reload_models_after_training(self):
  262. """
  263. 训练完成后主动触发 MultiModelPredictor 重载所有设备模型
  264. 相比被动等待60秒轮询检测 mtime 变化,主动调用 reload_device() 可以:
  265. 1. 即时生效(零延迟)
  266. 2. 有明确的成功/失败反馈
  267. """
  268. if not self.pickup_system:
  269. return
  270. mp = self.pickup_system.multi_predictor
  271. success_count = 0
  272. fail_count = 0
  273. for device_code in mp.registered_devices:
  274. try:
  275. if mp.reload_device(device_code):
  276. success_count += 1
  277. else:
  278. fail_count += 1
  279. except Exception as e:
  280. logger.error(f"重载设备模型失败: {device_code} | {e}")
  281. fail_count += 1
  282. logger.info(f"模型重载完成: 成功={success_count}, 失败={fail_count}")
  283. def _setup_auto_training_tasks(self):
  284. # 配置 APScheduler 定时任务:增量训练 + 数据清理
  285. if not self.auto_config.get('auto_training', {}).get('enabled', False):
  286. logger.info("自动训练已禁用(auto_training.enabled=false),跳过定时任务配置")
  287. return
  288. logger.info("=" * 70)
  289. logger.info("配置定时任务")
  290. logger.info("=" * 70)
  291. self.scheduler = BackgroundScheduler()
  292. # 定时增量训练
  293. incremental_cfg = self.auto_config['auto_training'].get('incremental', {})
  294. if incremental_cfg.get('enabled', False):
  295. schedule_time = incremental_cfg.get('schedule_time', '02:00')
  296. hour, minute = map(int, schedule_time.split(':'))
  297. self.scheduler.add_job(
  298. self._run_incremental_training,
  299. trigger=CronTrigger(hour=hour, minute=minute),
  300. id='incremental_training',
  301. name='每日增量训练',
  302. misfire_grace_time=3600 # 错过1小时内仍执行
  303. )
  304. logger.info(f"每日增量训练: 每天 {schedule_time}")
  305. # 定时数据清理
  306. data_cfg = self.auto_config['auto_training'].get('data', {})
  307. cleanup_time = data_cfg.get('cleanup_time', '00:00')
  308. hour, minute = map(int, cleanup_time.split(':'))
  309. self.scheduler.add_job(
  310. self._run_data_cleanup,
  311. trigger=CronTrigger(hour=hour, minute=minute),
  312. id='data_cleanup',
  313. name='每日数据清理',
  314. misfire_grace_time=3600
  315. )
  316. logger.info(f"每日数据清理: 每天 {cleanup_time}")
  317. self.scheduler.start()
  318. logger.info("定时任务调度器已启动")
  319. logger.info("=" * 70)
  320. def _run_incremental_training(self):
  321. # 定时任务回调:执行增量训练(逐设备串行,每完成一个即时重载)
  322. try:
  323. logger.info("定时任务触发:增量训练开始")
  324. from auto_training.incremental_trainer import IncrementalTrainer
  325. trainer = IncrementalTrainer(self.auto_config_file)
  326. success = trainer.run_daily_training(
  327. on_device_trained=self._reload_single_device
  328. )
  329. if not success:
  330. logger.warning("增量训练返回失败,保持当前推理模型不变")
  331. except Exception as e:
  332. logger.error(f"增量训练异常: {e}", exc_info=True)
  333. def _run_data_cleanup(self):
  334. # 定时任务回调:执行每日数据清理
  335. try:
  336. logger.info("定时任务触发:数据清理开始")
  337. from auto_training.data_cleanup import DataCleaner
  338. cleaner = DataCleaner(self.auto_config_file)
  339. cleaner.run_cleanup()
  340. except Exception as e:
  341. logger.error(f"数据清理异常: {e}", exc_info=True)
  342. def start(self):
  343. # 主启动流程
  344. logger.info("=" * 70)
  345. logger.info("拾音器异响检测系统(带自动训练)")
  346. logger.info("=" * 70)
  347. # 1. 创建 PickupMonitoringSystem(会初始化 multi_predictor + 注册设备)
  348. logger.info("初始化监控系统...")
  349. from run_pickup_monitor import PickupMonitoringSystem
  350. self.pickup_system = PickupMonitoringSystem()
  351. # 2. 检查冷启动(需要在 pickup_system 初始化之后,因为需要设备注册信息)
  352. is_cold_start = self._check_and_handle_cold_start()
  353. # 3. 设置定时任务
  354. self._setup_auto_training_tasks()
  355. # 4. 覆盖信号处理(确保优雅关闭 scheduler)
  356. signal.signal(signal.SIGINT, self._signal_handler)
  357. signal.signal(signal.SIGTERM, self._signal_handler)
  358. # 5. 启动拾音器监控(这是阻塞调用,包含主循环)
  359. logger.info("启动拾音器监控...")
  360. self.pickup_system.start()
  361. def stop(self):
  362. # 关闭所有组件
  363. logger.info("停止系统...")
  364. # 先关 scheduler,避免训练任务在关停过程中触发
  365. if self.scheduler and self.scheduler.running:
  366. self.scheduler.shutdown(wait=False)
  367. logger.info("定时任务调度器已停止")
  368. if self.pickup_system:
  369. self.pickup_system.stop()
  370. logger.info("系统已停止")
  371. def _signal_handler(self, signum, frame):
  372. logger.info(f"收到信号 {signum}")
  373. self.stop()
  374. sys.exit(0)
  375. def main():
  376. setup_logging()
  377. system = IntegratedSystem()
  378. system.start()
  379. if __name__ == "__main__":
  380. main()