predict.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # predict.py
  2. import os
  3. import torch
  4. import joblib
  5. import argparse
  6. import pandas as pd
  7. import numpy as np
  8. from datetime import datetime, timedelta
  9. from config import config
  10. class RealTimePredictor:
  11. def __init__(self):
  12. self.device = torch.device(f"cuda:{config.DEVICE_ID}" if torch.cuda.is_available() else "cpu")
  13. if not os.path.exists(config.SCALER_PATH):
  14. raise FileNotFoundError(f"未找到归一化文件: {config.SCALER_PATH}")
  15. self.scaler = joblib.load(config.SCALER_PATH)
  16. from gat_lstm import GAT_LSTM
  17. self.model = GAT_LSTM().to(self.device)
  18. self.model.load_state_dict(torch.load(config.MODEL_PATH, map_location=self.device, weights_only=True))
  19. self.model.eval()
  20. def _preprocess(self, df):
  21. data = df.copy()
  22. if 'datetime' in data.columns: data = data.rename(columns={'datetime': 'index'})
  23. if 'index' not in data.columns:
  24. data['index'] = pd.date_range(end=datetime.now(), periods=len(data), freq='min')
  25. data['index'] = pd.to_datetime(data['index'])
  26. if len(data) < config.SEQ_LEN:
  27. pad_len = config.SEQ_LEN - len(data)
  28. pads = pd.concat([data.iloc[0:1]] * pad_len, ignore_index=True)
  29. for i in range(pad_len):
  30. pads.at[i, 'index'] = data['index'].iloc[0] - timedelta(minutes=(pad_len-i))
  31. data = pd.concat([pads, data], ignore_index=True)
  32. business_cols = config.REQUIRED_COLUMNS[1:]
  33. data_business = data[business_cols]
  34. date_col = data['index']
  35. minute_of_day = date_col.dt.hour * 60 + date_col.dt.minute
  36. day_of_year = date_col.dt.dayofyear
  37. time_features = pd.DataFrame({
  38. 'minute_sin': np.sin(2 * np.pi * minute_of_day / 1440),
  39. 'minute_cos': np.cos(2 * np.pi * minute_of_day / 1440),
  40. 'day_year_sin': np.sin(2 * np.pi * day_of_year / 366),
  41. 'day_year_cos': np.cos(2 * np.pi * day_of_year / 366)
  42. })
  43. data_to_scale = pd.concat([time_features.reset_index(drop=True), data_business.reset_index(drop=True)], axis=1)
  44. return self.scaler.transform(data_to_scale)
  45. def predict(self, df):
  46. processed_data = self._preprocess(df)
  47. input_seq = processed_data[-config.SEQ_LEN:]
  48. input_tensor = torch.tensor(input_seq, dtype=torch.float32).unsqueeze(0).to(self.device)
  49. with torch.no_grad():
  50. output = self.model(input_tensor)
  51. preds = output.cpu().numpy().reshape(config.OUTPUT_SIZE, config.LABELS_NUM)
  52. target_min = self.scaler.min_[-config.LABELS_NUM:]
  53. target_scale = self.scaler.scale_[-config.LABELS_NUM:]
  54. real_preds = np.abs((preds - target_min) / target_scale)
  55. return real_preds.tolist()
  56. if __name__ == "__main__":
  57. parser = argparse.ArgumentParser()
  58. parser.add_argument('-p', '--plant', required=True)
  59. args = parser.parse_args()
  60. config.load(args.plant)
  61. predictor = RealTimePredictor()
  62. mock_data = pd.DataFrame()
  63. mock_data['index'] = pd.date_range(end=datetime.now(), periods=15, freq='min')
  64. for col in config.REQUIRED_COLUMNS[1:]:
  65. mock_data[col] = np.random.rand(15) * 10
  66. print(predictor.predict(mock_data))