args.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import torch
  2. import argparse
  3. def get_args():
  4. parser = argparse.ArgumentParser(description='RL-Optimized GAT for time series prediction')
  5. # 数据参数
  6. parser.add_argument('--data_dir', type=str, default='../datasets_xishan',
  7. help='Directory for data files')
  8. parser.add_argument('--num_files', type=int, default=50,
  9. help='Number of data files (1 to num_files)')
  10. parser.add_argument('--test_ratio', type=float, default=0.2,
  11. help='Ratio of test data')
  12. parser.add_argument('--val_ratio', type=float, default=0.1,
  13. help='Ratio of validation data')
  14. # 模型参数
  15. parser.add_argument('--num_features', type=int, default=145,
  16. help='Number of feature variables')
  17. parser.add_argument('--num_targets', type=int, default=47,
  18. help='Number of target variables')
  19. parser.add_argument('--hidden_dim', type=int, default=64,
  20. help='Default hidden dimension of GAT')
  21. parser.add_argument('--num_heads', type=int, default=4,
  22. help='Default number of attention heads')
  23. parser.add_argument('--dropout', type=float, default=0.3,
  24. help='Default dropout rate')
  25. # 训练参数
  26. parser.add_argument('--batch_size', type=int, default=128,
  27. help='Batch size')
  28. parser.add_argument('--lr', type=float, default=0.001,
  29. help='Default learning rate')
  30. parser.add_argument('--epochs', type=int, default=100,
  31. help='Number of epochs for final training')
  32. parser.add_argument('--weight_decay', type=float, default=1e-4,
  33. help='Weight decay')
  34. parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
  35. help='Device to use for training')
  36. parser.add_argument('--grad_clip', type=float, default=1.0,
  37. help='Gradient clipping threshold')
  38. parser.add_argument('--patience', type=int, default=20,
  39. help='Patience for early stopping')
  40. # 强化学习参数
  41. parser.add_argument('--rl_timesteps', type=int, default=5000,
  42. help='Total timesteps for RL training')
  43. parser.add_argument('--rl_max_steps', type=int, default=20,
  44. help='Max steps per RL episode')
  45. parser.add_argument('--rl_eval_episodes', type=int, default=10,
  46. help='Number of episodes for RL evaluation')
  47. # 小波去噪参数
  48. parser.add_argument('--wavelet', type=str, default='db4',
  49. help='Wavelet type for denoising')
  50. parser.add_argument('--wavelet_level', type=int, default=1,
  51. help='Wavelet decomposition level')
  52. args = parser.parse_args()
  53. return args