| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class GraphAttentionLayer(nn.Module):
- """有向图注意力层(单独处理源节点和目标节点)"""
- def __init__(self, in_features, out_features, dropout, alpha, concat=True):
- super(GraphAttentionLayer, self).__init__()
- self.dropout = dropout
- self.in_features = in_features
- self.out_features = out_features
- self.alpha = alpha
- self.concat = concat
-
- # 权重参数
- self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
- nn.init.xavier_uniform_(self.W.data, gain=1.414)
-
- # 有向图注意力参数(源节点和目标节点分开)
- self.a_src = nn.Parameter(torch.empty(size=(out_features, 1)))
- self.a_dst = nn.Parameter(torch.empty(size=(out_features, 1)))
- nn.init.xavier_uniform_(self.a_src.data, gain=1.414)
- nn.init.xavier_uniform_(self.a_dst.data, gain=1.414)
-
- self.leakyrelu = nn.LeakyReLU(self.alpha)
-
- def forward(self, h, adj):
- """
- h: 输入特征 (batch_size, num_nodes, in_features)
- adj: 邻接矩阵 (num_nodes, num_nodes)
- """
- batch_size = h.size(0)
- num_nodes = h.size(1)
-
- # 线性变换
- Wh = torch.matmul(h, self.W) # (batch_size, num_nodes, out_features)
-
- # 计算有向注意力分数
- a_input_src = torch.matmul(Wh, self.a_src) # (batch_size, num_nodes, 1)
- a_input_dst = torch.matmul(Wh, self.a_dst) # (batch_size, num_nodes, 1)
-
- # 有向图注意力分数 = 源节点分数 + 目标节点分数(转置后)
- e = a_input_src + a_input_dst.transpose(1, 2) # (batch_size, num_nodes, num_nodes)
- e = self.leakyrelu(e)
-
- # 应用邻接矩阵掩码(只保留存在的边)
- zero_vec = -9e15 * torch.ones_like(e)
- attention = torch.where(adj > 0, e, zero_vec)
-
- # 计算注意力权重
- attention = F.softmax(attention, dim=2)
- attention = F.dropout(attention, self.dropout, training=self.training)
-
- # 应用注意力权重
- h_prime = torch.matmul(attention, Wh) # (batch_size, num_nodes, out_features)
-
- if self.concat:
- return F.elu(h_prime)
- else:
- return h_prime
-
- def __repr__(self):
- return self.__class__.__name__ + f'({self.in_features} -> {self.out_features})'
- class GAT(nn.Module):
- def __init__(self, nfeat, nhid, noutput, dropout, alpha, nheads):
- super(GAT, self).__init__()
- self.dropout = dropout
-
- # 多头注意力层(有向图适配)
- self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True)
- for _ in range(nheads)]
- for i, attention in enumerate(self.attentions):
- self.add_module(f'attention_{i}', attention)
-
- # 输出层
- self.out_att = GraphAttentionLayer(nhid * nheads, noutput, dropout=dropout, alpha=alpha, concat=False)
-
- def forward(self, x, adj):
- """
- x: 输入特征 (batch_size, num_nodes, nfeat)
- adj: 邻接矩阵 (num_nodes, num_nodes)
- """
- x = F.dropout(x, self.dropout, training=self.training)
- # 拼接多头注意力输出
- x = torch.cat([att(x, adj) for att in self.attentions], dim=2)
- x = F.dropout(x, self.dropout, training=self.training)
- x = F.elu(self.out_att(x, adj))
-
- return x
|