|
@@ -0,0 +1,271 @@
|
|
|
|
|
+# predict.py
|
|
|
|
|
+import os
|
|
|
|
|
+import torch
|
|
|
|
|
+import joblib
|
|
|
|
|
+import pandas as pd
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+from datetime import datetime, timedelta
|
|
|
|
|
+from gat_lstm import GAT_LSTM
|
|
|
|
|
+
|
|
|
|
|
+class RealTimePredictor:
|
|
|
|
|
+ def __init__(self, model_path='model.pth', scaler_path='scaler.pkl', device=None):
|
|
|
|
|
+ """
|
|
|
|
|
+ 初始化预测器
|
|
|
|
|
+ """
|
|
|
|
|
+ # 1. 参数配置 (与训练 args.py 保持一致)
|
|
|
|
|
+ self.seq_len = 10 # 输入序列长度
|
|
|
|
|
+ self.feature_num = 42 # 输入特征数 (4时间编码 + 38业务特征)
|
|
|
|
|
+ self.labels_num = 4 # 输出标签数
|
|
|
|
|
+ self.hidden_size = 64
|
|
|
|
|
+ self.num_layers = 1
|
|
|
|
|
+ self.output_size = 5 # 预测未来 5 步
|
|
|
|
|
+ self.dropout = 0
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 设备与资源加载
|
|
|
|
|
+ self.device = device if device else torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
+ self.model_path = model_path
|
|
|
|
|
+ self.scaler_path = scaler_path
|
|
|
|
|
+
|
|
|
|
|
+ # 加载归一化器
|
|
|
|
|
+ if not os.path.exists(self.scaler_path):
|
|
|
|
|
+ raise FileNotFoundError(f"未找到归一化文件: {self.scaler_path},请确保已完成训练。")
|
|
|
|
|
+ self.scaler = joblib.load(self.scaler_path)
|
|
|
|
|
+
|
|
|
|
|
+ # 加载模型
|
|
|
|
|
+ self._load_model()
|
|
|
|
|
+
|
|
|
|
|
+ # 定义必须存在的列名 (39个,包含index,顺序必须固定)
|
|
|
|
|
+ self.required_columns = [
|
|
|
|
|
+ 'index',
|
|
|
|
|
+ "AR.1#UF_JSFLOW_O", # 1#UF进水流量
|
|
|
|
|
+ "AR.2#UF_JSFLOW_O", # 2#UF进水流量
|
|
|
|
|
+ "AR.1#RO_JSFLOW_O", # 1#RO进水流量
|
|
|
|
|
+ "AR.2#RO_JSFLOW_O", # 2#RO进水流量
|
|
|
|
|
+ "AR.1#UF_JSPRESS_O", # 1#UF进水压力
|
|
|
|
|
+ "AR.2#UF_JSPRESS_O", # 2#UF进水压力
|
|
|
|
|
+ "AR.1#RO_JSPRESS_O", # 1#RO进水压力
|
|
|
|
|
+ "AR.2#RO_JSPRESS_O", # 2#RO进水压力
|
|
|
|
|
+ "AR.1#RO_EDJSPRESS_O", # 1#RO二段进水压力
|
|
|
|
|
+ "AR.1#RO_SDJSPRESS_O", # 1#RO三段进水压力
|
|
|
|
|
+ "AR.2#RO_EDJSPRESS_O", # 2#RO二段进水压力
|
|
|
|
|
+ "AR.2#RO_SDJSPRESS_O", # 2#RO三段进水压力
|
|
|
|
|
+ "AR.ZJS_TEMP_O", # 进水温度
|
|
|
|
|
+ "AR.ZJS_ZD_O", # UF进水浊度
|
|
|
|
|
+ "AR.RO_JSDD_O", # RO进水电导
|
|
|
|
|
+ "AR.RO_JSORP_O", # RO进水ORP
|
|
|
|
|
+ "AR.RO_JSPH_O", # RO进水PH
|
|
|
|
|
+ "AR.1#UF_V_FB_O", # 1#UF调节阀开度反馈
|
|
|
|
|
+ "AR.2#UF_V_FB_O", # 2#UF调节阀开度反馈
|
|
|
|
|
+ "AR.1#UFBWB_FRE_FB_O", # 1#UF反洗泵频率反馈
|
|
|
|
|
+ "AR.2#UFBWB_FRE_FB_O", # 2#UF反洗泵频率反馈
|
|
|
|
|
+ "AR.1#RODJB_FRE_FB_O", # 1#RO段间泵频率反馈
|
|
|
|
|
+ "AR.1#ROGYB_FRE_FB_O", # 1#RO高压泵频率反馈
|
|
|
|
|
+ "AR.1#RODJB_CZ_O", # 1#RO段间泵测振反馈
|
|
|
|
|
+ "AR.1#ROGYB_CZ_O", # 1#RO高压泵测振反馈
|
|
|
|
|
+ "AR.2#RODJB_CZ_O", # 2#RO段间泵测振反馈
|
|
|
|
|
+ "AR.2#ROGYB_CZ_O", # 2#RO高压泵测振反馈
|
|
|
|
|
+ "AR.ROGSB_FRE_FB_O", # RO供水泵频率反馈
|
|
|
|
|
+ "AR.UFGSB_FRE_FB_O", # UF供水泵频率反馈
|
|
|
|
|
+ "AR.V_UF1_TJV_KD_FB", # UF1调节阀开度反馈
|
|
|
|
|
+ "AR.V_UF2_TJV_KD_FB", # UF2调节阀开度反馈
|
|
|
|
|
+ "AR.CS_LEVEL_O", # RO产水箱液位
|
|
|
|
|
+ "AR.UF_CSLEVEL_O", # UF产水箱液位
|
|
|
|
|
+ "AR.UF1_SSD_KMYC", # UF1跨膜压差
|
|
|
|
|
+ "AR.UF2_SSD_KMYC", # UF2跨膜压差
|
|
|
|
|
+ "AR.RO1_2D_YC", # RO1二段压差
|
|
|
|
|
+ "AR.PUBLIC_BY_REAL_1", # RO1三段压差
|
|
|
|
|
+ "1#RO_CSFLOW", # 1#RO产水流量
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ # 用于防空值兜底机制的变量
|
|
|
|
|
+ self.raw_input_data = None
|
|
|
|
|
+ self.target_columns = self.required_columns[-self.labels_num:]
|
|
|
|
|
+
|
|
|
|
|
+ def _load_model(self):
|
|
|
|
|
+ """内部方法:加载模型权重"""
|
|
|
|
|
+ class ModelArgs: pass
|
|
|
|
|
+ args = ModelArgs()
|
|
|
|
|
+ args.feature_num = self.feature_num
|
|
|
|
|
+ args.hidden_size = self.hidden_size
|
|
|
|
|
+ args.num_layers = self.num_layers
|
|
|
|
|
+ args.output_size = self.output_size
|
|
|
|
|
+ args.labels_num = self.labels_num
|
|
|
|
|
+ args.dropout = self.dropout
|
|
|
|
|
+
|
|
|
|
|
+ self.model = GAT_LSTM(args).to(self.device)
|
|
|
|
|
+
|
|
|
|
|
+ # 加载 edge_index.pt
|
|
|
|
|
+ if os.path.exists('edge_index.pt'):
|
|
|
|
|
+ edge_index = torch.load('edge_index.pt', map_location=self.device, weights_only=True)
|
|
|
|
|
+ self.model.set_edge_index(edge_index)
|
|
|
|
|
+
|
|
|
|
|
+ if not os.path.exists(self.model_path):
|
|
|
|
|
+ raise FileNotFoundError(f"未找到模型权重文件: {self.model_path}")
|
|
|
|
|
+
|
|
|
|
|
+ state_dict = torch.load(self.model_path, map_location=self.device, weights_only=True)
|
|
|
|
|
+ self.model.load_state_dict(state_dict)
|
|
|
|
|
+ self.model.eval()
|
|
|
|
|
+
|
|
|
|
|
+ def _preprocess(self, df):
|
|
|
|
|
+ """数据预处理:补全、排序、生成时间特征、整体归一化"""
|
|
|
|
|
+ data = df.copy()
|
|
|
|
|
+
|
|
|
|
|
+ # 1. 统一时间列名
|
|
|
|
|
+ if 'datetime' in data.columns:
|
|
|
|
|
+ data = data.rename(columns={'datetime': 'index'})
|
|
|
|
|
+ if 'index' not in data.columns:
|
|
|
|
|
+ data['index'] = pd.date_range(end=datetime.now(), periods=len(data), freq='min')
|
|
|
|
|
+ data['index'] = pd.to_datetime(data['index'])
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 补全长度 (Padding)
|
|
|
|
|
+ if len(data) < self.seq_len:
|
|
|
|
|
+ pad_len = self.seq_len - len(data)
|
|
|
|
|
+ first_row = data.iloc[0:1]
|
|
|
|
|
+ pads = pd.concat([first_row] * pad_len, ignore_index=True)
|
|
|
|
|
+ start_time = data['index'].iloc[0]
|
|
|
|
|
+ for i in range(pad_len):
|
|
|
|
|
+ pads.at[i, 'index'] = start_time - timedelta(minutes=(pad_len-i))
|
|
|
|
|
+ data = pd.concat([pads, data], ignore_index=True)
|
|
|
|
|
+
|
|
|
|
|
+ # 3. 列筛选排序 (提取业务数据,不含index)
|
|
|
|
|
+ try:
|
|
|
|
|
+ # required_columns[0] 是 'index',我们取后面的业务列
|
|
|
|
|
+ business_cols = self.required_columns[1:]
|
|
|
|
|
+ data_business = data[business_cols].copy()
|
|
|
|
|
+
|
|
|
|
|
+ # 策略: 前向填充 -> 后向填充 -> 填充为0
|
|
|
|
|
+ data_business = data_business.ffill().bfill().fillna(0.0)
|
|
|
|
|
+ # ==========================================
|
|
|
|
|
+
|
|
|
|
|
+ except KeyError:
|
|
|
|
|
+ missing = list(set(self.required_columns) - set(data.columns))
|
|
|
|
|
+ raise ValueError(f"缺少列: {missing}")
|
|
|
|
|
+
|
|
|
|
|
+ # 4. 生成时间特征
|
|
|
|
|
+ date_col = data['index']
|
|
|
|
|
+ minute_of_day = date_col.dt.hour * 60 + date_col.dt.minute
|
|
|
|
|
+ day_of_year = date_col.dt.dayofyear
|
|
|
|
|
+
|
|
|
|
|
+ time_features = pd.DataFrame({
|
|
|
|
|
+ 'minute_sin': np.sin(2 * np.pi * minute_of_day / 1440),
|
|
|
|
|
+ 'minute_cos': np.cos(2 * np.pi * minute_of_day / 1440),
|
|
|
|
|
+ 'day_year_sin': np.sin(2 * np.pi * day_of_year / 366),
|
|
|
|
|
+ 'day_year_cos': np.cos(2 * np.pi * day_of_year / 366)
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ # 5. 拼接:[时间特征 + 业务特征]
|
|
|
|
|
+ # 注意:训练时的顺序是 time_features + other_columns
|
|
|
|
|
+ # 必须重置索引以避免拼接错位
|
|
|
|
|
+ data_to_scale = pd.concat([
|
|
|
|
|
+ time_features.reset_index(drop=True),
|
|
|
|
|
+ data_business.reset_index(drop=True)
|
|
|
|
|
+ ], axis=1)
|
|
|
|
|
+
|
|
|
|
|
+ # 6. 整体归一化
|
|
|
|
|
+ # 此时 columns 应该包含: minute_sin, minute_cos..., AR.1#UF_JSFLOW_O...
|
|
|
|
|
+ # 顺序和名字必须与 fit 时一致
|
|
|
|
|
+ scaled_array = self.scaler.transform(data_to_scale)
|
|
|
|
|
+
|
|
|
|
|
+ return scaled_array
|
|
|
|
|
+
|
|
|
|
|
+ # --- 备用防空值兜底函数 ---
|
|
|
|
|
+ def get_recent_values_as_fallback(self):
|
|
|
|
|
+ """从原始输入数据中获取最近的output_size条记录作为备用输出,避免输出空值"""
|
|
|
|
|
+ if self.raw_input_data is None or self.raw_input_data.empty:
|
|
|
|
|
+ return np.zeros((self.output_size, self.labels_num))
|
|
|
|
|
+
|
|
|
|
|
+ df_copy = self.raw_input_data.copy()
|
|
|
|
|
+
|
|
|
|
|
+ # 统一时间列格式,防止报错
|
|
|
|
|
+ if 'datetime' in df_copy.columns:
|
|
|
|
|
+ df_copy = df_copy.rename(columns={'datetime': 'index'})
|
|
|
|
|
+ if 'index' not in df_copy.columns:
|
|
|
|
|
+ df_copy['index'] = pd.date_range(end=datetime.now(), periods=len(df_copy), freq='min')
|
|
|
|
|
+ df_copy['index'] = pd.to_datetime(df_copy['index'])
|
|
|
|
|
+
|
|
|
|
|
+ # 按时间排序并取最近的output_size条
|
|
|
|
|
+ recent_data = df_copy.sort_values('index').tail(self.output_size)
|
|
|
|
|
+
|
|
|
|
|
+ # 若数据不足,用最后一条补充
|
|
|
|
|
+ if len(recent_data) < self.output_size:
|
|
|
|
|
+ last_row = recent_data.iloc[-1:] if not recent_data.empty else pd.DataFrame(
|
|
|
|
|
+ {col: [0.0] for col in self.target_columns}, index=[0])
|
|
|
|
|
+ while len(recent_data) < self.output_size:
|
|
|
|
|
+ recent_data = pd.concat([recent_data, last_row], ignore_index=True)
|
|
|
|
|
+
|
|
|
|
|
+ # 确保提取的兜底数据中没有空值 (NaN)
|
|
|
|
|
+ recent_data[self.target_columns] = recent_data[self.target_columns].ffill().bfill().fillna(0.0)
|
|
|
|
|
+
|
|
|
|
|
+ # 提取目标列值并返回
|
|
|
|
|
+ try:
|
|
|
|
|
+ fallback_values = recent_data[self.target_columns].values
|
|
|
|
|
+ except KeyError:
|
|
|
|
|
+ # 极度异常情况兜底(输入中缺少目标列)
|
|
|
|
|
+ fallback_values = np.zeros((self.output_size, self.labels_num))
|
|
|
|
|
+
|
|
|
|
|
+ return fallback_values
|
|
|
|
|
+
|
|
|
|
|
+ def predict(self, df):
|
|
|
|
|
+ """
|
|
|
|
|
+ 返回: List[List[float]]
|
|
|
|
|
+ 格式: [[t+1时刻的4个值], [t+2时刻的4个值], ..., [t+5时刻的4个值]]
|
|
|
|
|
+ """
|
|
|
|
|
+ # --- 保存原始输入数据用于可能的降级策略 ---
|
|
|
|
|
+ self.raw_input_data = df.copy()
|
|
|
|
|
+
|
|
|
|
|
+ # 1. 预处理 (返回的是归一化后的 numpy 数组)
|
|
|
|
|
+ processed_data = self._preprocess(df)
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 取最后 seq_len 个时间步构建 Tensor
|
|
|
|
|
+ input_seq = processed_data[-self.seq_len:]
|
|
|
|
|
+ input_tensor = torch.tensor(input_seq, dtype=torch.float32).unsqueeze(0).to(self.device)
|
|
|
|
|
+
|
|
|
|
|
+ # 3. 推理
|
|
|
|
|
+ with torch.no_grad():
|
|
|
|
|
+ output = self.model(input_tensor)
|
|
|
|
|
+
|
|
|
|
|
+ # 4. 反归一化
|
|
|
|
|
+ # 输出形状调整为 (5, 4) -> 5个步长, 4个变量
|
|
|
|
|
+ preds = output.cpu().numpy().reshape(self.output_size, self.labels_num)
|
|
|
|
|
+
|
|
|
|
|
+ # 获取最后4列的归一化参数 (目标变量)
|
|
|
|
|
+ target_min = self.scaler.min_[-self.labels_num:]
|
|
|
|
|
+ target_scale = self.scaler.scale_[-self.labels_num:]
|
|
|
|
|
+
|
|
|
|
|
+ real_preds = (preds - target_min) / target_scale
|
|
|
|
|
+ real_preds = np.abs(real_preds)
|
|
|
|
|
+
|
|
|
|
|
+ # --- 空值/NaN 检测与兜底机制 ---
|
|
|
|
|
+ # 如果模型因极端情况输出 NaN 或者 inf 无穷大,触发历史数据兜底
|
|
|
|
|
+ if np.isnan(real_preds).any() or np.isinf(real_preds).any():
|
|
|
|
|
+ real_preds = self.get_recent_values_as_fallback()
|
|
|
|
|
+
|
|
|
|
|
+ # 5. 返回纯数值列表
|
|
|
|
|
+ return real_preds.tolist()
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ # 测试代码
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 初始化
|
|
|
|
|
+ predictor = RealTimePredictor()
|
|
|
|
|
+
|
|
|
|
|
+ # 生成模拟数据
|
|
|
|
|
+ mock_data = pd.DataFrame()
|
|
|
|
|
+ mock_data['index'] = pd.date_range(end=datetime.now(), periods=15, freq='min')
|
|
|
|
|
+ for col in predictor.required_columns[1:]:
|
|
|
|
|
+ mock_data[col] = np.random.rand(15) * 10
|
|
|
|
|
+
|
|
|
|
|
+ # 人为制造空值测试鲁棒性
|
|
|
|
|
+ mock_data.loc[3:6, "AR.1#UF_JSFLOW_O"] = np.nan
|
|
|
|
|
+ mock_data.loc[12, predictor.target_columns[0]] = np.nan
|
|
|
|
|
+
|
|
|
|
|
+ # 预测
|
|
|
|
|
+ result = predictor.predict(mock_data)
|
|
|
|
|
+
|
|
|
|
|
+ print("预测结果 (5x4 数组):")
|
|
|
|
|
+ print(result)
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"Error: {e}")
|
|
|
|
|
+ import traceback
|
|
|
|
|
+ traceback.print_exc()
|