gat_lstm.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # gat_lstm.py
  2. import torch
  3. import torch.nn as nn
  4. class SingleGATLSTM(nn.Module):
  5. """单个子模型:预测1个目标指标"""
  6. def __init__(self, args):
  7. super(SingleGATLSTM, self).__init__()
  8. self.args = args
  9. self.lstm = nn.LSTM(
  10. input_size=args.feature_num,
  11. hidden_size=args.hidden_size,
  12. num_layers=args.num_layers,
  13. batch_first=True
  14. )
  15. self.final_linear = nn.Sequential(
  16. nn.Linear(args.hidden_size, args.hidden_size),
  17. nn.LeakyReLU(0.01),
  18. nn.Dropout(args.dropout * 0.4),
  19. nn.Linear(args.hidden_size, args.output_size)
  20. )
  21. self._init_weights()
  22. def _init_weights(self):
  23. for m in self.modules():
  24. if isinstance(m, nn.Linear):
  25. nn.init.xavier_uniform_(m.weight)
  26. if m.bias is not None: nn.init.zeros_(m.bias)
  27. def forward(self, x):
  28. batch_size, seq_len, feature_num = x.size()
  29. lstm_out, _ = self.lstm(x)
  30. last_out = lstm_out[:, -1, :]
  31. output = self.final_linear(last_out)
  32. return output
  33. class GAT_LSTM(nn.Module):
  34. """总模型:包含多个SingleGATLSTM子模型"""
  35. def __init__(self, args):
  36. super(GAT_LSTM, self).__init__()
  37. self.args = args
  38. # 创建4个独立模型(对应labels_num=4)
  39. self.models = nn.ModuleList([SingleGATLSTM(args) for _ in range(args.labels_num)])
  40. def set_edge_index(self, edge_index):
  41. self.edge_index = edge_index
  42. def forward(self, x):
  43. outputs = []
  44. for model in self.models:
  45. outputs.append(model(x))
  46. return torch.cat(outputs, dim=1)