standalone_train.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. standalone_train.py
  5. -------------------
  6. 独立全量训练 CLI 入口
  7. 使用方法:
  8. # 训练指定数据目录下的所有设备
  9. python auto_training/standalone_train.py --data-dir /path/to/plant_audio_data
  10. # 只训练指定设备
  11. python auto_training/standalone_train.py --data-dir /path/to/data --devices LT-2 LT-5
  12. # 自定义训练参数
  13. python auto_training/standalone_train.py --data-dir /path/to/data --epochs 100 --lr 0.0001
  14. 数据目录结构约定:
  15. data_dir/
  16. ├── LT-2/ <-- 子文件夹名 = device_code
  17. │ ├── xxx.wav <-- 扁平结构
  18. │ └── ...
  19. └── LT-5/
  20. ├── 20260301/ <-- 或日期嵌套结构
  21. │ ├── yyy.wav
  22. │ └── ...
  23. └── 20260302/
  24. └── ...
  25. 产出目录结构:
  26. models/
  27. ├── LT-2/
  28. │ ├── ae_model.pth
  29. │ ├── global_scale.npy
  30. │ └── thresholds/
  31. │ └── threshold_LT-2.npy
  32. └── LT-5/
  33. ├── ae_model.pth
  34. ├── global_scale.npy
  35. └── thresholds/
  36. └── threshold_LT-5.npy
  37. """
  38. import sys
  39. import argparse
  40. import logging
  41. from pathlib import Path
  42. # 添加项目根目录到路径
  43. sys.path.insert(0, str(Path(__file__).parent.parent))
  44. def main():
  45. parser = argparse.ArgumentParser(
  46. description="泵异响检测模型 - 全量训练工具",
  47. formatter_class=argparse.RawDescriptionHelpFormatter,
  48. epilog="""
  49. 示例:
  50. python standalone_train.py --data-dir /data/xishan_audio
  51. python standalone_train.py --data-dir /data/xishan_audio --devices LT-2 LT-5
  52. python standalone_train.py --data-dir /data/xishan_audio --epochs 100 --lr 0.00005
  53. """
  54. )
  55. parser.add_argument(
  56. "--data-dir", type=str, required=True,
  57. help="训练数据目录路径(子文件夹名=设备编码)"
  58. )
  59. parser.add_argument(
  60. "--epochs", type=int, default=None,
  61. help="训练轮数(默认使用 auto_training.yaml 中的配置)"
  62. )
  63. parser.add_argument(
  64. "--lr", type=float, default=None,
  65. help="学习率(默认使用 auto_training.yaml 中的配置)"
  66. )
  67. parser.add_argument(
  68. "--batch-size", type=int, default=None,
  69. help="批大小(默认使用 auto_training.yaml 中的配置)"
  70. )
  71. parser.add_argument(
  72. "--devices", nargs="+", default=None,
  73. help="只训练指定设备(空格分隔,如 --devices LT-2 LT-5)"
  74. )
  75. parser.add_argument(
  76. "--log-level", type=str, default="INFO",
  77. choices=["DEBUG", "INFO", "WARNING", "ERROR"],
  78. help="日志级别(默认 INFO)"
  79. )
  80. args = parser.parse_args()
  81. # 配置日志
  82. logging.basicConfig(
  83. level=getattr(logging, args.log_level),
  84. format='%(asctime)s | %(levelname)-8s | %(name)s | %(message)s',
  85. datefmt='%Y-%m-%d %H:%M:%S'
  86. )
  87. logger = logging.getLogger('StandaloneTrain')
  88. # 验证数据目录
  89. data_dir = Path(args.data_dir)
  90. if not data_dir.exists():
  91. logger.error(f"数据目录不存在: {data_dir}")
  92. sys.exit(1)
  93. # 加载配置
  94. config_file = Path(__file__).parent.parent / "config" / "auto_training.yaml"
  95. if not config_file.exists():
  96. logger.error(f"配置文件不存在: {config_file}")
  97. sys.exit(1)
  98. # 初始化训练器
  99. from auto_training.incremental_trainer import IncrementalTrainer
  100. trainer = IncrementalTrainer(config_file)
  101. # 覆盖 batch_size(如果指定)
  102. if args.batch_size:
  103. trainer.config['auto_training']['incremental']['batch_size'] = args.batch_size
  104. # 执行全量训练
  105. logger.info("=" * 60)
  106. logger.info("泵异响检测模型 - 全量训练")
  107. logger.info(f"数据目录: {data_dir}")
  108. logger.info(f"设备过滤: {args.devices or '全部'}")
  109. logger.info(f"训练轮数: {args.epochs or '配置默认'}")
  110. logger.info(f"学习率: {args.lr or '配置默认'}")
  111. logger.info(f"批大小: {args.batch_size or '配置默认'}")
  112. logger.info("=" * 60)
  113. success = trainer.run_full_training(
  114. data_dir=data_dir,
  115. epochs=args.epochs,
  116. lr=args.lr,
  117. device_filter=args.devices
  118. )
  119. if success:
  120. logger.info("训练全部完成,模型已部署到 models/ 目录")
  121. # 列出产出文件
  122. model_root = trainer.model_root
  123. for device_dir in sorted(model_root.iterdir()):
  124. if device_dir.is_dir() and (device_dir / "ae_model.pth").exists():
  125. threshold_dir = device_dir / "thresholds"
  126. threshold_files = list(threshold_dir.glob("*.npy")) if threshold_dir.exists() else []
  127. logger.info(f" {device_dir.name}/: "
  128. f"model=OK, scale=OK, thresholds={len(threshold_files)}")
  129. else:
  130. logger.error("训练存在失败,请检查日志")
  131. sys.exit(0 if success else 1)
  132. if __name__ == "__main__":
  133. main()