run_with_auto_training.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. run_with_auto_training.py
  5. -------------------------
  6. 带自动训练的主程序
  7. 功能:
  8. 1. RTSP拾音器监控(通过 PickupMonitoringSystem)
  9. 2. 冷启动检测(无模型时等待数据收集后自动训练)
  10. 3. 每日增量训练(默认02:00,可配置)
  11. 4. 每日数据清理(默认00:00,可配置)
  12. 5. 训练完成后主动通知 MultiModelPredictor 重载,无需等待60秒轮询
  13. """
  14. import sys
  15. import signal
  16. import logging
  17. import threading
  18. import time
  19. from pathlib import Path
  20. from datetime import datetime
  21. # 确保项目根目录在 sys.path 中
  22. sys.path.insert(0, str(Path(__file__).parent))
  23. try:
  24. from apscheduler.schedulers.background import BackgroundScheduler
  25. from apscheduler.triggers.cron import CronTrigger
  26. except ImportError:
  27. print("错误:缺少依赖库")
  28. print("请运行:pip install apscheduler")
  29. sys.exit(1)
  30. from config.config_manager import ConfigManager
  31. def setup_logging():
  32. # 配置日志系统(按文件大小轮转),与 run_pickup_monitor.py 共用同一日志文件
  33. from logging.handlers import RotatingFileHandler
  34. # 如果根 logger 已经被上层调用者配置过,则直接复用
  35. root = logging.getLogger()
  36. if root.handlers:
  37. return
  38. log_dir = Path(__file__).parent / "logs"
  39. log_dir.mkdir(parents=True, exist_ok=True)
  40. log_file = log_dir / "system.log"
  41. formatter = logging.Formatter(
  42. '%(asctime)s | %(levelname)-8s | %(name)-20s | %(message)s',
  43. datefmt='%Y-%m-%d %H:%M:%S'
  44. )
  45. # 10MB 单文件上限,保留 2 个备份(总计约 30MB)
  46. file_handler = RotatingFileHandler(
  47. log_file,
  48. maxBytes=10 * 1024 * 1024,
  49. backupCount=2,
  50. encoding='utf-8'
  51. )
  52. file_handler.setFormatter(formatter)
  53. console_handler = logging.StreamHandler(sys.stdout)
  54. console_handler.setFormatter(formatter)
  55. logging.basicConfig(
  56. level=logging.INFO,
  57. handlers=[file_handler, console_handler]
  58. )
  59. logger = logging.getLogger('MainSystem')
  60. class ColdStartManager:
  61. """
  62. 冷启动管理器
  63. 检查注册的设备是否都有对应的模型文件,
  64. 缺少模型的设备会被标记为冷启动状态,等待数据收集后自动训练。
  65. """
  66. def __init__(self, deploy_root: Path, config: dict):
  67. self.deploy_root = deploy_root
  68. self.config = config
  69. # 模型根目录(每设备子目录)
  70. self.model_root = deploy_root / "models"
  71. # 音频数据根目录
  72. self.audio_root = deploy_root / "data" / "audio"
  73. # self.audio_root = "/Volumes/mo/水厂正常音频/龙亭"
  74. # 冷启动配置
  75. cold_start_cfg = config.get('auto_training', {}).get('cold_start', {})
  76. self.enabled = cold_start_cfg.get('enabled', True)
  77. # 等待收集数据的最短时长(小时),避免数据量过少就开始训练
  78. self.wait_hours = cold_start_cfg.get('wait_hours', 2)
  79. # 每设备最少样本数
  80. self.min_samples = cold_start_cfg.get('min_samples', 100)
  81. # 运行时状态
  82. self.is_cold_start = False
  83. self.cold_start_time = None
  84. # 缺少模型的设备列表
  85. self.missing_devices = []
  86. def check_cold_start_needed(self, registered_devices: dict) -> bool:
  87. """
  88. 检查是否有设备缺少模型
  89. 参数:
  90. registered_devices: {device_code: model_subdir} 已注册的设备映射
  91. 返回:
  92. True 表示至少有一个设备缺少模型,需要冷启动
  93. """
  94. if not self.enabled:
  95. return False
  96. self.missing_devices = []
  97. for device_code, model_subdir in registered_devices.items():
  98. model_path = self.model_root / model_subdir / "ae_model.pth"
  99. if not model_path.exists():
  100. self.missing_devices.append(device_code)
  101. logger.warning(f"模型不存在: {device_code} -> {model_path}")
  102. if self.missing_devices:
  103. logger.warning(f"以下设备缺少模型,进入冷启动模式: {self.missing_devices}")
  104. return True
  105. return False
  106. def start_cold_start_mode(self):
  107. # 进入冷启动模式,开始计时
  108. self.is_cold_start = True
  109. self.cold_start_time = datetime.now()
  110. logger.info("=" * 70)
  111. logger.info("冷启动模式")
  112. logger.info("=" * 70)
  113. logger.info(f"缺少模型的设备: {self.missing_devices}")
  114. logger.info(f"等待时长: {self.wait_hours} 小时")
  115. logger.info(f"每设备最少样本: {self.min_samples}")
  116. logger.info("系统将只采集音频,不进行检测(缺模型的设备)")
  117. logger.info("=" * 70)
  118. def check_ready_for_training(self) -> bool:
  119. """
  120. 检查是否已收集到足够的数据可以开始首次训练
  121. 条件:
  122. 1. 已等待超过 wait_hours 小时
  123. 2. 所有缺模型设备的音频文件数 >= min_samples
  124. 返回:
  125. True 表示数据已就绪
  126. """
  127. if not self.is_cold_start:
  128. return False
  129. # 条件1:等待时间
  130. elapsed = (datetime.now() - self.cold_start_time).total_seconds() / 3600
  131. if elapsed < self.wait_hours:
  132. return False
  133. # 条件2:逐设备检查数据量
  134. for device_code in self.missing_devices:
  135. device_dir = self.audio_root / device_code
  136. if not device_dir.exists():
  137. logger.info(f"数据目录不存在: {device_code}")
  138. return False
  139. # 统计 current 目录和日期目录下的 wav 文件数
  140. total_samples = 0
  141. for sub_dir in device_dir.iterdir():
  142. if sub_dir.is_dir():
  143. # 新结构:{date}/normal/ 子目录
  144. normal_dir = sub_dir / "normal"
  145. if normal_dir.exists():
  146. total_samples += len(list(normal_dir.glob("*.wav")))
  147. # 兼容旧结构 + current 目录:直接存放的 wav
  148. total_samples += len(list(sub_dir.glob("*.wav")))
  149. if total_samples < self.min_samples:
  150. logger.info(f"数据量不足: {device_code} -> {total_samples}/{self.min_samples}")
  151. return False
  152. logger.info("所有冷启动设备数据收集完成")
  153. return True
  154. def run_initial_training(self, on_device_trained=None) -> bool:
  155. """
  156. 执行首次训练(冷启动)
  157. 使用 IncrementalTrainer 的 cold_start_mode,收集所有可用数据进行全量训练。
  158. Args:
  159. on_device_trained: 可选回调 fn(device_code),单设备训练完成后调用
  160. 返回:
  161. True 表示训练成功
  162. """
  163. logger.info("=" * 70)
  164. logger.info("开始首次训练(冷启动)")
  165. logger.info("=" * 70)
  166. try:
  167. from auto_training.incremental_trainer import IncrementalTrainer
  168. # 从当前内存中的配置 dict 初始化训练器(配置来源为数据库)
  169. trainer = IncrementalTrainer(config=self.config)
  170. # 冷启动模式:收集所有目录的数据,用全量训练
  171. trainer.cold_start_mode = True
  172. trainer.use_days_ago = 0
  173. trainer.sample_hours = 0
  174. # 首次训练用更多轮数和稍高学习率
  175. trainer.epochs = 90
  176. trainer.learning_rate = 0.001
  177. success = trainer.run_daily_training(
  178. on_device_trained=on_device_trained
  179. )
  180. if success:
  181. logger.info("=" * 70)
  182. logger.info("首次训练完成,切换到正常检测模式")
  183. logger.info("=" * 70)
  184. self.is_cold_start = False
  185. return success
  186. except Exception as e:
  187. logger.error(f"首次训练失败: {e}", exc_info=True)
  188. return False
  189. class IntegratedSystem:
  190. """
  191. 集成系统:冷启动 + 拾音器监控 + 自动增量训练 + 数据清理
  192. 作为 run_pickup_monitor.py 的上层封装,在其基础上增加:
  193. 1. 冷启动自动训练
  194. 2. APScheduler 定时增量训练和数据清理
  195. 3. 训练完成后主动触发模型热更新
  196. """
  197. def __init__(self):
  198. self.deploy_root = Path(__file__).parent
  199. # =========================================================================
  200. # 配置加载来源:自动检测
  201. # 优先使用 YAML(config/rtsp_config.yaml 存在时)
  202. # 否则使用 SQLite 数据库
  203. # =========================================================================
  204. yaml_path = self.deploy_root / "config" / "rtsp_config.yaml"
  205. if yaml_path.exists():
  206. import yaml
  207. with open(yaml_path, 'r', encoding='utf-8') as f:
  208. full_config = yaml.safe_load(f)
  209. self.full_yaml_config = full_config
  210. self.auto_config = {'auto_training': full_config.get('auto_training', {})}
  211. logger.info(f"已从 YAML ({yaml_path.name}) 加载配置")
  212. else:
  213. self.full_yaml_config = None
  214. mgr = ConfigManager()
  215. self.auto_config = {'auto_training': mgr.get_system_config('auto_training')}
  216. mgr.close()
  217. logger.info(f"已从数据库加载 auto_training 配置 ({len(self.auto_config.get('auto_training', {}))} 项)")
  218. # 运行时对象
  219. self.scheduler = None
  220. self.pickup_system = None
  221. self.cold_start_manager = None
  222. self.cold_start_thread = None
  223. def _check_and_handle_cold_start(self) -> bool:
  224. """
  225. 检查并处理冷启动
  226. 利用 PickupMonitoringSystem 中已注册的设备列表,
  227. 检查哪些设备缺少模型文件,对缺失设备启动冷启动训练。
  228. 返回:
  229. True 表示处于冷启动模式
  230. """
  231. self.cold_start_manager = ColdStartManager(
  232. self.deploy_root,
  233. self.auto_config
  234. )
  235. # 从 pickup_system 的 multi_predictor 获取已注册设备
  236. registered = self.pickup_system.multi_predictor.device_model_map
  237. if not self.cold_start_manager.check_cold_start_needed(registered):
  238. logger.info("所有设备模型完整,进入正常检测模式")
  239. return False
  240. # 进入冷启动模式
  241. self.cold_start_manager.start_cold_start_mode()
  242. # 启动后台线程监控冷启动状态
  243. self.cold_start_thread = threading.Thread(
  244. target=self._cold_start_monitor_loop,
  245. daemon=True,
  246. name="cold-start-monitor"
  247. )
  248. self.cold_start_thread.start()
  249. return True
  250. def _cold_start_monitor_loop(self):
  251. # 后台线程:每分钟检查冷启动数据是否就绪,就绪后触发首次训练
  252. while self.cold_start_manager.is_cold_start:
  253. time.sleep(60)
  254. if self.cold_start_manager.check_ready_for_training():
  255. success = self.cold_start_manager.run_initial_training(
  256. on_device_trained=self._reload_single_device
  257. )
  258. if success:
  259. break
  260. else:
  261. # 训练失败,等待10分钟后重试
  262. logger.warning("首次训练失败,将在10分钟后重试")
  263. time.sleep(600)
  264. def _reload_single_device(self, device_code: str):
  265. """训练回调:单个设备训练完成后即时重载该设备模型"""
  266. if not self.pickup_system:
  267. return
  268. mp = self.pickup_system.multi_predictor
  269. try:
  270. if mp.reload_device(device_code):
  271. logger.info(f"模型即时重载成功: {device_code}")
  272. else:
  273. logger.warning(f"模型即时重载失败: {device_code}")
  274. except Exception as e:
  275. logger.error(f"模型即时重载异常: {device_code} | {e}")
  276. def _reload_models_after_training(self):
  277. """
  278. 训练完成后主动触发 MultiModelPredictor 重载所有设备模型
  279. 相比被动等待60秒轮询检测 mtime 变化,主动调用 reload_device() 可以:
  280. 1. 即时生效(零延迟)
  281. 2. 有明确的成功/失败反馈
  282. """
  283. if not self.pickup_system:
  284. return
  285. mp = self.pickup_system.multi_predictor
  286. success_count = 0
  287. fail_count = 0
  288. for device_code in mp.registered_devices:
  289. try:
  290. if mp.reload_device(device_code):
  291. success_count += 1
  292. else:
  293. fail_count += 1
  294. except Exception as e:
  295. logger.error(f"重载设备模型失败: {device_code} | {e}")
  296. fail_count += 1
  297. logger.info(f"模型重载完成: 成功={success_count}, 失败={fail_count}")
  298. def _setup_auto_training_tasks(self):
  299. # 配置 APScheduler 定时任务:增量训练 + 数据清理
  300. if not self.auto_config.get('auto_training', {}).get('enabled', False):
  301. logger.info("自动训练已禁用(auto_training.enabled=false),跳过定时任务配置")
  302. return
  303. logger.info("=" * 70)
  304. logger.info("配置定时任务")
  305. logger.info("=" * 70)
  306. self.scheduler = BackgroundScheduler()
  307. # 定时增量训练
  308. incremental_cfg = self.auto_config['auto_training'].get('incremental', {})
  309. if incremental_cfg.get('enabled', False):
  310. schedule_time = incremental_cfg.get('schedule_time', '02:00')
  311. hour, minute = map(int, schedule_time.split(':'))
  312. self.scheduler.add_job(
  313. self._run_incremental_training,
  314. trigger=CronTrigger(hour=hour, minute=minute),
  315. id='incremental_training',
  316. name='每日增量训练',
  317. misfire_grace_time=3600 # 错过1小时内仍执行
  318. )
  319. logger.info(f"每日增量训练: 每天 {schedule_time}")
  320. # 定时数据清理
  321. data_cfg = self.auto_config['auto_training'].get('data', {})
  322. cleanup_time = data_cfg.get('cleanup_time', '00:00')
  323. hour, minute = map(int, cleanup_time.split(':'))
  324. self.scheduler.add_job(
  325. self._run_data_cleanup,
  326. trigger=CronTrigger(hour=hour, minute=minute),
  327. id='data_cleanup',
  328. name='每日数据清理',
  329. misfire_grace_time=3600
  330. )
  331. logger.info(f"每日数据清理: 每天 {cleanup_time}")
  332. self.scheduler.start()
  333. logger.info("定时任务调度器已启动")
  334. logger.info("=" * 70)
  335. def _run_incremental_training(self):
  336. # 定时任务回调:执行增量训练(逐设备串行,每完成一个即时重载)
  337. try:
  338. logger.info("定时任务触发:增量训练开始")
  339. from auto_training.incremental_trainer import IncrementalTrainer
  340. # 传 config dict 而非 YAML 路径,配置来源为数据库
  341. trainer = IncrementalTrainer(config=self.auto_config)
  342. success = trainer.run_daily_training(
  343. on_device_trained=self._reload_single_device
  344. )
  345. if not success:
  346. logger.warning("增量训练返回失败,保持当前推理模型不变")
  347. except Exception as e:
  348. logger.error(f"增量训练异常: {e}", exc_info=True)
  349. def _run_data_cleanup(self):
  350. # 定时任务回调:执行每日数据清理
  351. try:
  352. logger.info("定时任务触发:数据清理开始")
  353. from auto_training.data_cleanup import DataCleaner
  354. # 传 config dict 而非 YAML 路径,配置来源为数据库
  355. cleaner = DataCleaner(config=self.auto_config)
  356. cleaner.run_cleanup()
  357. except Exception as e:
  358. logger.error(f"数据清理异常: {e}", exc_info=True)
  359. def start(self):
  360. # 主启动流程
  361. logger.info("=" * 70)
  362. logger.info("拾音器异响检测系统(带自动训练)")
  363. logger.info("=" * 70)
  364. # 1. 创建 PickupMonitoringSystem(会初始化 multi_predictor + 注册设备)
  365. logger.info("初始化监控系统...")
  366. from run_pickup_monitor import PickupMonitoringSystem
  367. self.pickup_system = PickupMonitoringSystem(yaml_config=self.full_yaml_config)
  368. # 2. 检查冷启动(需要在 pickup_system 初始化之后,因为需要设备注册信息)
  369. is_cold_start = self._check_and_handle_cold_start()
  370. # 3. 设置定时任务
  371. self._setup_auto_training_tasks()
  372. # 4. 覆盖信号处理(确保关闭 scheduler)
  373. signal.signal(signal.SIGINT, self._signal_handler)
  374. signal.signal(signal.SIGTERM, self._signal_handler)
  375. # 5. 启动拾音器监控(这是阻塞调用,包含主循环)
  376. logger.info("启动拾音器监控...")
  377. self.pickup_system.start()
  378. def stop(self):
  379. # 关闭所有组件
  380. logger.info("停止系统...")
  381. # 先关 scheduler,避免训练任务在关停过程中触发
  382. if self.scheduler and self.scheduler.running:
  383. self.scheduler.shutdown(wait=False)
  384. logger.info("定时任务调度器已停止")
  385. if self.pickup_system:
  386. self.pickup_system.stop()
  387. logger.info("系统已停止")
  388. def _signal_handler(self, signum, frame):
  389. logger.info(f"收到信号 {signum}")
  390. self.stop()
  391. sys.exit(0)
  392. def main():
  393. setup_logging()
  394. system = IntegratedSystem()
  395. system.start()
  396. if __name__ == "__main__":
  397. main()