gat.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. """
  2. 有向图注意力网络 (Directed Graph Attention Network)
  3. 实现基于有向图的注意力机制,用于建模节点间的非对称因果关系。
  4. 与传统GAT不同,本实现分离源节点和目标节点的注意力参数,更适合因果推理任务。
  5. 核心特性:
  6. - 有向注意力: 源节点和目标节点使用独立的注意力参数
  7. - 多头机制: 并行学习多种关系模式
  8. - 邻接掩码: 仅在图中存在的边上计算注意力
  9. 技术实现:
  10. - 框架: PyTorch
  11. - 注意力: 加性注意力 (Additive Attention)
  12. - 激活函数: LeakyReLU (α=0.2)
  13. """
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. class GraphAttentionLayer(nn.Module):
  18. """
  19. 有向图注意力层 (Directed Graph Attention Layer)
  20. 实现单层有向图注意力机制,是GAT模型的基本构建块。
  21. 通过分离源节点和目标节点的注意力参数,支持建模非对称的因果关系。
  22. 核心思想:
  23. 传统GAT使用对称注意力权重 (A→B 和 B→A 权重相同)
  24. 有向GAT分离源节点和目标节点参数,学习方向性的因果影响
  25. 注意力计算:
  26. e_ij = LeakyReLU(a_src^T·Wh_i + a_dst^T·Wh_j)
  27. α_ij = softmax_j(e_ij)
  28. h_i' = σ(Σ_j α_ij·Wh_j)
  29. Args:
  30. in_features (int): 输入特征维度
  31. out_features (int): 输出特征维度
  32. dropout (float): Dropout概率 [0,1]
  33. alpha (float): LeakyReLU负斜率,默认0.2
  34. concat (bool): 是否使用ELU激活 (True用于中间层,False用于输出层)
  35. Example:
  36. >>> layer = GraphAttentionLayer(1, 64, dropout=0.3, alpha=0.2, concat=True)
  37. >>> h = torch.randn(32, 145, 1) # (batch, nodes, features)
  38. >>> adj = torch.ones(145, 145) # 邻接矩阵
  39. >>> output = layer(h, adj) # (32, 145, 64)
  40. """
  41. def __init__(self, in_features, out_features, dropout, alpha, concat=True):
  42. super(GraphAttentionLayer, self).__init__()
  43. self.dropout = dropout
  44. self.in_features = in_features
  45. self.out_features = out_features
  46. self.alpha = alpha
  47. self.concat = concat
  48. # 特征变换矩阵: 输入特征 → 输出特征空间
  49. # Xavier初始化保证前向/反向传播时方差稳定
  50. self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
  51. nn.init.xavier_uniform_(self.W.data, gain=1.414) # gain=1.414 适配LeakyReLU
  52. # 有向注意力参数 (核心创新)
  53. # a_src: 源节点注意力向量 (发出边的权重)
  54. # a_dst: 目标节点注意力向量 (接收边的权重)
  55. # 分离参数使模型能够学习非对称因果关系
  56. self.a_src = nn.Parameter(torch.empty(size=(out_features, 1)))
  57. self.a_dst = nn.Parameter(torch.empty(size=(out_features, 1)))
  58. nn.init.xavier_uniform_(self.a_src.data, gain=1.414)
  59. nn.init.xavier_uniform_(self.a_dst.data, gain=1.414)
  60. # LeakyReLU: 允许负值有小梯度,防止神经元死亡
  61. self.leakyrelu = nn.LeakyReLU(self.alpha)
  62. def forward(self, h, adj):
  63. """
  64. 前向传播
  65. 计算流程:
  66. 1. 线性变换: Wh = h @ W
  67. 2. 计算源/目标节点注意力分数
  68. 3. 构建注意力矩阵: e_ij = LeakyReLU(src_i + dst_j)
  69. 4. 应用邻接掩码 (不存在的边设为-∞)
  70. 5. Softmax归一化得到注意力权重 α_ij
  71. 6. 加权聚合邻居特征: h_i' = Σ_j α_ij·Wh_j
  72. Args:
  73. h (Tensor): 输入特征 (batch_size, num_nodes, in_features)
  74. 例: (32, 145, 1)
  75. adj (Tensor): 邻接矩阵 (num_nodes, num_nodes)
  76. adj[i,j]=1 表示节点i→j存在有向边
  77. Returns:
  78. Tensor: 输出特征 (batch_size, num_nodes, out_features)
  79. 经过图注意力聚合后的节点特征
  80. """
  81. batch_size = h.size(0)
  82. num_nodes = h.size(1)
  83. # Step 1: 线性变换 (batch, nodes, in_features) → (batch, nodes, out_features)
  84. Wh = torch.matmul(h, self.W)
  85. # Step 2-3: 计算源/目标节点注意力分数
  86. a_input_src = torch.matmul(Wh, self.a_src) # 源节点分数 (信息发送方)
  87. a_input_dst = torch.matmul(Wh, self.a_dst) # 目标节点分数 (信息接收方)
  88. # Step 4: 构建注意力矩阵 (广播: src[i] + dst[j]^T)
  89. # e[i,j] = a_src^T·Wh_i + a_dst^T·Wh_j
  90. e = a_input_src + a_input_dst.transpose(1, 2)
  91. e = self.leakyrelu(e)
  92. # Step 5: 应用邻接掩码 (不存在的边设为-9e15,softmax后≈0)
  93. zero_vec = -9e15 * torch.ones_like(e)
  94. attention = torch.where(adj > 0, e, zero_vec)
  95. # Step 6: Softmax归一化 (dim=2: 对每个节点的所有邻居归一化)
  96. attention = F.softmax(attention, dim=2)
  97. attention = F.dropout(attention, self.dropout, training=self.training)
  98. # Step 7: 加权聚合邻居特征 h_i' = Σ_j α_ij·Wh_j
  99. h_prime = torch.matmul(attention, Wh)
  100. # 中间层使用ELU激活,输出层保持线性
  101. if self.concat:
  102. return F.elu(h_prime)
  103. else:
  104. return h_prime
  105. def __repr__(self):
  106. return self.__class__.__name__ + f'({self.in_features} -> {self.out_features})'
  107. class GAT(nn.Module):
  108. """
  109. 多层图注意力网络 (Multi-layer Graph Attention Network)
  110. 组合多个图注意力层构建完整的GAT模型,采用多头注意力机制从不同视角捕捉节点关系。
  111. 网络结构:
  112. 输入 → 多头注意力层 (nheads个并行) → 拼接 → 输出注意力层 → 输出
  113. Args:
  114. nfeat (int): 输入特征维度,例: 1
  115. nhid (int): 隐藏层维度 (每个注意力头的输出维度),推荐: 32-128
  116. noutput (int): 输出维度 (目标变量数量),例: 47
  117. dropout (float): Dropout概率 [0,1],例: 0.3
  118. alpha (float): LeakyReLU负斜率,例: 0.2
  119. nheads (int): 注意力头数量,例: 4
  120. 多头注意力机制:
  121. 多个独立注意力头并行学习不同关系模式 (直接因果、间接影响、周期性等)
  122. 最后拼接所有头的输出,形成丰富的特征表示
  123. 维度变化:
  124. (batch, 145, 1) → [多头] → (batch, 145, nhid×nheads)
  125. → [输出层] → (batch, 145, noutput)
  126. Example:
  127. >>> model = GAT(nfeat=1, nhid=64, noutput=47, dropout=0.3, alpha=0.2, nheads=4)
  128. >>> x = torch.randn(32, 145, 1)
  129. >>> adj = torch.ones(145, 145)
  130. >>> output = model(x, adj) # (32, 145, 47)
  131. """
  132. def __init__(self, nfeat, nhid, noutput, dropout, alpha, nheads):
  133. super(GAT, self).__init__()
  134. self.dropout = dropout
  135. # 多头注意力层: 创建 nheads 个独立的图注意力层
  136. self.attentions = [
  137. GraphAttentionLayer(
  138. in_features=nfeat,
  139. out_features=nhid,
  140. dropout=dropout,
  141. alpha=alpha,
  142. concat=True # 中间层使用ELU激活
  143. )
  144. for _ in range(nheads)
  145. ]
  146. # 注册为子模块,使参数可被自动追踪和优化
  147. for i, attention in enumerate(self.attentions):
  148. self.add_module(f'attention_{i}', attention)
  149. # 输出注意力层: 输入维度 = nhid×nheads (拼接后)
  150. self.out_att = GraphAttentionLayer(
  151. in_features=nhid * nheads,
  152. out_features=noutput,
  153. dropout=dropout,
  154. alpha=alpha,
  155. concat=False # 输出层保持线性
  156. )
  157. def forward(self, x, adj):
  158. """
  159. 前向传播
  160. 计算流程:
  161. 1. 输入dropout
  162. 2. 多头注意力并行计算并拼接
  163. 3. 中间dropout
  164. 4. 输出层 + ELU激活
  165. Args:
  166. x (Tensor): 输入特征 (batch_size, num_nodes, nfeat)
  167. 例: (32, 145, 1)
  168. adj (Tensor): 邻接矩阵 (num_nodes, num_nodes)
  169. adj[i,j]=1 表示特征i对特征j有因果影响
  170. Returns:
  171. Tensor: 输出特征 (batch_size, num_nodes, noutput)
  172. 例: (32, 145, 47)
  173. """
  174. # 输入dropout (防止过拟合)
  175. x = F.dropout(x, self.dropout, training=self.training)
  176. # 多头注意力并行计算 + 拼接
  177. # (batch, nodes, nfeat) → nheads × (batch, nodes, nhid) → (batch, nodes, nhid×nheads)
  178. x = torch.cat([att(x, adj) for att in self.attentions], dim=2)
  179. # 中间dropout
  180. x = F.dropout(x, self.dropout, training=self.training)
  181. # 输出层 + ELU激活
  182. x = F.elu(self.out_att(x, adj))
  183. return x