| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- import torch
- import argparse
- def get_args():
- parser = argparse.ArgumentParser(description='RL-Optimized GAT for time series prediction')
-
- # 数据参数
- parser.add_argument('--data_dir', type=str, default='../datasets_xishan',
- help='Directory for data files')
- parser.add_argument('--num_files', type=int, default=50,
- help='Number of data files (1 to num_files)')
- parser.add_argument('--test_ratio', type=float, default=0.2,
- help='Ratio of test data')
- parser.add_argument('--val_ratio', type=float, default=0.1,
- help='Ratio of validation data')
-
- # 模型参数
- parser.add_argument('--num_features', type=int, default=145,
- help='Number of feature variables')
- parser.add_argument('--num_targets', type=int, default=47,
- help='Number of target variables')
- parser.add_argument('--hidden_dim', type=int, default=64,
- help='Default hidden dimension of GAT')
- parser.add_argument('--num_heads', type=int, default=4,
- help='Default number of attention heads')
- parser.add_argument('--dropout', type=float, default=0.3,
- help='Default dropout rate')
-
- # 训练参数
- parser.add_argument('--batch_size', type=int, default=128,
- help='Batch size')
- parser.add_argument('--lr', type=float, default=0.001,
- help='Default learning rate')
- parser.add_argument('--epochs', type=int, default=100,
- help='Number of epochs for final training')
- parser.add_argument('--weight_decay', type=float, default=1e-4,
- help='Weight decay')
- parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
- help='Device to use for training')
- parser.add_argument('--grad_clip', type=float, default=1.0,
- help='Gradient clipping threshold')
- parser.add_argument('--patience', type=int, default=20,
- help='Patience for early stopping')
-
- # 强化学习参数
- parser.add_argument('--rl_timesteps', type=int, default=5000,
- help='Total timesteps for RL training')
- parser.add_argument('--rl_max_steps', type=int, default=20,
- help='Max steps per RL episode')
- parser.add_argument('--rl_eval_episodes', type=int, default=10,
- help='Number of episodes for RL evaluation')
-
- # 小波去噪参数
- parser.add_argument('--wavelet', type=str, default='db4',
- help='Wavelet type for denoising')
- parser.add_argument('--wavelet_level', type=int, default=1,
- help='Wavelet decomposition level')
-
- args = parser.parse_args()
- return args
|