| 1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- # gat_lstm.py
- import torch
- import torch.nn as nn
- from config import config
- class SingleGATLSTM(nn.Module):
- def __init__(self):
- super(SingleGATLSTM, self).__init__()
- self.lstm = nn.LSTM(
- input_size=config.FEATURE_NUM,
- hidden_size=config.HIDDEN_SIZE,
- num_layers=config.NUM_LAYERS,
- batch_first=True
- )
- self.final_linear = nn.Sequential(
- nn.Linear(config.HIDDEN_SIZE, config.HIDDEN_SIZE),
- nn.LeakyReLU(0.01),
- nn.Dropout(config.DROPOUT * 0.4),
- nn.Linear(config.HIDDEN_SIZE, config.OUTPUT_SIZE)
- )
- self._init_weights()
-
- def _init_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Linear):
- nn.init.xavier_uniform_(m.weight)
- if m.bias is not None: nn.init.zeros_(m.bias)
- def forward(self, x):
- lstm_out, _ = self.lstm(x)
- last_out = lstm_out[:, -1, :]
- return self.final_linear(last_out)
- class GAT_LSTM(nn.Module):
- def __init__(self):
- super(GAT_LSTM, self).__init__()
- self.models = nn.ModuleList([SingleGATLSTM() for _ in range(config.LABELS_NUM)])
-
- def set_edge_index(self, edge_index):
- self.edge_index = edge_index
-
- def forward(self, x):
- outputs = [model(x) for model in self.models]
- return torch.cat(outputs, dim=1)
|