gat_lstm.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # gat_lstm.py
  2. import torch
  3. import torch.nn as nn
  4. # 单个独立模型(对应1个因变量)
  5. class SingleGATLSTM(nn.Module):
  6. def __init__(self, args):
  7. super(SingleGATLSTM, self).__init__()
  8. self.args = args
  9. # 独立的LSTM层
  10. self.lstm = nn.LSTM(
  11. input_size=args.feature_num,
  12. hidden_size=args.hidden_size,
  13. num_layers=args.num_layers,
  14. batch_first=True
  15. )
  16. # 独立的输出层
  17. self.final_linear = nn.Sequential(
  18. nn.Linear(args.hidden_size, args.hidden_size),
  19. nn.LeakyReLU(0.01),
  20. nn.Dropout(args.dropout * 0.4),
  21. nn.Linear(args.hidden_size, args.output_size)
  22. )
  23. self._init_weights()
  24. def _init_weights(self):
  25. # 初始化线性层权重
  26. for m in self.modules():
  27. if isinstance(m, nn.Linear):
  28. nn.init.xavier_uniform_(m.weight)
  29. if m.bias is not None:
  30. nn.init.zeros_(m.bias)
  31. elif isinstance(m, nn.BatchNorm1d):
  32. nn.init.constant_(m.weight, 1)
  33. nn.init.constant_(m.bias, 0)
  34. # 初始化LSTM权重
  35. for name, param in self.lstm.named_parameters():
  36. if 'weight_ih' in name:
  37. nn.init.xavier_uniform_(param.data)
  38. elif 'weight_hh' in name:
  39. nn.init.orthogonal_(param.data)
  40. elif 'bias' in name:
  41. param.data.fill_(0)
  42. n = param.size(0)
  43. start, end = n // 4, n // 2
  44. param.data[start:end].fill_(1)
  45. def forward(self, x):
  46. # LSTM处理输入序列
  47. batch_size, seq_len, feature_num = x.size()
  48. lstm_out, _ = self.lstm(x)
  49. # 取最后一个时间步的输出
  50. last_out = lstm_out[:, -1, :]
  51. # 输出层预测
  52. output = self.final_linear(last_out)
  53. return output # [batch_size, output_size]
  54. # 16个独立模型的容器(总模型)
  55. class GAT_LSTM(nn.Module):
  56. def __init__(self, args):
  57. super(GAT_LSTM, self).__init__()
  58. self.args = args
  59. # 创建16个独立模型(数量由labels_num指定)
  60. self.models = nn.ModuleList([SingleGATLSTM(args) for _ in range(args.labels_num)])
  61. def forward(self, x):
  62. # 收集所有模型的输出并拼接
  63. outputs = []
  64. for model in self.models:
  65. outputs.append(model(x)) # 每个输出为[batch, output_size]
  66. return torch.cat(outputs, dim=1) # 拼接后[batch, output_size * labels_num]