|
|
@@ -1,90 +1,90 @@
|
|
|
-import torch
|
|
|
-import torch.nn as nn
|
|
|
-import torch.nn.functional as F
|
|
|
-
|
|
|
-class GraphAttentionLayer(nn.Module):
|
|
|
- """有向图注意力层(单独处理源节点和目标节点)"""
|
|
|
- def __init__(self, in_features, out_features, dropout, alpha, concat=True):
|
|
|
- super(GraphAttentionLayer, self).__init__()
|
|
|
- self.dropout = dropout
|
|
|
- self.in_features = in_features
|
|
|
- self.out_features = out_features
|
|
|
- self.alpha = alpha
|
|
|
- self.concat = concat
|
|
|
-
|
|
|
- # 权重参数
|
|
|
- self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
|
|
|
- nn.init.xavier_uniform_(self.W.data, gain=1.414)
|
|
|
-
|
|
|
- # 有向图注意力参数(源节点和目标节点分开)
|
|
|
- self.a_src = nn.Parameter(torch.empty(size=(out_features, 1)))
|
|
|
- self.a_dst = nn.Parameter(torch.empty(size=(out_features, 1)))
|
|
|
- nn.init.xavier_uniform_(self.a_src.data, gain=1.414)
|
|
|
- nn.init.xavier_uniform_(self.a_dst.data, gain=1.414)
|
|
|
-
|
|
|
- self.leakyrelu = nn.LeakyReLU(self.alpha)
|
|
|
-
|
|
|
- def forward(self, h, adj):
|
|
|
- """
|
|
|
- h: 输入特征 (batch_size, num_nodes, in_features)
|
|
|
- adj: 邻接矩阵 (num_nodes, num_nodes)
|
|
|
- """
|
|
|
- batch_size = h.size(0)
|
|
|
- num_nodes = h.size(1)
|
|
|
-
|
|
|
- # 线性变换
|
|
|
- Wh = torch.matmul(h, self.W) # (batch_size, num_nodes, out_features)
|
|
|
-
|
|
|
- # 计算有向注意力分数
|
|
|
- a_input_src = torch.matmul(Wh, self.a_src) # (batch_size, num_nodes, 1)
|
|
|
- a_input_dst = torch.matmul(Wh, self.a_dst) # (batch_size, num_nodes, 1)
|
|
|
-
|
|
|
- # 有向图注意力分数 = 源节点分数 + 目标节点分数(转置后)
|
|
|
- e = a_input_src + a_input_dst.transpose(1, 2) # (batch_size, num_nodes, num_nodes)
|
|
|
- e = self.leakyrelu(e)
|
|
|
-
|
|
|
- # 应用邻接矩阵掩码(只保留存在的边)
|
|
|
- zero_vec = -9e15 * torch.ones_like(e)
|
|
|
- attention = torch.where(adj > 0, e, zero_vec)
|
|
|
-
|
|
|
- # 计算注意力权重
|
|
|
- attention = F.softmax(attention, dim=2)
|
|
|
- attention = F.dropout(attention, self.dropout, training=self.training)
|
|
|
-
|
|
|
- # 应用注意力权重
|
|
|
- h_prime = torch.matmul(attention, Wh) # (batch_size, num_nodes, out_features)
|
|
|
-
|
|
|
- if self.concat:
|
|
|
- return F.elu(h_prime)
|
|
|
- else:
|
|
|
- return h_prime
|
|
|
-
|
|
|
- def __repr__(self):
|
|
|
- return self.__class__.__name__ + f'({self.in_features} -> {self.out_features})'
|
|
|
-
|
|
|
-class GAT(nn.Module):
|
|
|
- def __init__(self, nfeat, nhid, noutput, dropout, alpha, nheads):
|
|
|
- super(GAT, self).__init__()
|
|
|
- self.dropout = dropout
|
|
|
-
|
|
|
- # 多头注意力层(有向图适配)
|
|
|
- self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True)
|
|
|
- for _ in range(nheads)]
|
|
|
- for i, attention in enumerate(self.attentions):
|
|
|
- self.add_module(f'attention_{i}', attention)
|
|
|
-
|
|
|
- # 输出层
|
|
|
- self.out_att = GraphAttentionLayer(nhid * nheads, noutput, dropout=dropout, alpha=alpha, concat=False)
|
|
|
-
|
|
|
- def forward(self, x, adj):
|
|
|
- """
|
|
|
- x: 输入特征 (batch_size, num_nodes, nfeat)
|
|
|
- adj: 邻接矩阵 (num_nodes, num_nodes)
|
|
|
- """
|
|
|
- x = F.dropout(x, self.dropout, training=self.training)
|
|
|
- # 拼接多头注意力输出
|
|
|
- x = torch.cat([att(x, adj) for att in self.attentions], dim=2)
|
|
|
- x = F.dropout(x, self.dropout, training=self.training)
|
|
|
- x = F.elu(self.out_att(x, adj))
|
|
|
-
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+import torch.nn.functional as F
|
|
|
+
|
|
|
+class GraphAttentionLayer(nn.Module):
|
|
|
+ """有向图注意力层(单独处理源节点和目标节点)"""
|
|
|
+ def __init__(self, in_features, out_features, dropout, alpha, concat=True):
|
|
|
+ super(GraphAttentionLayer, self).__init__()
|
|
|
+ self.dropout = dropout
|
|
|
+ self.in_features = in_features
|
|
|
+ self.out_features = out_features
|
|
|
+ self.alpha = alpha
|
|
|
+ self.concat = concat
|
|
|
+
|
|
|
+ # 权重参数
|
|
|
+ self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
|
|
|
+ nn.init.xavier_uniform_(self.W.data, gain=1.414)
|
|
|
+
|
|
|
+ # 有向图注意力参数(源节点和目标节点分开)
|
|
|
+ self.a_src = nn.Parameter(torch.empty(size=(out_features, 1)))
|
|
|
+ self.a_dst = nn.Parameter(torch.empty(size=(out_features, 1)))
|
|
|
+ nn.init.xavier_uniform_(self.a_src.data, gain=1.414)
|
|
|
+ nn.init.xavier_uniform_(self.a_dst.data, gain=1.414)
|
|
|
+
|
|
|
+ self.leakyrelu = nn.LeakyReLU(self.alpha)
|
|
|
+
|
|
|
+ def forward(self, h, adj):
|
|
|
+ """
|
|
|
+ h: 输入特征 (batch_size, num_nodes, in_features)
|
|
|
+ adj: 邻接矩阵 (num_nodes, num_nodes)
|
|
|
+ """
|
|
|
+ batch_size = h.size(0)
|
|
|
+ num_nodes = h.size(1)
|
|
|
+
|
|
|
+ # 线性变换
|
|
|
+ Wh = torch.matmul(h, self.W) # (batch_size, num_nodes, out_features)
|
|
|
+
|
|
|
+ # 计算有向注意力分数
|
|
|
+ a_input_src = torch.matmul(Wh, self.a_src) # (batch_size, num_nodes, 1)
|
|
|
+ a_input_dst = torch.matmul(Wh, self.a_dst) # (batch_size, num_nodes, 1)
|
|
|
+
|
|
|
+ # 有向图注意力分数 = 源节点分数 + 目标节点分数(转置后)
|
|
|
+ e = a_input_src + a_input_dst.transpose(1, 2) # (batch_size, num_nodes, num_nodes)
|
|
|
+ e = self.leakyrelu(e)
|
|
|
+
|
|
|
+ # 应用邻接矩阵掩码(只保留存在的边)
|
|
|
+ zero_vec = -9e15 * torch.ones_like(e)
|
|
|
+ attention = torch.where(adj > 0, e, zero_vec)
|
|
|
+
|
|
|
+ # 计算注意力权重
|
|
|
+ attention = F.softmax(attention, dim=2)
|
|
|
+ attention = F.dropout(attention, self.dropout, training=self.training)
|
|
|
+
|
|
|
+ # 应用注意力权重
|
|
|
+ h_prime = torch.matmul(attention, Wh) # (batch_size, num_nodes, out_features)
|
|
|
+
|
|
|
+ if self.concat:
|
|
|
+ return F.elu(h_prime)
|
|
|
+ else:
|
|
|
+ return h_prime
|
|
|
+
|
|
|
+ def __repr__(self):
|
|
|
+ return self.__class__.__name__ + f'({self.in_features} -> {self.out_features})'
|
|
|
+
|
|
|
+class GAT(nn.Module):
|
|
|
+ def __init__(self, nfeat, nhid, noutput, dropout, alpha, nheads):
|
|
|
+ super(GAT, self).__init__()
|
|
|
+ self.dropout = dropout
|
|
|
+
|
|
|
+ # 多头注意力层(有向图适配)
|
|
|
+ self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True)
|
|
|
+ for _ in range(nheads)]
|
|
|
+ for i, attention in enumerate(self.attentions):
|
|
|
+ self.add_module(f'attention_{i}', attention)
|
|
|
+
|
|
|
+ # 输出层
|
|
|
+ self.out_att = GraphAttentionLayer(nhid * nheads, noutput, dropout=dropout, alpha=alpha, concat=False)
|
|
|
+
|
|
|
+ def forward(self, x, adj):
|
|
|
+ """
|
|
|
+ x: 输入特征 (batch_size, num_nodes, nfeat)
|
|
|
+ adj: 邻接矩阵 (num_nodes, num_nodes)
|
|
|
+ """
|
|
|
+ x = F.dropout(x, self.dropout, training=self.training)
|
|
|
+ # 拼接多头注意力输出
|
|
|
+ x = torch.cat([att(x, adj) for att in self.attentions], dim=2)
|
|
|
+ x = F.dropout(x, self.dropout, training=self.training)
|
|
|
+ x = F.elu(self.out_att(x, adj))
|
|
|
+
|
|
|
return x
|