args.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # args.py
  2. import argparse
  3. def lstm_args_parser():
  4. parser = argparse.ArgumentParser(description="LSTM模型训练参数")
  5. # 数据集划分
  6. parser.add_argument('--train_start_date', type=str, default='2024-02-23', help='训练集开始日期')
  7. parser.add_argument('--train_end_date', type=str, default='2025-09-10', help='训练集结束日期')
  8. parser.add_argument('--val_start_date', type=str, default='2025-01-01', help='验证集开始日期')
  9. parser.add_argument('--val_end_date', type=str, default='2025-09-10', help='验证集结束日期')
  10. parser.add_argument('--test_start_date', type=str, default='2025-01-01', help='测试集开始日期')
  11. parser.add_argument('--test_end_date', type=str, default='2025-09-10', help='测试集结束日期')
  12. # 模型相关参数
  13. parser.add_argument('--seq_len', type=int, default=4320, help='输入序列的长度(输入步长)')
  14. parser.add_argument('--output_size', type=int, default=2160, help='输出数据的维度(预测步长)')
  15. parser.add_argument('--step_size', type=int, default=2160, help='输入数据间隔')
  16. parser.add_argument('--resolution', type=int, default=60, help='输入数据分辨率(每多少个数据取一次)')
  17. parser.add_argument('--epochs', type=int, default=1000, help='训练轮数')
  18. parser.add_argument('--feature_num', type=int, default=16, help='特征维度')
  19. parser.add_argument('--labels_num', type=int, default=8, help='标签维度(子模型数量)')
  20. parser.add_argument('--hidden_size', type=int, default=64, help='隐藏层大小')
  21. parser.add_argument('--num_layers', type=int, default=1, help='LSTM层数')
  22. parser.add_argument('--dropout', type=float, default=0, help='dropout的概率')
  23. parser.add_argument('--lr', type=float, default=0.01, help='学习率')
  24. parser.add_argument('--batch_size', type=int, default=128, help='批次大小')
  25. # 学习率调度器
  26. parser.add_argument('--scheduler_step_size', type=int, default=100, help='学习率调整步长')
  27. parser.add_argument('--scheduler_gamma', type=float, default=0.9, help='学习率衰减率')
  28. # 早停
  29. parser.add_argument('--patience', type=int, default=500, help='早停耐心值')
  30. parser.add_argument('--min_delta', type=float, default=1e-10, help='最小改善阈值')
  31. # 设备选择
  32. parser.add_argument('--device', type=int, default=0, help='选择使用的GPU设备')
  33. # 数据处理相关参数
  34. parser.add_argument('--start_files', type=int, default=1, help='开始文件索引')
  35. parser.add_argument('--end_files', type=int, default=9, help='结束文件索引')
  36. parser.add_argument('--data_dir', type=str, default='datasets_xishan', help='数据文件夹路径')
  37. parser.add_argument('--file_pattern', type=str, default='data_process_{}.csv', help='数据文件命名模式')
  38. # 模型保存路径
  39. parser.add_argument('--model_path', type=str, default='model.pth', help='模型保存路径')
  40. parser.add_argument('--output_csv_path', type=str, default='predictions.csv', help='预测文件保存路径')
  41. # 随机种子
  42. parser.add_argument('--random_seed', type=int, default=1314, help='随机种子')
  43. args = parser.parse_args()
  44. return args