import torch import argparse def get_args(): parser = argparse.ArgumentParser(description='RL-Optimized GAT for time series prediction') # 数据参数 parser.add_argument('--data_dir', type=str, default='../datasets_xishan', help='Directory for data files') parser.add_argument('--num_files', type=int, default=50, help='Number of data files (1 to num_files)') parser.add_argument('--test_ratio', type=float, default=0.2, help='Ratio of test data') parser.add_argument('--val_ratio', type=float, default=0.1, help='Ratio of validation data') # 模型参数 parser.add_argument('--num_features', type=int, default=145, help='Number of feature variables') parser.add_argument('--num_targets', type=int, default=47, help='Number of target variables') parser.add_argument('--hidden_dim', type=int, default=64, help='Default hidden dimension of GAT') parser.add_argument('--num_heads', type=int, default=4, help='Default number of attention heads') parser.add_argument('--dropout', type=float, default=0.3, help='Default dropout rate') # 训练参数 parser.add_argument('--batch_size', type=int, default=128, help='Batch size') parser.add_argument('--lr', type=float, default=0.001, help='Default learning rate') parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for final training') parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay') parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use for training') parser.add_argument('--grad_clip', type=float, default=1.0, help='Gradient clipping threshold') parser.add_argument('--patience', type=int, default=20, help='Patience for early stopping') # 强化学习参数 parser.add_argument('--rl_timesteps', type=int, default=5000, help='Total timesteps for RL training') parser.add_argument('--rl_max_steps', type=int, default=20, help='Max steps per RL episode') parser.add_argument('--rl_eval_episodes', type=int, default=10, help='Number of episodes for RL evaluation') # 小波去噪参数 parser.add_argument('--wavelet', type=str, default='db4', help='Wavelet type for denoising') parser.add_argument('--wavelet_level', type=int, default=1, help='Wavelet decomposition level') args = parser.parse_args() return args