# predict.py import os import torch import joblib import argparse import pandas as pd import numpy as np from datetime import datetime, timedelta from config import config class RealTimePredictor: def __init__(self): self.device = torch.device(f"cuda:{config.DEVICE_ID}" if torch.cuda.is_available() else "cpu") if not os.path.exists(config.SCALER_PATH): raise FileNotFoundError(f"未找到归一化文件: {config.SCALER_PATH}") self.scaler = joblib.load(config.SCALER_PATH) from gat_lstm import GAT_LSTM self.model = GAT_LSTM().to(self.device) self.model.load_state_dict(torch.load(config.MODEL_PATH, map_location=self.device, weights_only=True)) self.model.eval() def _preprocess(self, df): data = df.copy() if 'datetime' in data.columns: data = data.rename(columns={'datetime': 'index'}) if 'index' not in data.columns: data['index'] = pd.date_range(end=datetime.now(), periods=len(data), freq='min') data['index'] = pd.to_datetime(data['index']) if len(data) < config.SEQ_LEN: pad_len = config.SEQ_LEN - len(data) pads = pd.concat([data.iloc[0:1]] * pad_len, ignore_index=True) for i in range(pad_len): pads.at[i, 'index'] = data['index'].iloc[0] - timedelta(minutes=(pad_len-i)) data = pd.concat([pads, data], ignore_index=True) business_cols = config.REQUIRED_COLUMNS[1:] data_business = data[business_cols] date_col = data['index'] minute_of_day = date_col.dt.hour * 60 + date_col.dt.minute day_of_year = date_col.dt.dayofyear time_features = pd.DataFrame({ 'minute_sin': np.sin(2 * np.pi * minute_of_day / 1440), 'minute_cos': np.cos(2 * np.pi * minute_of_day / 1440), 'day_year_sin': np.sin(2 * np.pi * day_of_year / 366), 'day_year_cos': np.cos(2 * np.pi * day_of_year / 366) }) data_to_scale = pd.concat([time_features.reset_index(drop=True), data_business.reset_index(drop=True)], axis=1) return self.scaler.transform(data_to_scale) def predict(self, df): processed_data = self._preprocess(df) input_seq = processed_data[-config.SEQ_LEN:] input_tensor = torch.tensor(input_seq, dtype=torch.float32).unsqueeze(0).to(self.device) with torch.no_grad(): output = self.model(input_tensor) preds = output.cpu().numpy().reshape(config.OUTPUT_SIZE, config.LABELS_NUM) target_min = self.scaler.min_[-config.LABELS_NUM:] target_scale = self.scaler.scale_[-config.LABELS_NUM:] real_preds = np.abs((preds - target_min) / target_scale) return real_preds.tolist() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-p', '--plant', required=True) args = parser.parse_args() config.load(args.plant) predictor = RealTimePredictor() mock_data = pd.DataFrame() mock_data['index'] = pd.date_range(end=datetime.now(), periods=15, freq='min') for col in config.REQUIRED_COLUMNS[1:]: mock_data[col] = np.random.rand(15) * 10 print(predictor.predict(mock_data))