Explorar o código

因果图模型

zhanghao hai 4 meses
pai
achega
5e0783fef3

+ 60 - 0
models/causal-inference/args.py

@@ -0,0 +1,60 @@
+import torch
+import argparse
+
+def get_args():
+    parser = argparse.ArgumentParser(description='RL-Optimized GAT for time series prediction')
+    
+    # 数据参数
+    parser.add_argument('--data_dir', type=str, default='../datasets_xishan', 
+                       help='Directory for data files')
+    parser.add_argument('--num_files', type=int, default=50, 
+                       help='Number of data files (1 to num_files)')
+    parser.add_argument('--test_ratio', type=float, default=0.2, 
+                       help='Ratio of test data')
+    parser.add_argument('--val_ratio', type=float, default=0.1, 
+                       help='Ratio of validation data')
+    
+    # 模型参数
+    parser.add_argument('--num_features', type=int, default=145, 
+                       help='Number of feature variables')
+    parser.add_argument('--num_targets', type=int, default=47, 
+                       help='Number of target variables')
+    parser.add_argument('--hidden_dim', type=int, default=64, 
+                       help='Default hidden dimension of GAT')
+    parser.add_argument('--num_heads', type=int, default=4, 
+                       help='Default number of attention heads')
+    parser.add_argument('--dropout', type=float, default=0.3, 
+                       help='Default dropout rate')
+    
+    # 训练参数
+    parser.add_argument('--batch_size', type=int, default=128, 
+                       help='Batch size')
+    parser.add_argument('--lr', type=float, default=0.001, 
+                       help='Default learning rate')
+    parser.add_argument('--epochs', type=int, default=100, 
+                       help='Number of epochs for final training')
+    parser.add_argument('--weight_decay', type=float, default=1e-4, 
+                       help='Weight decay')
+    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
+                       help='Device to use for training')
+    parser.add_argument('--grad_clip', type=float, default=1.0,
+                       help='Gradient clipping threshold')
+    parser.add_argument('--patience', type=int, default=20,
+                       help='Patience for early stopping')
+    
+    # 强化学习参数
+    parser.add_argument('--rl_timesteps', type=int, default=5000, 
+                       help='Total timesteps for RL training')
+    parser.add_argument('--rl_max_steps', type=int, default=20, 
+                       help='Max steps per RL episode')
+    parser.add_argument('--rl_eval_episodes', type=int, default=10, 
+                       help='Number of episodes for RL evaluation')
+    
+    # 小波去噪参数
+    parser.add_argument('--wavelet', type=str, default='db4',
+                       help='Wavelet type for denoising')
+    parser.add_argument('--wavelet_level', type=int, default=1,
+                       help='Wavelet decomposition level')
+    
+    args = parser.parse_args()
+    return args

+ 227 - 0
models/causal-inference/data_preprocessor.py

