import torch import torch.nn as nn import torch.nn.functional as F class GraphAttentionLayer(nn.Module): """有向图注意力层(单独处理源节点和目标节点)""" def __init__(self, in_features, out_features, dropout, alpha, concat=True): super(GraphAttentionLayer, self).__init__() self.dropout = dropout self.in_features = in_features self.out_features = out_features self.alpha = alpha self.concat = concat # 权重参数 self.W = nn.Parameter(torch.empty(size=(in_features, out_features))) nn.init.xavier_uniform_(self.W.data, gain=1.414) # 有向图注意力参数(源节点和目标节点分开) self.a_src = nn.Parameter(torch.empty(size=(out_features, 1))) self.a_dst = nn.Parameter(torch.empty(size=(out_features, 1))) nn.init.xavier_uniform_(self.a_src.data, gain=1.414) nn.init.xavier_uniform_(self.a_dst.data, gain=1.414) self.leakyrelu = nn.LeakyReLU(self.alpha) def forward(self, h, adj): """ h: 输入特征 (batch_size, num_nodes, in_features) adj: 邻接矩阵 (num_nodes, num_nodes) """ batch_size = h.size(0) num_nodes = h.size(1) # 线性变换 Wh = torch.matmul(h, self.W) # (batch_size, num_nodes, out_features) # 计算有向注意力分数 a_input_src = torch.matmul(Wh, self.a_src) # (batch_size, num_nodes, 1) a_input_dst = torch.matmul(Wh, self.a_dst) # (batch_size, num_nodes, 1) # 有向图注意力分数 = 源节点分数 + 目标节点分数(转置后) e = a_input_src + a_input_dst.transpose(1, 2) # (batch_size, num_nodes, num_nodes) e = self.leakyrelu(e) # 应用邻接矩阵掩码(只保留存在的边) zero_vec = -9e15 * torch.ones_like(e) attention = torch.where(adj > 0, e, zero_vec) # 计算注意力权重 attention = F.softmax(attention, dim=2) attention = F.dropout(attention, self.dropout, training=self.training) # 应用注意力权重 h_prime = torch.matmul(attention, Wh) # (batch_size, num_nodes, out_features) if self.concat: return F.elu(h_prime) else: return h_prime def __repr__(self): return self.__class__.__name__ + f'({self.in_features} -> {self.out_features})' class GAT(nn.Module): def __init__(self, nfeat, nhid, noutput, dropout, alpha, nheads): super(GAT, self).__init__() self.dropout = dropout # 多头注意力层(有向图适配) self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)] for i, attention in enumerate(self.attentions): self.add_module(f'attention_{i}', attention) # 输出层 self.out_att = GraphAttentionLayer(nhid * nheads, noutput, dropout=dropout, alpha=alpha, concat=False) def forward(self, x, adj): """ x: 输入特征 (batch_size, num_nodes, nfeat) adj: 邻接矩阵 (num_nodes, num_nodes) """ x = F.dropout(x, self.dropout, training=self.training) # 拼接多头注意力输出 x = torch.cat([att(x, adj) for att in self.attentions], dim=2) x = F.dropout(x, self.dropout, training=self.training) x = F.elu(self.out_att(x, adj)) return x