فهرست منبع

有向图模型

zhanghao 4 ماه پیش
والد
کامیت
b9f7a5c4db
1فایلهای تغییر یافته به همراه89 افزوده شده و 89 حذف شده
  1. 89 89
      models/causal-inference/gat.py

+ 89 - 89
models/causal-inference/gat.py

@@ -1,90 +1,90 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-class GraphAttentionLayer(nn.Module):
-    """有向图注意力层(单独处理源节点和目标节点)"""
-    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
-        super(GraphAttentionLayer, self).__init__()
-        self.dropout = dropout
-        self.in_features = in_features
-        self.out_features = out_features
-        self.alpha = alpha
-        self.concat = concat
-        
-        # 权重参数
-        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
-        nn.init.xavier_uniform_(self.W.data, gain=1.414)
-        
-        # 有向图注意力参数(源节点和目标节点分开)
-        self.a_src = nn.Parameter(torch.empty(size=(out_features, 1)))
-        self.a_dst = nn.Parameter(torch.empty(size=(out_features, 1)))
-        nn.init.xavier_uniform_(self.a_src.data, gain=1.414)
-        nn.init.xavier_uniform_(self.a_dst.data, gain=1.414)
-        
-        self.leakyrelu = nn.LeakyReLU(self.alpha)
-        
-    def forward(self, h, adj):
-        """
-        h: 输入特征 (batch_size, num_nodes, in_features)
-        adj: 邻接矩阵 (num_nodes, num_nodes)
-        """
-        batch_size = h.size(0)
-        num_nodes = h.size(1)
-        
-        # 线性变换
-        Wh = torch.matmul(h, self.W)  # (batch_size, num_nodes, out_features)
-        
-        # 计算有向注意力分数
-        a_input_src = torch.matmul(Wh, self.a_src)  # (batch_size, num_nodes, 1)
-        a_input_dst = torch.matmul(Wh, self.a_dst)  # (batch_size, num_nodes, 1)
-        
-        # 有向图注意力分数 = 源节点分数 + 目标节点分数(转置后)
-        e = a_input_src + a_input_dst.transpose(1, 2)  # (batch_size, num_nodes, num_nodes)
-        e = self.leakyrelu(e)
-        
-        # 应用邻接矩阵掩码(只保留存在的边)
-        zero_vec = -9e15 * torch.ones_like(e)
-        attention = torch.where(adj > 0, e, zero_vec)
-        
-        # 计算注意力权重
-        attention = F.softmax(attention, dim=2)
-        attention = F.dropout(attention, self.dropout, training=self.training)
-        
-        # 应用注意力权重
-        h_prime = torch.matmul(attention, Wh)  # (batch_size, num_nodes, out_features)
-        
-        if self.concat:
-            return F.elu(h_prime)
-        else:
-            return h_prime
-        
-    def __repr__(self):
-        return self.__class__.__name__ + f'({self.in_features} -> {self.out_features})'
-
-class GAT(nn.Module):
-    def __init__(self, nfeat, nhid, noutput, dropout, alpha, nheads):
-        super(GAT, self).__init__()
-        self.dropout = dropout
-        
-        # 多头注意力层(有向图适配)
-        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) 
-                           for _ in range(nheads)]
-        for i, attention in enumerate(self.attentions):
-            self.add_module(f'attention_{i}', attention)
-        
-        # 输出层
-        self.out_att = GraphAttentionLayer(nhid * nheads, noutput, dropout=dropout, alpha=alpha, concat=False)
-        
-    def forward(self, x, adj):
-        """
-        x: 输入特征 (batch_size, num_nodes, nfeat)
-        adj: 邻接矩阵 (num_nodes, num_nodes)
-        """
-        x = F.dropout(x, self.dropout, training=self.training)
-        # 拼接多头注意力输出
-        x = torch.cat([att(x, adj) for att in self.attentions], dim=2)
-        x = F.dropout(x, self.dropout, training=self.training)
-        x = F.elu(self.out_att(x, adj))
-        
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class GraphAttentionLayer(nn.Module):
+    """有向图注意力层(单独处理源节点和目标节点)"""
+    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
+        super(GraphAttentionLayer, self).__init__()
+        self.dropout = dropout
+        self.in_features = in_features
+        self.out_features = out_features
+        self.alpha = alpha
+        self.concat = concat
+        
+        # 权重参数
+        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
+        nn.init.xavier_uniform_(self.W.data, gain=1.414)
+        
+        # 有向图注意力参数(源节点和目标节点分开)
+        self.a_src = nn.Parameter(torch.empty(size=(out_features, 1)))
+        self.a_dst = nn.Parameter(torch.empty(size=(out_features, 1)))
+        nn.init.xavier_uniform_(self.a_src.data, gain=1.414)
+        nn.init.xavier_uniform_(self.a_dst.data, gain=1.414)
+        
+        self.leakyrelu = nn.LeakyReLU(self.alpha)
+        
+    def forward(self, h, adj):
+        """
+        h: 输入特征 (batch_size, num_nodes, in_features)
+        adj: 邻接矩阵 (num_nodes, num_nodes)
+        """
+        batch_size = h.size(0)
+        num_nodes = h.size(1)
+        
+        # 线性变换
+        Wh = torch.matmul(h, self.W)  # (batch_size, num_nodes, out_features)
+        
+        # 计算有向注意力分数
+        a_input_src = torch.matmul(Wh, self.a_src)  # (batch_size, num_nodes, 1)
+        a_input_dst = torch.matmul(Wh, self.a_dst)  # (batch_size, num_nodes, 1)
+        
+        # 有向图注意力分数 = 源节点分数 + 目标节点分数(转置后)
+        e = a_input_src + a_input_dst.transpose(1, 2)  # (batch_size, num_nodes, num_nodes)
+        e = self.leakyrelu(e)
+        
+        # 应用邻接矩阵掩码(只保留存在的边)
+        zero_vec = -9e15 * torch.ones_like(e)
+        attention = torch.where(adj > 0, e, zero_vec)
+        
+        # 计算注意力权重
+        attention = F.softmax(attention, dim=2)
+        attention = F.dropout(attention, self.dropout, training=self.training)
+        
+        # 应用注意力权重
+        h_prime = torch.matmul(attention, Wh)  # (batch_size, num_nodes, out_features)
+        
+        if self.concat:
+            return F.elu(h_prime)
+        else:
+            return h_prime
+        
+    def __repr__(self):
+        return self.__class__.__name__ + f'({self.in_features} -> {self.out_features})'
+
+class GAT(nn.Module):
+    def __init__(self, nfeat, nhid, noutput, dropout, alpha, nheads):
+        super(GAT, self).__init__()
+        self.dropout = dropout
+        
+        # 多头注意力层(有向图适配)
+        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) 
+                           for _ in range(nheads)]
+        for i, attention in enumerate(self.attentions):
+            self.add_module(f'attention_{i}', attention)
+        
+        # 输出层
+        self.out_att = GraphAttentionLayer(nhid * nheads, noutput, dropout=dropout, alpha=alpha, concat=False)
+        
+    def forward(self, x, adj):
+        """
+        x: 输入特征 (batch_size, num_nodes, nfeat)
+        adj: 邻接矩阵 (num_nodes, num_nodes)
+        """
+        x = F.dropout(x, self.dropout, training=self.training)
+        # 拼接多头注意力输出
+        x = torch.cat([att(x, adj) for att in self.attentions], dim=2)
+        x = F.dropout(x, self.dropout, training=self.training)
+        x = F.elu(self.out_att(x, adj))
+        
         return x