gat.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class GraphAttentionLayer(nn.Module):
  5. """有向图注意力层(单独处理源节点和目标节点)"""
  6. def __init__(self, in_features, out_features, dropout, alpha, concat=True):
  7. super(GraphAttentionLayer, self).__init__()
  8. self.dropout = dropout
  9. self.in_features = in_features
  10. self.out_features = out_features
  11. self.alpha = alpha
  12. self.concat = concat
  13. # 权重参数
  14. self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
  15. nn.init.xavier_uniform_(self.W.data, gain=1.414)
  16. # 有向图注意力参数(源节点和目标节点分开)
  17. self.a_src = nn.Parameter(torch.empty(size=(out_features, 1)))
  18. self.a_dst = nn.Parameter(torch.empty(size=(out_features, 1)))
  19. nn.init.xavier_uniform_(self.a_src.data, gain=1.414)
  20. nn.init.xavier_uniform_(self.a_dst.data, gain=1.414)
  21. self.leakyrelu = nn.LeakyReLU(self.alpha)
  22. def forward(self, h, adj):
  23. """
  24. h: 输入特征 (batch_size, num_nodes, in_features)
  25. adj: 邻接矩阵 (num_nodes, num_nodes)
  26. """
  27. batch_size = h.size(0)
  28. num_nodes = h.size(1)
  29. # 线性变换
  30. Wh = torch.matmul(h, self.W) # (batch_size, num_nodes, out_features)
  31. # 计算有向注意力分数
  32. a_input_src = torch.matmul(Wh, self.a_src) # (batch_size, num_nodes, 1)
  33. a_input_dst = torch.matmul(Wh, self.a_dst) # (batch_size, num_nodes, 1)
  34. # 有向图注意力分数 = 源节点分数 + 目标节点分数(转置后)
  35. e = a_input_src + a_input_dst.transpose(1, 2) # (batch_size, num_nodes, num_nodes)
  36. e = self.leakyrelu(e)
  37. # 应用邻接矩阵掩码(只保留存在的边)
  38. zero_vec = -9e15 * torch.ones_like(e)
  39. attention = torch.where(adj > 0, e, zero_vec)
  40. # 计算注意力权重
  41. attention = F.softmax(attention, dim=2)
  42. attention = F.dropout(attention, self.dropout, training=self.training)
  43. # 应用注意力权重
  44. h_prime = torch.matmul(attention, Wh) # (batch_size, num_nodes, out_features)
  45. if self.concat:
  46. return F.elu(h_prime)
  47. else:
  48. return h_prime
  49. def __repr__(self):
  50. return self.__class__.__name__ + f'({self.in_features} -> {self.out_features})'
  51. class GAT(nn.Module):
  52. def __init__(self, nfeat, nhid, noutput, dropout, alpha, nheads):
  53. super(GAT, self).__init__()
  54. self.dropout = dropout
  55. # 多头注意力层(有向图适配)
  56. self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True)
  57. for _ in range(nheads)]
  58. for i, attention in enumerate(self.attentions):
  59. self.add_module(f'attention_{i}', attention)
  60. # 输出层
  61. self.out_att = GraphAttentionLayer(nhid * nheads, noutput, dropout=dropout, alpha=alpha, concat=False)
  62. def forward(self, x, adj):
  63. """
  64. x: 输入特征 (batch_size, num_nodes, nfeat)
  65. adj: 邻接矩阵 (num_nodes, num_nodes)
  66. """
  67. x = F.dropout(x, self.dropout, training=self.training)
  68. # 拼接多头注意力输出
  69. x = torch.cat([att(x, adj) for att in self.attentions], dim=2)
  70. x = F.dropout(x, self.dropout, training=self.training)
  71. x = F.elu(self.out_att(x, adj))
  72. return x