# gat_lstm.py import torch import torch.nn as nn class SingleGATLSTM(nn.Module): """单个子模型:预测1个目标指标""" def __init__(self, args): super(SingleGATLSTM, self).__init__() self.args = args self.lstm = nn.LSTM( input_size=args.feature_num, hidden_size=args.hidden_size, num_layers=args.num_layers, batch_first=True ) self.final_linear = nn.Sequential( nn.Linear(args.hidden_size, args.hidden_size), nn.LeakyReLU(0.01), nn.Dropout(args.dropout * 0.4), nn.Linear(args.hidden_size, args.output_size) ) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, x): batch_size, seq_len, feature_num = x.size() lstm_out, _ = self.lstm(x) last_out = lstm_out[:, -1, :] output = self.final_linear(last_out) return output class GAT_LSTM(nn.Module): """总模型:包含多个SingleGATLSTM子模型""" def __init__(self, args): super(GAT_LSTM, self).__init__() self.args = args # 创建4个独立模型(对应labels_num=4) self.models = nn.ModuleList([SingleGATLSTM(args) for _ in range(args.labels_num)]) def set_edge_index(self, edge_index): self.edge_index = edge_index def forward(self, x): outputs = [] for model in self.models: outputs.append(model(x)) return torch.cat(outputs, dim=1)