#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ standalone_train.py ------------------- 独立全量训练 CLI 入口 使用方法: # 训练指定数据目录下的所有设备 python auto_training/standalone_train.py --data-dir /path/to/plant_audio_data # 只训练指定设备 python auto_training/standalone_train.py --data-dir /path/to/data --devices LT-2 LT-5 # 自定义训练参数 python auto_training/standalone_train.py --data-dir /path/to/data --epochs 100 --lr 0.0001 数据目录结构约定: data_dir/ ├── LT-2/ <-- 子文件夹名 = device_code │ ├── xxx.wav <-- 扁平结构 │ └── ... └── LT-5/ ├── 20260301/ <-- 或日期嵌套结构 │ ├── yyy.wav │ └── ... └── 20260302/ └── ... 产出目录结构: models/ ├── LT-2/ │ ├── ae_model.pth │ ├── global_scale.npy │ └── thresholds/ │ └── threshold_LT-2.npy └── LT-5/ ├── ae_model.pth ├── global_scale.npy └── thresholds/ └── threshold_LT-5.npy """ import sys import argparse import logging from pathlib import Path # 添加项目根目录到路径 sys.path.insert(0, str(Path(__file__).parent.parent)) def main(): parser = argparse.ArgumentParser( description="泵异响检测模型 - 全量训练工具", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 示例: python standalone_train.py --data-dir /data/xishan_audio python standalone_train.py --data-dir /data/xishan_audio --devices LT-2 LT-5 python standalone_train.py --data-dir /data/xishan_audio --epochs 100 --lr 0.00005 """ ) parser.add_argument( "--data-dir", type=str, required=True, help="训练数据目录路径(子文件夹名=设备编码)" ) parser.add_argument( "--epochs", type=int, default=None, help="训练轮数(默认使用 auto_training.yaml 中的配置)" ) parser.add_argument( "--lr", type=float, default=None, help="学习率(默认使用 auto_training.yaml 中的配置)" ) parser.add_argument( "--batch-size", type=int, default=None, help="批大小(默认使用 auto_training.yaml 中的配置)" ) parser.add_argument( "--devices", nargs="+", default=None, help="只训练指定设备(空格分隔,如 --devices LT-2 LT-5)" ) parser.add_argument( "--log-level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="日志级别(默认 INFO)" ) args = parser.parse_args() # 配置日志 logging.basicConfig( level=getattr(logging, args.log_level), format='%(asctime)s | %(levelname)-8s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger('StandaloneTrain') # 验证数据目录 data_dir = Path(args.data_dir) if not data_dir.exists(): logger.error(f"数据目录不存在: {data_dir}") sys.exit(1) # 加载配置 config_file = Path(__file__).parent.parent / "config" / "auto_training.yaml" if not config_file.exists(): logger.error(f"配置文件不存在: {config_file}") sys.exit(1) # 初始化训练器 from auto_training.incremental_trainer import IncrementalTrainer trainer = IncrementalTrainer(config_file) # 覆盖 batch_size(如果指定) if args.batch_size: trainer.config['auto_training']['incremental']['batch_size'] = args.batch_size # 执行全量训练 logger.info("=" * 60) logger.info("泵异响检测模型 - 全量训练") logger.info(f"数据目录: {data_dir}") logger.info(f"设备过滤: {args.devices or '全部'}") logger.info(f"训练轮数: {args.epochs or '配置默认'}") logger.info(f"学习率: {args.lr or '配置默认'}") logger.info(f"批大小: {args.batch_size or '配置默认'}") logger.info("=" * 60) success = trainer.run_full_training( data_dir=data_dir, epochs=args.epochs, lr=args.lr, device_filter=args.devices ) if success: logger.info("训练全部完成,模型已部署到 models/ 目录") # 列出产出文件 model_root = trainer.model_root for device_dir in sorted(model_root.iterdir()): if device_dir.is_dir() and (device_dir / "ae_model.pth").exists(): threshold_dir = device_dir / "thresholds" threshold_files = list(threshold_dir.glob("*.npy")) if threshold_dir.exists() else [] logger.info(f" {device_dir.name}/: " f"model=OK, scale=OK, thresholds={len(threshold_files)}") else: logger.error("训练存在失败,请检查日志") sys.exit(0 if success else 1) if __name__ == "__main__": main()