# gat_lstm.py import torch import torch.nn as nn from config import config class SingleGATLSTM(nn.Module): def __init__(self): super(SingleGATLSTM, self).__init__() self.lstm = nn.LSTM( input_size=config.FEATURE_NUM, hidden_size=config.HIDDEN_SIZE, num_layers=config.NUM_LAYERS, batch_first=True ) self.final_linear = nn.Sequential( nn.Linear(config.HIDDEN_SIZE, config.HIDDEN_SIZE), nn.LeakyReLU(0.01), nn.Dropout(config.DROPOUT * 0.4), nn.Linear(config.HIDDEN_SIZE, config.OUTPUT_SIZE) ) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, x): lstm_out, _ = self.lstm(x) last_out = lstm_out[:, -1, :] return self.final_linear(last_out) class GAT_LSTM(nn.Module): def __init__(self): super(GAT_LSTM, self).__init__() self.models = nn.ModuleList([SingleGATLSTM() for _ in range(config.LABELS_NUM)]) def set_edge_index(self, edge_index): self.edge_index = edge_index def forward(self, x): outputs = [model(x) for model in self.models] return torch.cat(outputs, dim=1)