gat_lstm.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. # gat_lstm.py
  2. import torch
  3. import torch.nn as nn
  4. from config import config
  5. class SingleGATLSTM(nn.Module):
  6. def __init__(self):
  7. super(SingleGATLSTM, self).__init__()
  8. self.lstm = nn.LSTM(
  9. input_size=config.FEATURE_NUM,
  10. hidden_size=config.HIDDEN_SIZE,
  11. num_layers=config.NUM_LAYERS,
  12. batch_first=True
  13. )
  14. self.final_linear = nn.Sequential(
  15. nn.Linear(config.HIDDEN_SIZE, config.HIDDEN_SIZE),
  16. nn.LeakyReLU(0.01),
  17. nn.Dropout(config.DROPOUT * 0.4),
  18. nn.Linear(config.HIDDEN_SIZE, config.OUTPUT_SIZE)
  19. )
  20. self._init_weights()
  21. def _init_weights(self):
  22. for m in self.modules():
  23. if isinstance(m, nn.Linear):
  24. nn.init.xavier_uniform_(m.weight)
  25. if m.bias is not None: nn.init.zeros_(m.bias)
  26. def forward(self, x):
  27. lstm_out, _ = self.lstm(x)
  28. last_out = lstm_out[:, -1, :]
  29. return self.final_linear(last_out)
  30. class GAT_LSTM(nn.Module):
  31. def __init__(self):
  32. super(GAT_LSTM, self).__init__()
  33. self.models = nn.ModuleList([SingleGATLSTM() for _ in range(config.LABELS_NUM)])
  34. def set_edge_index(self, edge_index):
  35. self.edge_index = edge_index
  36. def forward(self, x):
  37. outputs = [model(x) for model in self.models]
  38. return torch.cat(outputs, dim=1)