@@ -0,0 +1,227 @@
+import os
+import pandas as pd
+import numpy as np
+import pywt
+import logging
+from sklearn.preprocessing import StandardScaler
+from sklearn.model_selection import train_test_split
+import torch
+import joblib
+from torch.utils.data import TensorDataset, DataLoader
+
+class DataPreprocessor:
+    def __init__(self, args, logger=None):
+        self.args = args
+        self.data_dir = args.data_dir
+        self.num_files = args.num_files
+        self.scaler_features = StandardScaler()
+        self.scaler_targets = StandardScaler()
+        self.logger = logger if logger is not None else self._default_logger()
+        self.features = None  # 保存特征数据用于构建邻接矩阵
+        self.scaler_dir = 'scalers'
+        os.makedirs(self.scaler_dir, exist_ok=True)
+        self.features_scaler_path = os.path.join(self.scaler_dir, 'features_scaler.joblib')
+        self.targets_scaler_path = os.path.join(self.scaler_dir, 'targets_scaler.joblib')
+        
+    def _default_logger(self):
+        """默认日志记录器"""
+        logger = logging.getLogger('DataPreprocessor')
+        logger.setLevel(logging.INFO)
+        console_handler = logging.StreamHandler()
+        console_handler.setLevel(logging.INFO)
+        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+        console_handler.setFormatter(formatter)
+        logger.addHandler(console_handler)
+        return logger
+    
+    def load_data(self):
+        """加载所有数据文件并合并"""
+        all_data = []
+        
+        for i in range(1, self.num_files + 1):
+            file_path = os.path.join(self.data_dir, f'data_process_{i}.csv')
+            try:
+                df = pd.read_csv(file_path, index_col=0)
+                df = df.reset_index()  # 将原索引作为第一列
+                all_data.append(df)
+                self.logger.info(f"Loaded file {i}/{self.num_files}")
+            except Exception as e:
+                self.logger.error(f"Error loading file {i}: {e}")
+        
+        combined_df = pd.concat(all_data, ignore_index=True)
+        return combined_df
+    
+    def decompose_time(self, df):
+        """将时间列分解为年、月、日、时、分、秒"""
+        time_col = df.columns[0]
+        df[time_col] = pd.to_datetime(df[time_col])
+        
+        df['year'] = df[time_col].dt.year
+        df['month'] = df[time_col].dt.month
+        df['day'] = df[time_col].dt.day
+        df['hour'] = df[time_col].dt.hour
+        df['minute'] = df[time_col].dt.minute
+        df['second'] = df[time_col].dt.second
+        
+        df = df.drop(columns=[time_col])
+        
+        # 调整列顺序
+        time_features = ['year', 'month', 'day', 'hour', 'minute', 'second']
+        other_features = [col for col in df.columns if col not in time_features]
+        df = df[time_features + other_features]
+        
+        return df
+    
+    def wavelet_denoising(self, data, wavelet='db4', level=1):
+        """对数据进行小波降噪,避免除以零警告"""
+        denoised_data = np.zeros_like(data)
+        epsilon = 1e-10  # 极小值,避免除以零
+        
+        for i in range(data.shape[1]):
+            # 小波分解
+            coeffs = pywt.wavedec(data[:, i], wavelet, level=level)
+            
+            # 计算阈值时避免系数为零
+            sigma = np.median(np.abs(coeffs[-level] + epsilon)) / 0.6745  # 加epsilon
+            original_length = len(data[:, i])
+            threshold = sigma * np.sqrt(2 * np.log(original_length))
+            
+            # 对系数进行阈值处理(手动实现软阈值,避免库函数警告)
+            processed_coeffs = []
+            for c in coeffs[1:]:
+                # 手动计算软阈值:y = sign(x) * max(|x| - threshold, 0)
+                magnitude = np.abs(c)
+                # 避免除以零:给magnitude加epsilon
+                thresholded = np.where(
+                    magnitude > threshold,
+                    np.sign(c) * (magnitude - threshold),
+                    0
+                )
+                processed_coeffs.append(thresholded)
+            
+            coeffs[1:] = processed_coeffs
+            
+            # 小波重构(保持之前的长度对齐处理)
+            reconstructed = pywt.waverec(coeffs, wavelet)
+            # 补充之前的长度对齐逻辑(如果之前已添加)
+            original_length = data[:, i].shape[0]
+            if len(reconstructed) > original_length:
+                reconstructed = reconstructed[:original_length]
+            elif len(reconstructed) < original_length:
+                reconstructed = np.pad(reconstructed, (0, original_length - len(reconstructed)), mode='edge')
+            
+            denoised_data[:, i] = reconstructed
+        
+        return denoised_data
+    
+    def normalize_data(self, features, targets):
+        """归一化数据并保存scaler"""
+        features_scaled = self.scaler_features.fit_transform(features)
+        targets_scaled = self.scaler_targets.fit_transform(targets)
+        
+        # 保存scaler
+        joblib.dump(self.scaler_features, self.features_scaler_path)
+        joblib.dump(self.scaler_targets, self.targets_scaler_path)
+        self.logger.info(f"已保存特征归一化模型到 {self.features_scaler_path}")
+        self.logger.info(f"已保存目标归一化模型到 {self.targets_scaler_path}")
+        
+        return features_scaled, targets_scaled
+    
+    def inverse_transform_targets(self, targets_scaled):
+        """将归一化的目标变量反变换回原始尺度"""
+        return self.scaler_targets.inverse_transform(targets_scaled)
+    
+    def split_data(self, features, targets):
+        """划分训练集、验证集和测试集"""
+        X_train, X_temp, y_train, y_temp = train_test_split(
+            features, targets, test_size=self.args.test_ratio + self.args.val_ratio, 
+            shuffle=False
+        )
+        
+        test_size = self.args.test_ratio / (self.args.test_ratio + self.args.val_ratio)
+        X_val, X_test, y_val, y_test = train_test_split(
+            X_temp, y_temp, test_size=test_size, shuffle=False
+        )
+        
+        return X_train, X_val, X_test, y_train, y_val, y_test
+    
+    def create_dataloaders(self, X_train, X_val, X_test, y_train, y_val, y_test):
+        """创建DataLoader"""
+        X_train = torch.FloatTensor(X_train)
+        y_train = torch.FloatTensor(y_train)
+        X_val = torch.FloatTensor(X_val)
+        y_val = torch.FloatTensor(y_val)
+        X_test = torch.FloatTensor(X_test)
+        y_test = torch.FloatTensor(y_test)
+        
+        train_dataset = TensorDataset(X_train, y_train)
+        val_dataset = TensorDataset(X_val, y_val)
+        test_dataset = TensorDataset(X_test, y_test)
+        
+        train_loader = DataLoader(
+            train_dataset, batch_size=self.args.batch_size, shuffle=True
+        )
+        val_loader = DataLoader(
+            val_dataset, batch_size=self.args.batch_size, shuffle=False
+        )
+        test_loader = DataLoader(
+            test_dataset, batch_size=self.args.batch_size, shuffle=False
+        )
+        
+        return train_loader, val_loader, test_loader
+    
+    def preprocess(self):
+        """完整的预处理流程"""
+        df = self.load_data()
+        self.logger.info(f"Original data shape: {df.shape}")
+        
+        df = self.decompose_time(df)
+        self.logger.info(f"Data shape after time decomposition: {df.shape}")
+        
+        data = df.values
+        data_denoised = self.wavelet_denoising(data)
+        self.logger.info(f"Data shape after wavelet denoising: {data_denoised.shape}")
+        
+        # 保存特征数据用于构建邻接矩阵
+        self.features = data_denoised[:, :self.args.num_features]
+        targets = data_denoised[:, self.args.num_features:self.args.num_features+self.args.num_targets]
+        self.logger.info(f"Features shape: {self.features.shape}, Targets shape: {targets.shape}")
+        
+        features_scaled, targets_scaled = self.normalize_data(self.features, targets)
+        
+        X_train, X_val, X_test, y_train, y_val, y_test = self.split_data(
+            features_scaled, targets_scaled
+        )
+        self.logger.info(f"Train: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}")
+        
+        train_loader, val_loader, test_loader = self.create_dataloaders(
+            X_train, X_val, X_test, y_train, y_val, y_test
+        )
+        
+        return train_loader, val_loader, test_loader, self
+    
+    def create_adjacency_matrix(self):
+        """创建有向图的邻接矩阵(基于特征相关性)"""
+        num_nodes = self.args.num_features
+        adj = torch.zeros((num_nodes, num_nodes))
+        
+        if self.features is None:
+            self.logger.warning("特征数据未初始化,使用默认邻接矩阵")
+            # 默认自连接
+            for i in range(num_nodes):
+                adj[i, i] = 1
+            return adj
+        
+        # 计算特征之间的相关性
+        corr_matrix = np.corrcoef(self.features.T)
+        corr_threshold = 0.3  # 相关性阈值
+        
+        # 基于相关性构建有向边
+        for i in range(num_nodes):
+            adj[i, i] = 1  # 自连接
+            for j in range(num_nodes):
+                if i != j and abs(corr_matrix[i, j]) > corr_threshold:
+                    adj[i, j] = 1
+        
+        self.logger.info(f"邻接矩阵中边的数量: {int(torch.sum(adj))}")
+        return adj

