| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- standalone_train.py
- -------------------
- 独立全量训练 CLI 入口
- 使用方法:
- # 训练指定数据目录下的所有设备
- python auto_training/standalone_train.py --data-dir /path/to/plant_audio_data
- # 只训练指定设备
- python auto_training/standalone_train.py --data-dir /path/to/data --devices LT-2 LT-5
- # 自定义训练参数
- python auto_training/standalone_train.py --data-dir /path/to/data --epochs 100 --lr 0.0001
- 数据目录结构约定:
- data_dir/
- ├── LT-2/ <-- 子文件夹名 = device_code
- │ ├── xxx.wav <-- 扁平结构
- │ └── ...
- └── LT-5/
- ├── 20260301/ <-- 或日期嵌套结构
- │ ├── yyy.wav
- │ └── ...
- └── 20260302/
- └── ...
- 产出目录结构:
- models/
- ├── LT-2/
- │ ├── ae_model.pth
- │ ├── global_scale.npy
- │ └── thresholds/
- │ └── threshold_LT-2.npy
- └── LT-5/
- ├── ae_model.pth
- ├── global_scale.npy
- └── thresholds/
- └── threshold_LT-5.npy
- """
- import sys
- import argparse
- import logging
- from pathlib import Path
- # 添加项目根目录到路径
- sys.path.insert(0, str(Path(__file__).parent.parent))
- def main():
- parser = argparse.ArgumentParser(
- description="泵异响检测模型 - 全量训练工具",
- formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog="""
- 示例:
- python standalone_train.py --data-dir /data/xishan_audio
- python standalone_train.py --data-dir /data/xishan_audio --devices LT-2 LT-5
- python standalone_train.py --data-dir /data/xishan_audio --epochs 100 --lr 0.00005
- """
- )
- parser.add_argument(
- "--data-dir", type=str, required=True,
- help="训练数据目录路径(子文件夹名=设备编码)"
- )
- parser.add_argument(
- "--epochs", type=int, default=None,
- help="训练轮数(默认使用 auto_training.yaml 中的配置)"
- )
- parser.add_argument(
- "--lr", type=float, default=None,
- help="学习率(默认使用 auto_training.yaml 中的配置)"
- )
- parser.add_argument(
- "--batch-size", type=int, default=None,
- help="批大小(默认使用 auto_training.yaml 中的配置)"
- )
- parser.add_argument(
- "--devices", nargs="+", default=None,
- help="只训练指定设备(空格分隔,如 --devices LT-2 LT-5)"
- )
- parser.add_argument(
- "--log-level", type=str, default="INFO",
- choices=["DEBUG", "INFO", "WARNING", "ERROR"],
- help="日志级别(默认 INFO)"
- )
- args = parser.parse_args()
- # 配置日志
- logging.basicConfig(
- level=getattr(logging, args.log_level),
- format='%(asctime)s | %(levelname)-8s | %(name)s | %(message)s',
- datefmt='%Y-%m-%d %H:%M:%S'
- )
- logger = logging.getLogger('StandaloneTrain')
- # 验证数据目录
- data_dir = Path(args.data_dir)
- if not data_dir.exists():
- logger.error(f"数据目录不存在: {data_dir}")
- sys.exit(1)
- # 加载配置
- config_file = Path(__file__).parent.parent / "config" / "auto_training.yaml"
- if not config_file.exists():
- logger.error(f"配置文件不存在: {config_file}")
- sys.exit(1)
- # 初始化训练器
- from auto_training.incremental_trainer import IncrementalTrainer
- trainer = IncrementalTrainer(config_file)
- # 覆盖 batch_size(如果指定)
- if args.batch_size:
- trainer.config['auto_training']['incremental']['batch_size'] = args.batch_size
- # 执行全量训练
- logger.info("=" * 60)
- logger.info("泵异响检测模型 - 全量训练")
- logger.info(f"数据目录: {data_dir}")
- logger.info(f"设备过滤: {args.devices or '全部'}")
- logger.info(f"训练轮数: {args.epochs or '配置默认'}")
- logger.info(f"学习率: {args.lr or '配置默认'}")
- logger.info(f"批大小: {args.batch_size or '配置默认'}")
- logger.info("=" * 60)
- success = trainer.run_full_training(
- data_dir=data_dir,
- epochs=args.epochs,
- lr=args.lr,
- device_filter=args.devices
- )
- if success:
- logger.info("训练全部完成,模型已部署到 models/ 目录")
- # 列出产出文件
- model_root = trainer.model_root
- for device_dir in sorted(model_root.iterdir()):
- if device_dir.is_dir() and (device_dir / "ae_model.pth").exists():
- threshold_dir = device_dir / "thresholds"
- threshold_files = list(threshold_dir.glob("*.npy")) if threshold_dir.exists() else []
- logger.info(f" {device_dir.name}/: "
- f"model=OK, scale=OK, thresholds={len(threshold_files)}")
- else:
- logger.error("训练存在失败,请检查日志")
- sys.exit(0 if success else 1)
- if __name__ == "__main__":
- main()
|