|
|
@@ -1,60 +1,60 @@
|
|
|
-import torch
|
|
|
-import argparse
|
|
|
-
|
|
|
-def get_args():
|
|
|
- parser = argparse.ArgumentParser(description='RL-Optimized GAT for time series prediction')
|
|
|
-
|
|
|
- # 数据参数
|
|
|
- parser.add_argument('--data_dir', type=str, default='../datasets_xishan',
|
|
|
- help='Directory for data files')
|
|
|
- parser.add_argument('--num_files', type=int, default=50,
|
|
|
- help='Number of data files (1 to num_files)')
|
|
|
- parser.add_argument('--test_ratio', type=float, default=0.2,
|
|
|
- help='Ratio of test data')
|
|
|
- parser.add_argument('--val_ratio', type=float, default=0.1,
|
|
|
- help='Ratio of validation data')
|
|
|
-
|
|
|
- # 模型参数
|
|
|
- parser.add_argument('--num_features', type=int, default=145,
|
|
|
- help='Number of feature variables')
|
|
|
- parser.add_argument('--num_targets', type=int, default=47,
|
|
|
- help='Number of target variables')
|
|
|
- parser.add_argument('--hidden_dim', type=int, default=64,
|
|
|
- help='Default hidden dimension of GAT')
|
|
|
- parser.add_argument('--num_heads', type=int, default=4,
|
|
|
- help='Default number of attention heads')
|
|
|
- parser.add_argument('--dropout', type=float, default=0.3,
|
|
|
- help='Default dropout rate')
|
|
|
-
|
|
|
- # 训练参数
|
|
|
- parser.add_argument('--batch_size', type=int, default=128,
|
|
|
- help='Batch size')
|
|
|
- parser.add_argument('--lr', type=float, default=0.001,
|
|
|
- help='Default learning rate')
|
|
|
- parser.add_argument('--epochs', type=int, default=100,
|
|
|
- help='Number of epochs for final training')
|
|
|
- parser.add_argument('--weight_decay', type=float, default=1e-4,
|
|
|
- help='Weight decay')
|
|
|
- parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
|
|
- help='Device to use for training')
|
|
|
- parser.add_argument('--grad_clip', type=float, default=1.0,
|
|
|
- help='Gradient clipping threshold')
|
|
|
- parser.add_argument('--patience', type=int, default=20,
|
|
|
- help='Patience for early stopping')
|
|
|
-
|
|
|
- # 强化学习参数
|
|
|
- parser.add_argument('--rl_timesteps', type=int, default=5000,
|
|
|
- help='Total timesteps for RL training')
|
|
|
- parser.add_argument('--rl_max_steps', type=int, default=20,
|
|
|
- help='Max steps per RL episode')
|
|
|
- parser.add_argument('--rl_eval_episodes', type=int, default=10,
|
|
|
- help='Number of episodes for RL evaluation')
|
|
|
-
|
|
|
- # 小波去噪参数
|
|
|
- parser.add_argument('--wavelet', type=str, default='db4',
|
|
|
- help='Wavelet type for denoising')
|
|
|
- parser.add_argument('--wavelet_level', type=int, default=1,
|
|
|
- help='Wavelet decomposition level')
|
|
|
-
|
|
|
- args = parser.parse_args()
|
|
|
+import torch
|
|
|
+import argparse
|
|
|
+
|
|
|
+def get_args():
|
|
|
+ parser = argparse.ArgumentParser(description='RL-Optimized GAT for time series prediction')
|
|
|
+
|
|
|
+ # 数据参数
|
|
|
+ parser.add_argument('--data_dir', type=str, default='../datasets_xishan',
|
|
|
+ help='Directory for data files')
|
|
|
+ parser.add_argument('--num_files', type=int, default=50,
|
|
|
+ help='Number of data files (1 to num_files)')
|
|
|
+ parser.add_argument('--test_ratio', type=float, default=0.2,
|
|
|
+ help='Ratio of test data')
|
|
|
+ parser.add_argument('--val_ratio', type=float, default=0.1,
|
|
|
+ help='Ratio of validation data')
|
|
|
+
|
|
|
+ # 模型参数
|
|
|
+ parser.add_argument('--num_features', type=int, default=145,
|
|
|
+ help='Number of feature variables')
|
|
|
+ parser.add_argument('--num_targets', type=int, default=47,
|
|
|
+ help='Number of target variables')
|
|
|
+ parser.add_argument('--hidden_dim', type=int, default=64,
|
|
|
+ help='Default hidden dimension of GAT')
|
|
|
+ parser.add_argument('--num_heads', type=int, default=4,
|
|
|
+ help='Default number of attention heads')
|
|
|
+ parser.add_argument('--dropout', type=float, default=0.3,
|
|
|
+ help='Default dropout rate')
|
|
|
+
|
|
|
+ # 训练参数
|
|
|
+ parser.add_argument('--batch_size', type=int, default=128,
|
|
|
+ help='Batch size')
|
|
|
+ parser.add_argument('--lr', type=float, default=0.001,
|
|
|
+ help='Default learning rate')
|
|
|
+ parser.add_argument('--epochs', type=int, default=100,
|
|
|
+ help='Number of epochs for final training')
|
|
|
+ parser.add_argument('--weight_decay', type=float, default=1e-4,
|
|
|
+ help='Weight decay')
|
|
|
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
|
|
+ help='Device to use for training')
|
|
|
+ parser.add_argument('--grad_clip', type=float, default=1.0,
|
|
|
+ help='Gradient clipping threshold')
|
|
|
+ parser.add_argument('--patience', type=int, default=20,
|
|
|
+ help='Patience for early stopping')
|
|
|
+
|
|
|
+ # 强化学习参数
|
|
|
+ parser.add_argument('--rl_timesteps', type=int, default=5000,
|
|
|
+ help='Total timesteps for RL training')
|
|
|
+ parser.add_argument('--rl_max_steps', type=int, default=20,
|
|
|
+ help='Max steps per RL episode')
|
|
|
+ parser.add_argument('--rl_eval_episodes', type=int, default=10,
|
|
|
+ help='Number of episodes for RL evaluation')
|
|
|
+
|
|
|
+ # 小波去噪参数
|
|
|
+ parser.add_argument('--wavelet', type=str, default='db4',
|
|
|
+ help='Wavelet type for denoising')
|
|
|
+ parser.add_argument('--wavelet_level', type=int, default=1,
|
|
|
+ help='Wavelet decomposition level')
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
return args
|