+ 262 - 0
models/causal-inference/data_trainer.py

@@ -0,0 +1,262 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import numpy as np
+import matplotlib.pyplot as plt
+from tqdm import tqdm
+import os
+from sklearn.metrics import mean_absolute_error, mean_squared_error
+import logging
+
+class DataTrainer:
+    def __init__(self, model, args, preprocessor, optimizer=None, scheduler=None, logger=None):
+        self.model = model
+        self.args = args
+        self.preprocessor = preprocessor
+        self.device = args.device
+        self.logger = logger if logger is not None else self._default_logger()
+        
+        self.criterion = nn.MSELoss()
+        self.optimizer = optimizer if optimizer is not None else optim.Adam(
+            model.parameters(),
+            lr=args.lr,
+            weight_decay=args.weight_decay
+        )
+        self.scheduler = scheduler
+        
+        self.train_losses = []
+        self.val_losses = []
+        self.train_mae = []
+        self.val_mae = []
+        self.best_val_loss = float('inf')
+        self.early_stop_counter = 0
+        self.model_save_path = os.path.join('models', 'best_model.pth')
+        self.final_model_path = os.path.join('models', 'final_model.pth')
+        
+        if not os.path.exists('models'):
+            os.makedirs('models')
+        if not os.path.exists('plots'):
+            os.makedirs('plots')
+        
+    def _default_logger(self):
+        logger = logging.getLogger('DefaultLogger')
+        logger.setLevel(logging.INFO)
+        console_handler = logging.StreamHandler()
+        console_handler.setLevel(logging.INFO)
+        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+        console_handler.setFormatter(formatter)
+        if not logger.handlers:
+            logger.addHandler(console_handler)
+        return logger
+    
+    @torch.enable_grad()
+    def train_epoch(self, train_loader, adj, max_batches=None):
+        """
+        支持通过 max_batches 限制本 epoch 使用的批次数(用于 RL 快速评估)。
+        """
+        self.model.train()
+        total_loss = 0.0
+        all_outputs = []
+        all_targets = []
+        adj = adj.to(self.device)
+
+        for batch_idx, (data, target) in enumerate(train_loader):
+            if max_batches is not None and batch_idx >= max_batches:
+                break
+            data, target = data.to(self.device), target.to(self.device)
+            data = data.unsqueeze(-1)  # (batch_size, 145, 1)
+            
+            self.optimizer.zero_grad()
+            output = self.model(data, adj)  # (batch_size, 145, 47)
+            output = output.mean(dim=1)     # (batch_size, 47)
+            
+            loss = self.criterion(output, target)
+            loss.backward()
+            if self.args.grad_clip > 0:
+                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)
+            self.optimizer.step()
+            
+            total_loss += loss.item()
+            all_outputs.append(output.detach().cpu().numpy())
+            all_targets.append(target.detach().cpu().numpy())
+        
+        if len(all_outputs) == 0:
+            # 在极小数据/极小 batch 情况下的保护
+            return float('inf'), float('inf')
+        
+        all_outputs = np.vstack(all_outputs)
+        all_targets = np.vstack(all_targets)
+        mae = mean_absolute_error(all_targets, all_outputs)
+        
+        avg_loss = total_loss / (min(len(train_loader), max_batches) if max_batches else len(train_loader))
+        self.train_losses.append(avg_loss)
+        self.train_mae.append(mae)
+        return avg_loss, mae
+    
+    @torch.no_grad()
+    def validate(self, val_loader, adj, max_batches=None):
+        """
+        支持通过 max_batches 限制验证批次数(用于 RL 快速评估)。
+        """
+        self.model.eval()
+        total_loss = 0.0
+        all_outputs = []
+        all_targets = []
+        adj = adj.to(self.device)
+        
+        for batch_idx, (data, target) in enumerate(val_loader):
+            if max_batches is not None and batch_idx >= max_batches:
+                break
+            data, target = data.to(self.device), target.to(self.device)
+            data = data.unsqueeze(-1)  # (batch_size, 145, 1)
+            output = self.model(data, adj)  # (batch_size, 145, 47)
+            output = output.mean(dim=1)     # (batch_size, 47)
+            loss = self.criterion(output, target)
+            total_loss += loss.item()
+            all_outputs.append(output.cpu().numpy())
+            all_targets.append(target.cpu().numpy())
+        
+        if len(all_outputs) == 0:
+            return float('inf'), float('inf')
+        
+        all_outputs = np.vstack(all_outputs)
+        all_targets = np.vstack(all_targets)
+        mae = mean_absolute_error(all_targets, all_outputs)
+        avg_loss = total_loss / (min(len(val_loader), max_batches) if max_batches else len(val_loader))
+        self.val_losses.append(avg_loss)
+        self.val_mae.append(mae)
+        return avg_loss, mae
+    
+    def train(self, train_loader, val_loader, adj, epochs=None):
+        adj = adj.to(self.device)
+        epochs = epochs if epochs is not None else self.args.epochs
+        self.logger.info(f"开始训练,共 {epochs} 个epoch")
+        
+        for epoch in tqdm(range(epochs)):
+            train_loss, train_mae = self.train_epoch(train_loader, adj)
+            val_loss, val_mae = self.validate(val_loader, adj)
+            if self.scheduler is not None:
+                self.scheduler.step(val_loss)
+            
+            # 保存最佳模型
+            if val_loss < self.best_val_loss:
+                self.best_val_loss = val_loss
+                torch.save({
+                    'epoch': epoch,
+                    'model_state_dict': self.model.state_dict(),
+                    'optimizer_state_dict': self.optimizer.state_dict(),
+                    'loss': val_loss,
+                    'args': self.args  # 保存训练参数
+                }, self.model_save_path)
+                self.early_stop_counter = 0
+                self.logger.info(f"已更新最佳模型到 {self.model_save_path}")  # 日志
+            else:
+                self.early_stop_counter += 1
+                if self.early_stop_counter >= self.args.patience:
+                    self.logger.info(f"早停机制触发,在第 {epoch+1} 轮停止训练")
+                    break
+            
+            if (epoch + 1) % 10 == 0:
+                self.logger.info(f'Epoch {epoch+1}/{epochs}, '
+                                 f'Train Loss: {train_loss:.6f}, Train MAE: {train_mae:.6f}, '
+                                 f'Val Loss: {val_loss:.6f}, Val MAE: {val_mae:.6f}')
+        
+        self.plot_losses()
+        self.plot_mae()
+        
+        # 加载最佳模型
+        checkpoint = torch.load(self.model_save_path, map_location=self.device)
+        self.model.load_state_dict(checkpoint['model_state_dict'])
+        self.logger.info(f"加载最佳模型(第 {checkpoint['epoch']+1} 轮,验证损失: {checkpoint['loss']:.6f})")
+        
+        # 保存最终训练完成的模型(加载最佳模型后)
+        torch.save({
+            'model_state_dict': self.model.state_dict(),
+            'optimizer_state_dict': self.optimizer.state_dict(),
+            'best_val_loss': self.best_val_loss,
+            'args': self.args
+        }, self.final_model_path)
+        self.logger.info(f"已保存最终模型到 {self.final_model_path}")
+        
+        return self.model
+    
+    @torch.no_grad()
+    def test(self, test_loader, adj):
+        self.model.eval()
+        total_loss = 0.0
+        all_outputs = []
+        all_targets = []
+        adj = adj.to(self.device)
+        
+        for data, target in test_loader:
+            data, target = data.to(self.device), target.to(self.device)
+            data = data.unsqueeze(-1)
+            output = self.model(data, adj)
+            output = output.mean(dim=1)
+            loss = self.criterion(output, target)
+            total_loss += loss.item()
+            all_outputs.append(output.cpu().numpy())
+            all_targets.append(target.cpu().numpy())
+        
+        all_outputs = np.vstack(all_outputs)
+        all_targets = np.vstack(all_targets)
+        mse = total_loss / len(test_loader)
+        mae = mean_absolute_error(all_targets, all_outputs)
+        rmse = np.sqrt(mean_squared_error(all_targets, all_outputs))
+        
+        all_outputs_original = self.preprocessor.inverse_transform_targets(all_outputs)
+        all_targets_original = self.preprocessor.inverse_transform_targets(all_targets)
+        original_mse = mean_squared_error(all_targets_original, all_outputs_original)
+        original_mae = mean_absolute_error(all_targets_original, all_outputs_original)
+        original_rmse = np.sqrt(original_mse)
+        
+        self.logger.info(f'Test Loss (normalized): MSE={mse:.6f}, MAE={mae:.6f}, RMSE={rmse:.6f}')
+        self.logger.info(f'Test Loss (original scale): MSE={original_mse:.6f}, MAE={original_mae:.6f}, RMSE={original_rmse:.6f}')
+        self.plot_predictions(all_outputs_original, all_targets_original)
+        return {
+            'normalized_mse': mse,
+            'normalized_mae': mae,
+            'normalized_rmse': rmse,
+            'original_mse': original_mse,
+            'original_mae': original_mae,
+            'original_rmse': original_rmse,
+            'predictions': all_outputs_original,
+            'targets': all_targets_original
+        }
+    
+    def plot_losses(self):
+        plt.figure(figsize=(10, 6))
+        plt.plot(self.train_losses, label='Train Loss')
+        plt.plot(self.val_losses, label='Validation Loss')
+        plt.xlabel('Epoch')
+        plt.ylabel('MSE Loss')
+        plt.title('Training and Validation Loss')
+        plt.legend()
+        plt.savefig('plots/loss_curve.png')
+        plt.close()
+    
+    def plot_mae(self):
+        plt.figure(figsize=(10, 6))
+        plt.plot(self.train_mae, label='Train MAE')
+        plt.plot(self.val_mae, label='Validation MAE')
+        plt.xlabel('Epoch')
+        plt.ylabel('MAE')
+        plt.title('Training and Validation MAE')
+        plt.legend()
+        plt.savefig('plots/mae_curve.png')
+        plt.close()
+    
+    def plot_predictions(self, predictions, targets):
+        num_plots = min(3, self.args.num_targets)
+        plt.figure(figsize=(15, 5*num_plots))
+        for i in range(num_plots):
+            plt.subplot(num_plots, 1, i+1)
+            plt.plot(targets[:100, i], label='True Value')
+            plt.plot(predictions[:100, i], label='Predicted Value')
+            plt.xlabel('Time Step')
+            plt.ylabel(f'Target {i+1}')
+            plt.title(f'Prediction vs True Value for Target {i+1}')
+            plt.legend()
+        plt.tight_layout()
+        plt.savefig('plots/prediction_examples.png')
+        plt.close()

