| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226 |
- """
- 有向图注意力网络 (Directed Graph Attention Network)
- 实现基于有向图的注意力机制,用于建模节点间的非对称因果关系。
- 与传统GAT不同,本实现分离源节点和目标节点的注意力参数,更适合因果推理任务。
- 核心特性:
- - 有向注意力: 源节点和目标节点使用独立的注意力参数
- - 多头机制: 并行学习多种关系模式
- - 邻接掩码: 仅在图中存在的边上计算注意力
- 技术实现:
- - 框架: PyTorch
- - 注意力: 加性注意力 (Additive Attention)
- - 激活函数: LeakyReLU (α=0.2)
- """
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class GraphAttentionLayer(nn.Module):
- """
- 有向图注意力层 (Directed Graph Attention Layer)
- 实现单层有向图注意力机制,是GAT模型的基本构建块。
- 通过分离源节点和目标节点的注意力参数,支持建模非对称的因果关系。
- 核心思想:
- 传统GAT使用对称注意力权重 (A→B 和 B→A 权重相同)
- 有向GAT分离源节点和目标节点参数,学习方向性的因果影响
- 注意力计算:
- e_ij = LeakyReLU(a_src^T·Wh_i + a_dst^T·Wh_j)
- α_ij = softmax_j(e_ij)
- h_i' = σ(Σ_j α_ij·Wh_j)
- Args:
- in_features (int): 输入特征维度
- out_features (int): 输出特征维度
- dropout (float): Dropout概率 [0,1]
- alpha (float): LeakyReLU负斜率,默认0.2
- concat (bool): 是否使用ELU激活 (True用于中间层,False用于输出层)
- Example:
- >>> layer = GraphAttentionLayer(1, 64, dropout=0.3, alpha=0.2, concat=True)
- >>> h = torch.randn(32, 145, 1) # (batch, nodes, features)
- >>> adj = torch.ones(145, 145) # 邻接矩阵
- >>> output = layer(h, adj) # (32, 145, 64)
- """
- def __init__(self, in_features, out_features, dropout, alpha, concat=True):
- super(GraphAttentionLayer, self).__init__()
- self.dropout = dropout
- self.in_features = in_features
- self.out_features = out_features
- self.alpha = alpha
- self.concat = concat
-
- # 特征变换矩阵: 输入特征 → 输出特征空间
- # Xavier初始化保证前向/反向传播时方差稳定
- self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
- nn.init.xavier_uniform_(self.W.data, gain=1.414) # gain=1.414 适配LeakyReLU
- # 有向注意力参数 (核心创新)
- # a_src: 源节点注意力向量 (发出边的权重)
- # a_dst: 目标节点注意力向量 (接收边的权重)
- # 分离参数使模型能够学习非对称因果关系
- self.a_src = nn.Parameter(torch.empty(size=(out_features, 1)))
- self.a_dst = nn.Parameter(torch.empty(size=(out_features, 1)))
- nn.init.xavier_uniform_(self.a_src.data, gain=1.414)
- nn.init.xavier_uniform_(self.a_dst.data, gain=1.414)
- # LeakyReLU: 允许负值有小梯度,防止神经元死亡
- self.leakyrelu = nn.LeakyReLU(self.alpha)
-
- def forward(self, h, adj):
- """
- 前向传播
- 计算流程:
- 1. 线性变换: Wh = h @ W
- 2. 计算源/目标节点注意力分数
- 3. 构建注意力矩阵: e_ij = LeakyReLU(src_i + dst_j)
- 4. 应用邻接掩码 (不存在的边设为-∞)
- 5. Softmax归一化得到注意力权重 α_ij
- 6. 加权聚合邻居特征: h_i' = Σ_j α_ij·Wh_j
- Args:
- h (Tensor): 输入特征 (batch_size, num_nodes, in_features)
- 例: (32, 145, 1)
- adj (Tensor): 邻接矩阵 (num_nodes, num_nodes)
- adj[i,j]=1 表示节点i→j存在有向边
- Returns:
- Tensor: 输出特征 (batch_size, num_nodes, out_features)
- 经过图注意力聚合后的节点特征
- """
- batch_size = h.size(0)
- num_nodes = h.size(1)
- # Step 1: 线性变换 (batch, nodes, in_features) → (batch, nodes, out_features)
- Wh = torch.matmul(h, self.W)
- # Step 2-3: 计算源/目标节点注意力分数
- a_input_src = torch.matmul(Wh, self.a_src) # 源节点分数 (信息发送方)
- a_input_dst = torch.matmul(Wh, self.a_dst) # 目标节点分数 (信息接收方)
- # Step 4: 构建注意力矩阵 (广播: src[i] + dst[j]^T)
- # e[i,j] = a_src^T·Wh_i + a_dst^T·Wh_j
- e = a_input_src + a_input_dst.transpose(1, 2)
- e = self.leakyrelu(e)
- # Step 5: 应用邻接掩码 (不存在的边设为-9e15,softmax后≈0)
- zero_vec = -9e15 * torch.ones_like(e)
- attention = torch.where(adj > 0, e, zero_vec)
- # Step 6: Softmax归一化 (dim=2: 对每个节点的所有邻居归一化)
- attention = F.softmax(attention, dim=2)
- attention = F.dropout(attention, self.dropout, training=self.training)
- # Step 7: 加权聚合邻居特征 h_i' = Σ_j α_ij·Wh_j
- h_prime = torch.matmul(attention, Wh)
- # 中间层使用ELU激活,输出层保持线性
- if self.concat:
- return F.elu(h_prime)
- else:
- return h_prime
-
- def __repr__(self):
- return self.__class__.__name__ + f'({self.in_features} -> {self.out_features})'
- class GAT(nn.Module):
- """
- 多层图注意力网络 (Multi-layer Graph Attention Network)
- 组合多个图注意力层构建完整的GAT模型,采用多头注意力机制从不同视角捕捉节点关系。
- 网络结构:
- 输入 → 多头注意力层 (nheads个并行) → 拼接 → 输出注意力层 → 输出
- Args:
- nfeat (int): 输入特征维度,例: 1
- nhid (int): 隐藏层维度 (每个注意力头的输出维度),推荐: 32-128
- noutput (int): 输出维度 (目标变量数量),例: 47
- dropout (float): Dropout概率 [0,1],例: 0.3
- alpha (float): LeakyReLU负斜率,例: 0.2
- nheads (int): 注意力头数量,例: 4
- 多头注意力机制:
- 多个独立注意力头并行学习不同关系模式 (直接因果、间接影响、周期性等)
- 最后拼接所有头的输出,形成丰富的特征表示
- 维度变化:
- (batch, 145, 1) → [多头] → (batch, 145, nhid×nheads)
- → [输出层] → (batch, 145, noutput)
- Example:
- >>> model = GAT(nfeat=1, nhid=64, noutput=47, dropout=0.3, alpha=0.2, nheads=4)
- >>> x = torch.randn(32, 145, 1)
- >>> adj = torch.ones(145, 145)
- >>> output = model(x, adj) # (32, 145, 47)
- """
- def __init__(self, nfeat, nhid, noutput, dropout, alpha, nheads):
- super(GAT, self).__init__()
- self.dropout = dropout
-
- # 多头注意力层: 创建 nheads 个独立的图注意力层
- self.attentions = [
- GraphAttentionLayer(
- in_features=nfeat,
- out_features=nhid,
- dropout=dropout,
- alpha=alpha,
- concat=True # 中间层使用ELU激活
- )
- for _ in range(nheads)
- ]
- # 注册为子模块,使参数可被自动追踪和优化
- for i, attention in enumerate(self.attentions):
- self.add_module(f'attention_{i}', attention)
- # 输出注意力层: 输入维度 = nhid×nheads (拼接后)
- self.out_att = GraphAttentionLayer(
- in_features=nhid * nheads,
- out_features=noutput,
- dropout=dropout,
- alpha=alpha,
- concat=False # 输出层保持线性
- )
-
- def forward(self, x, adj):
- """
- 前向传播
- 计算流程:
- 1. 输入dropout
- 2. 多头注意力并行计算并拼接
- 3. 中间dropout
- 4. 输出层 + ELU激活
- Args:
- x (Tensor): 输入特征 (batch_size, num_nodes, nfeat)
- 例: (32, 145, 1)
- adj (Tensor): 邻接矩阵 (num_nodes, num_nodes)
- adj[i,j]=1 表示特征i对特征j有因果影响
- Returns:
- Tensor: 输出特征 (batch_size, num_nodes, noutput)
- 例: (32, 145, 47)
- """
- # 输入dropout (防止过拟合)
- x = F.dropout(x, self.dropout, training=self.training)
- # 多头注意力并行计算 + 拼接
- # (batch, nodes, nfeat) → nheads × (batch, nodes, nhid) → (batch, nodes, nhid×nheads)
- x = torch.cat([att(x, adj) for att in self.attentions], dim=2)
- # 中间dropout
- x = F.dropout(x, self.dropout, training=self.training)
- # 输出层 + ELU激活
- x = F.elu(self.out_att(x, adj))
- return x
|