Browse Source

参数文件

zhanghao 4 tháng trước cách đây
mục cha
commit
72b73486c0
1 tập tin đã thay đổi với 59 bổ sung59 xóa
  1. 59 59
      models/causal-inference/args.py

+ 59 - 59
models/causal-inference/args.py

@@ -1,60 +1,60 @@
-import torch
-import argparse
-
-def get_args():
-    parser = argparse.ArgumentParser(description='RL-Optimized GAT for time series prediction')
-    
-    # 数据参数
-    parser.add_argument('--data_dir', type=str, default='../datasets_xishan', 
-                       help='Directory for data files')
-    parser.add_argument('--num_files', type=int, default=50, 
-                       help='Number of data files (1 to num_files)')
-    parser.add_argument('--test_ratio', type=float, default=0.2, 
-                       help='Ratio of test data')
-    parser.add_argument('--val_ratio', type=float, default=0.1, 
-                       help='Ratio of validation data')
-    
-    # 模型参数
-    parser.add_argument('--num_features', type=int, default=145, 
-                       help='Number of feature variables')
-    parser.add_argument('--num_targets', type=int, default=47, 
-                       help='Number of target variables')
-    parser.add_argument('--hidden_dim', type=int, default=64, 
-                       help='Default hidden dimension of GAT')
-    parser.add_argument('--num_heads', type=int, default=4, 
-                       help='Default number of attention heads')
-    parser.add_argument('--dropout', type=float, default=0.3, 
-                       help='Default dropout rate')
-    
-    # 训练参数
-    parser.add_argument('--batch_size', type=int, default=128, 
-                       help='Batch size')
-    parser.add_argument('--lr', type=float, default=0.001, 
-                       help='Default learning rate')
-    parser.add_argument('--epochs', type=int, default=100, 
-                       help='Number of epochs for final training')
-    parser.add_argument('--weight_decay', type=float, default=1e-4, 
-                       help='Weight decay')
-    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
-                       help='Device to use for training')
-    parser.add_argument('--grad_clip', type=float, default=1.0,
-                       help='Gradient clipping threshold')
-    parser.add_argument('--patience', type=int, default=20,
-                       help='Patience for early stopping')
-    
-    # 强化学习参数
-    parser.add_argument('--rl_timesteps', type=int, default=5000, 
-                       help='Total timesteps for RL training')
-    parser.add_argument('--rl_max_steps', type=int, default=20, 
-                       help='Max steps per RL episode')
-    parser.add_argument('--rl_eval_episodes', type=int, default=10, 
-                       help='Number of episodes for RL evaluation')
-    
-    # 小波去噪参数
-    parser.add_argument('--wavelet', type=str, default='db4',
-                       help='Wavelet type for denoising')
-    parser.add_argument('--wavelet_level', type=int, default=1,
-                       help='Wavelet decomposition level')
-    
-    args = parser.parse_args()
+import torch
+import argparse
+
+def get_args():
+    parser = argparse.ArgumentParser(description='RL-Optimized GAT for time series prediction')
+    
+    # 数据参数
+    parser.add_argument('--data_dir', type=str, default='../datasets_xishan', 
+                       help='Directory for data files')
+    parser.add_argument('--num_files', type=int, default=50, 
+                       help='Number of data files (1 to num_files)')
+    parser.add_argument('--test_ratio', type=float, default=0.2, 
+                       help='Ratio of test data')
+    parser.add_argument('--val_ratio', type=float, default=0.1, 
+                       help='Ratio of validation data')
+    
+    # 模型参数
+    parser.add_argument('--num_features', type=int, default=145, 
+                       help='Number of feature variables')
+    parser.add_argument('--num_targets', type=int, default=47, 
+                       help='Number of target variables')
+    parser.add_argument('--hidden_dim', type=int, default=64, 
+                       help='Default hidden dimension of GAT')
+    parser.add_argument('--num_heads', type=int, default=4, 
+                       help='Default number of attention heads')
+    parser.add_argument('--dropout', type=float, default=0.3, 
+                       help='Default dropout rate')
+    
+    # 训练参数
+    parser.add_argument('--batch_size', type=int, default=128, 
+                       help='Batch size')
+    parser.add_argument('--lr', type=float, default=0.001, 
+                       help='Default learning rate')
+    parser.add_argument('--epochs', type=int, default=100, 
+                       help='Number of epochs for final training')
+    parser.add_argument('--weight_decay', type=float, default=1e-4, 
+                       help='Weight decay')
+    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
+                       help='Device to use for training')
+    parser.add_argument('--grad_clip', type=float, default=1.0,
+                       help='Gradient clipping threshold')
+    parser.add_argument('--patience', type=int, default=20,
+                       help='Patience for early stopping')
+    
+    # 强化学习参数
+    parser.add_argument('--rl_timesteps', type=int, default=5000, 
+                       help='Total timesteps for RL training')
+    parser.add_argument('--rl_max_steps', type=int, default=20, 
+                       help='Max steps per RL episode')
+    parser.add_argument('--rl_eval_episodes', type=int, default=10, 
+                       help='Number of episodes for RL evaluation')
+    
+    # 小波去噪参数
+    parser.add_argument('--wavelet', type=str, default='db4',
+                       help='Wavelet type for denoising')
+    parser.add_argument('--wavelet_level', type=int, default=1,
+                       help='Wavelet decomposition level')
+    
+    args = parser.parse_args()
     return args