| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- # args.py
- import argparse
- def lstm_args_parser():
- parser = argparse.ArgumentParser(description="LSTM模型训练参数")
-
- # 核心数据集参数
- parser.add_argument('--train_start_date', type=str, default='2024-10-08', help='训练集开始日期')
- parser.add_argument('--train_end_date', type=str, default='2026-02-13', help='训练集结束日期')
- parser.add_argument('--val_start_date', type=str, default='2024-10-08', help='验证集开始日期')
- parser.add_argument('--val_end_date', type=str, default='2026-02-13', help='验证集结束日期')
- parser.add_argument('--test_start_date', type=str, default='2024-10-08', help='测试集开始日期')
- parser.add_argument('--test_end_date', type=str, default='2026-02-13', help='测试集结束日期')
- # 模型架构参数
- parser.add_argument('--seq_len', type=int, default=10, help='输入序列长度')
- parser.add_argument('--output_size', type=int, default=5, help='预测步长')
- parser.add_argument('--step_size', type=int, default=5, help='采样步长')
- parser.add_argument('--resolution', type=int, default=60, help='数据分辨率(分钟)')
- parser.add_argument('--feature_num', type=int, default=32, help='输入特征维度')
- parser.add_argument('--labels_num', type=int, default=4, help='预测标签数量(子模型数量)')
-
- # 训练超参数
- parser.add_argument('--epochs', type=int, default=200, help='训练轮数')
- parser.add_argument('--hidden_size', type=int, default=64, help='隐藏层大小')
- parser.add_argument('--num_layers', type=int, default=1, help='LSTM层数')
- parser.add_argument('--dropout', type=float, default=0, help='dropout概率')
- parser.add_argument('--lr', type=float, default=0.01, help='学习率')
- parser.add_argument('--batch_size', type=int, default=512, help='批次大小')
-
- parser.add_argument('--scheduler_step_size', type=int, default=100, help='学习率调整步长')
- parser.add_argument('--scheduler_gamma', type=float, default=0.9, help='学习率衰减率')
-
- parser.add_argument('--patience', type=int, default=200, help='早停耐心值')
- parser.add_argument('--min_delta', type=float, default=1e-10, help='最小改善阈值')
- parser.add_argument('--device', type=int, default=1, help='GPU设备ID')
- # 文件路径配置
- parser.add_argument('--start_files', type=int, default=1, help='开始文件索引')
- parser.add_argument('--end_files', type=int, default=24, help='结束文件索引')
- parser.add_argument('--data_dir', type=str, default='datasets_jianding', help='数据文件夹路径')
- parser.add_argument('--file_pattern', type=str, default='data_process_{}.csv', help='数据文件命名模式')
-
- parser.add_argument('--model_path', type=str, default='model.pth', help='模型保存路径')
- parser.add_argument('--scaler_path', type=str, default='scaler.pkl', help='归一化器路径')
- parser.add_argument('--output_csv_path', type=str, default='predictions.csv', help='预测评估结果路径')
-
- parser.add_argument('--random_seed', type=int, default=1314, help='随机种子')
- args = parser.parse_args()
- return args
|