+ 90 - 0
models/causal-inference/gat.py

@@ -0,0 +1,90 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class GraphAttentionLayer(nn.Module):
+    """有向图注意力层(单独处理源节点和目标节点)"""
+    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
+        super(GraphAttentionLayer, self).__init__()
+        self.dropout = dropout
+        self.in_features = in_features
+        self.out_features = out_features
+        self.alpha = alpha
+        self.concat = concat
+        
+        # 权重参数
+        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
+        nn.init.xavier_uniform_(self.W.data, gain=1.414)
+        
+        # 有向图注意力参数(源节点和目标节点分开)
+        self.a_src = nn.Parameter(torch.empty(size=(out_features, 1)))
+        self.a_dst = nn.Parameter(torch.empty(size=(out_features, 1)))
+        nn.init.xavier_uniform_(self.a_src.data, gain=1.414)
+        nn.init.xavier_uniform_(self.a_dst.data, gain=1.414)
+        
+        self.leakyrelu = nn.LeakyReLU(self.alpha)
+        
+    def forward(self, h, adj):
+        """
+        h: 输入特征 (batch_size, num_nodes, in_features)
+        adj: 邻接矩阵 (num_nodes, num_nodes)
+        """
+        batch_size = h.size(0)
+        num_nodes = h.size(1)
+        
+        # 线性变换
+        Wh = torch.matmul(h, self.W)  # (batch_size, num_nodes, out_features)
+        
+        # 计算有向注意力分数
+        a_input_src = torch.matmul(Wh, self.a_src)  # (batch_size, num_nodes, 1)
+        a_input_dst = torch.matmul(Wh, self.a_dst)  # (batch_size, num_nodes, 1)
+        
+        # 有向图注意力分数 = 源节点分数 + 目标节点分数(转置后)
+        e = a_input_src + a_input_dst.transpose(1, 2)  # (batch_size, num_nodes, num_nodes)
+        e = self.leakyrelu(e)
+        
+        # 应用邻接矩阵掩码(只保留存在的边)
+        zero_vec = -9e15 * torch.ones_like(e)
+        attention = torch.where(adj > 0, e, zero_vec)
+        
+        # 计算注意力权重
+        attention = F.softmax(attention, dim=2)
+        attention = F.dropout(attention, self.dropout, training=self.training)
+        
+        # 应用注意力权重
+        h_prime = torch.matmul(attention, Wh)  # (batch_size, num_nodes, out_features)
+        
+        if self.concat:
+            return F.elu(h_prime)
+        else:
+            return h_prime
+        
+    def __repr__(self):
+        return self.__class__.__name__ + f'({self.in_features} -> {self.out_features})'
+
+class GAT(nn.Module):
+    def __init__(self, nfeat, nhid, noutput, dropout, alpha, nheads):
+        super(GAT, self).__init__()
+        self.dropout = dropout
+        
+        # 多头注意力层(有向图适配)
+        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) 
+                           for _ in range(nheads)]
+        for i, attention in enumerate(self.attentions):
+            self.add_module(f'attention_{i}', attention)
+        
+        # 输出层
+        self.out_att = GraphAttentionLayer(nhid * nheads, noutput, dropout=dropout, alpha=alpha, concat=False)
+        
+    def forward(self, x, adj):
+        """
+        x: 输入特征 (batch_size, num_nodes, nfeat)
+        adj: 邻接矩阵 (num_nodes, num_nodes)
+        """
+        x = F.dropout(x, self.dropout, training=self.training)
+        # 拼接多头注意力输出
+        x = torch.cat([att(x, adj) for att in self.attentions], dim=2)
+        x = F.dropout(x, self.dropout, training=self.training)
+        x = F.elu(self.out_att(x, adj))
+        
+        return x

