gat_lstm.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # gat_lstm.py
  2. import torch
  3. import torch.nn as nn # PyTorch神经网络模块
  4. # 单个独立模型(对应1个因变量)
  5. class SingleGATLSTM(nn.Module):
  6. def __init__(self, args):
  7. """
  8. 单个子模型:包含GAT-LSTM层和输出层,用于预测1个目标指标
  9. 参数:
  10. args: 配置参数(含特征数、隐藏层大小等)
  11. """
  12. super(SingleGATLSTM, self).__init__()
  13. self.args = args
  14. # 独立的LSTM层
  15. self.lstm = nn.LSTM(
  16. input_size=args.feature_num,
  17. hidden_size=args.hidden_size,
  18. num_layers=args.num_layers,
  19. batch_first=True
  20. )
  21. # 独立的输出层
  22. self.final_linear = nn.Sequential(
  23. nn.Linear(args.hidden_size, args.hidden_size),
  24. nn.LeakyReLU(0.01),
  25. nn.Dropout(args.dropout * 0.4),
  26. nn.Linear(args.hidden_size, args.output_size)
  27. )
  28. self._init_weights()
  29. def _init_weights(self):
  30. """初始化网络权重,加速模型收敛"""
  31. for m in self.modules():
  32. if isinstance(m, nn.Linear):
  33. nn.init.xavier_uniform_(m.weight)
  34. if m.bias is not None:
  35. nn.init.zeros_(m.bias)
  36. elif isinstance(m, nn.BatchNorm1d):
  37. nn.init.constant_(m.weight, 1)
  38. nn.init.constant_(m.bias, 0)
  39. # 初始化LSTM权重
  40. for name, param in self.lstm.named_parameters():
  41. if 'weight_ih' in name:
  42. nn.init.xavier_uniform_(param.data)
  43. elif 'weight_hh' in name:
  44. nn.init.orthogonal_(param.data)
  45. elif 'bias' in name:
  46. param.data.fill_(0)
  47. n = param.size(0)
  48. start, end = n // 4, n // 2
  49. param.data[start:end].fill_(1)
  50. def forward(self, x):
  51. """
  52. 前向传播:输入序列经过LSTM和输出层,得到预测结果
  53. 参数:
  54. x: 输入序列,形状为[batch_size, seq_len, feature_num]
  55. 返回:
  56. output: 预测结果,形状为[batch_size, output_size]
  57. """
  58. batch_size, seq_len, feature_num = x.size()
  59. lstm_out, _ = self.lstm(x)
  60. # 取最后一个时间步的输出
  61. last_out = lstm_out[:, -1, :]
  62. # 输出层预测
  63. output = self.final_linear(last_out)
  64. return output # [batch_size, output_size]
  65. # 16个独立模型的容器(总模型)
  66. class GAT_LSTM(nn.Module):
  67. def __init__(self, args):
  68. """
  69. 总模型:包含多个SingleGATLSTM子模型,分别预测不同的目标
  70. 参数:
  71. args: 配置参数(含labels_num,即子模型数量)
  72. """
  73. super(GAT_LSTM, self).__init__()
  74. self.args = args
  75. # 创建16个独立模型(数量由labels_num指定)
  76. self.models = nn.ModuleList([SingleGATLSTM(args) for _ in range(args.labels_num)])
  77. def set_edge_index(self, edge_index):
  78. self.edge_index = edge_index # 将传入的edge_index保存到模型内部
  79. def forward(self, x):
  80. """
  81. 前向传播:所有子模型并行处理输入,拼接预测结果
  82. 参数:
  83. x: 输入序列,形状为[batch_size, seq_len, feature_num]
  84. 返回:
  85. 拼接后的预测结果,形状为[batch_size, output_size * labels_num]
  86. """
  87. outputs = []
  88. for model in self.models:
  89. outputs.append(model(x)) # 每个输出为[batch, output_size]
  90. return torch.cat(outputs, dim=1) # 拼接后[batch, output_size * labels_num]