+ 88 - 0
models/causal-inference/main.py

@@ -0,0 +1,88 @@
+import torch.optim as optim
+from args import get_args
+from data_preprocessor import DataPreprocessor
+from gat import GAT
+from data_trainer import DataTrainer
+from rl_optimizer import RLOptimizer
+import logging
+import os
+
+def setup_logger(args):
+    """设置日志记录"""
+    if not os.path.exists('logs'):
+        os.makedirs('logs')
+    
+    logger = logging.getLogger('GAT-Training')
+    logger.setLevel(logging.INFO)
+    
+    # 文件处理器
+    file_handler = logging.FileHandler(f'logs/training_{args.num_files}.log')
+    file_handler.setLevel(logging.INFO)
+    
+    # 控制台处理器
+    console_handler = logging.StreamHandler()
+    console_handler.setLevel(logging.INFO)
+    
+    # 格式化器
+    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+    file_handler.setFormatter(formatter)
+    console_handler.setFormatter(formatter)
+    
+    logger.addHandler(file_handler)
+    logger.addHandler(console_handler)
+    
+    return logger
+
+def main():
+    # 获取参数
+    args = get_args()
+    logger = setup_logger(args)
+    logger.info(f"使用设备: {args.device}")
+    
+    # 数据预处理
+    preprocessor = DataPreprocessor(args, logger)
+    train_loader, val_loader, test_loader, preprocessor = preprocessor.preprocess()
+    
+    # 创建有向图邻接矩阵
+    adj = preprocessor.create_adjacency_matrix()
+    logger.info(f"邻接矩阵形状: {adj.shape}")
+    
+    # 步骤1: 使用强化学习优化超参数
+    rl_optimizer = RLOptimizer(args, preprocessor, train_loader, val_loader, adj, logger)
+    best_hparams = rl_optimizer.optimize()
+    
+    # 步骤2: 使用最优超参数训练最终模型
+    logger.info("\n使用最优超参数训练最终模型...")
+    final_model = GAT(
+        nfeat=1,
+        nhid=best_hparams['hidden_dim'],
+        noutput=args.num_targets,
+        dropout=best_hparams['dropout'],
+        nheads=best_hparams['num_heads'],
+        alpha=0.2
+    ).to(args.device)
+    
+    # 配置优化器和学习率调度器
+    optimizer = optim.Adam(
+        final_model.parameters(),
+        lr=best_hparams['lr'],
+        weight_decay=args.weight_decay
+    )
+    
+    # 学习率调度器
+    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
+        optimizer, mode='min', factor=0.5, patience=10, verbose=True
+    )
+    
+    # 训练最终模型
+    trainer = DataTrainer(final_model, args, preprocessor, optimizer, scheduler, logger)
+    trained_model = trainer.train(train_loader, val_loader, adj)
+    
+    # 步骤3: 在测试集上评估
+    logger.info("\n在测试集上评估最终模型...")
+    test_results = trainer.test(test_loader, adj)
+    
+    logger.info("所有任务完成!")
+
+if __name__ == "__main__":
+    main()