predict.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. # predict.py
  2. import os
  3. import torch
  4. import joblib
  5. import pandas as pd
  6. import numpy as np
  7. from datetime import datetime, timedelta
  8. from gat_lstm import GAT_LSTM
  9. class RealTimePredictor:
  10. def __init__(self, model_path='model.pth', scaler_path='scaler.pkl', device=None):
  11. """
  12. 初始化预测器
  13. """
  14. # 1. 参数配置 (与训练 args.py 保持一致)
  15. self.seq_len = 10 # 输入序列长度
  16. self.feature_num = 78 # 输入特征数 (4时间编码 + 38业务特征)
  17. self.labels_num = 8 # 输出标签数
  18. self.hidden_size = 64
  19. self.num_layers = 1
  20. self.output_size = 5 # 预测未来 5 步
  21. self.dropout = 0
  22. # 2. 设备与资源加载
  23. self.device = device if device else torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  24. self.model_path = model_path
  25. self.scaler_path = scaler_path
  26. # 加载归一化器
  27. if not os.path.exists(self.scaler_path):
  28. raise FileNotFoundError(f"未找到归一化文件: {self.scaler_path},请确保已完成训练。")
  29. self.scaler = joblib.load(self.scaler_path)
  30. # 加载模型
  31. self._load_model()
  32. # 定义必须存在的列名 (75个)
  33. self.required_columns = [
  34. 'index',
  35. "water_in", # 进水量
  36. "water_out", # 外供水流量
  37. "RO1_TYL", # RO1脱盐率
  38. "RO2_TYL", # RO2脱盐率
  39. "UF1Per", # UF1渗透率
  40. "UF2Per", # UF2渗透率
  41. "2#RODJB_Eff", # 2#RO段间泵效率
  42. "1#RODJB_Eff", # 1#RO段间泵效率
  43. "2#ROGYB_Eff", # 2#RO高压泵效率
  44. "1#ROGYB_Eff", # 1#RO高压泵效率
  45. "ROHSL", # 反渗透回收率
  46. "ns=3;s=1#RO_CSDD_O", # 1#RO产水电导
  47. "ns=3;s=1#RO_CSPRESS_O", # 1#RO产水压力
  48. "ns=3;s=1#RO_EDCSFLOW_O", # 1#RO二段产水流量
  49. "ns=3;s=1#RO_EDJSPRESS_O", # 1#RO二段进水压力
  50. "ns=3;s=1#RO_EDNSPRESS_O", # 1#RO二段浓水压力
  51. "ns=3;s=1#RO_JSFLOW_O", # 1#RO进水流量
  52. "ns=3;s=1#RO_JSPRESS_O", # 1#RO进水压力
  53. "ns=3;s=1#RO_NSFLOW_O", # 1#RO浓水流量
  54. "ns=3;s=1#RO_SDCSFLOW_O", # 1#RO三段产水流量
  55. "ns=3;s=1#RO_SDJSPRESS_O", # 1#RO三段进水压力
  56. "ns=3;s=1#RO_SDNSPRESS_O", # 1#RO三段浓水压力
  57. "ns=3;s=1#RODJB_CUR_FB_O", # 1#RO段间泵电流反馈
  58. "ns=3;s=1#RODJB_CZ_O", # 1#RO段间泵测振反馈
  59. "ns=3;s=1#RODJB_FRE_FB_O", # 1#RO段间泵频率反馈
  60. "ns=3;s=1#ROGYB_CUR_FB_O", # 1#RO高压泵电流反馈
  61. "ns=3;s=1#ROGYB_CZ_O", # 1#RO高压泵测振反馈
  62. "ns=3;s=1#ROGYB_FRE_FB_O", # 1#RO高压泵频率反馈
  63. "ns=3;s=1#UF_CSPRESS_O", # 1#UF产水压力
  64. "ns=3;s=1#UF_JSFLOW_O", # 1#UF进水流量
  65. "ns=3;s=1#UF_JSPRESS_O", # 1#UF进水压力
  66. "ns=3;s=1#UF_V_FB_O", # 1#UF调节阀开度反馈
  67. "ns=3;s=1#UFBWB_CUR_FB_O", # 1#UF反洗泵电流反馈
  68. "ns=3;s=1#UFBWB_FRE_FB_O", # 1#UF反洗泵频率反馈
  69. "ns=3;s=2#RO_CSDD_O", # 2#RO产水电导
  70. "ns=3;s=2#RO_CSPRESS_O", # 2#RO产水压力
  71. "ns=3;s=2#RO_EDCSFLOW_O", # 2#RO二段产水流量
  72. "ns=3;s=2#RO_EDJSPRESS_O", # 2#RO二段进水压力
  73. "ns=3;s=2#RO_EDNSPRESS_O", # 2#RO二段浓水压力
  74. "ns=3;s=2#RO_JSFLOW_O", # 2#RO进水流量
  75. "ns=3;s=2#RO_JSPRESS_O", # 2#RO进水压力
  76. "ns=3;s=2#RO_NSFLOW_O", # 2#RO浓水流量
  77. "ns=3;s=2#RO_SDCSFLOW_O", # 2#RO三段产水流量
  78. "ns=3;s=2#RO_SDJSPRESS_O", # 2#RO三段进水压力
  79. "ns=3;s=2#RO_SDNSPRESS_O", # 2#RO三段浓水压力
  80. "ns=3;s=2#RODJB_CUR_FB_O", # 2#RO段间泵电流反馈
  81. "ns=3;s=2#RODJB_CZ_O", # 2#RO段间泵测振反馈
  82. "ns=3;s=2#RODJB_FRE_FB_O", # 2#RO段间泵频率反馈
  83. "ns=3;s=2#ROGYB_CUR_FB_O", # 2#RO高压泵电流反馈
  84. "ns=3;s=2#ROGYB_CZ_O", # 2#RO高压泵测振反馈
  85. "ns=3;s=2#ROGYB_FRE_FB_O", # 2#RO高压泵频率反馈
  86. "ns=3;s=2#UF_CSPRESS_O", # 2#UF产水压力
  87. "ns=3;s=2#UF_JSFLOW_O", # 2#UF进水流量
  88. "ns=3;s=2#UF_JSPRESS_O", # 2#UF进水压力
  89. "ns=3;s=2#UF_V_FB_O", # 2#UF调节阀开度反馈
  90. "ns=3;s=2#UFBWB_CUR_FB_O", # 2#UF反洗泵电流反馈
  91. "ns=3;s=2#UFBWB_FRE_FB_O", # 2#UF反洗泵频率反馈
  92. "ns=3;s=RO_JSDD_O", # RO进水电导
  93. "ns=3;s=RO_JSORP_O", # RO进水ORP
  94. "ns=3;s=RO_JSPH_O", # RO进水PH
  95. "ns=3;s=RO1_1DUAN_CS_FLOW", # RO1一段产水流量
  96. "ns=3;s=ZJS_PRESS_O", # 进水压力
  97. "ns=3;s=ZJS_TEMP_O", # 进水温度
  98. "ns=3;s=ZJS_ZD_O", # UF进水浊度
  99. "ns=3;s=PUBLIC_RO1_MTL", # RO1膜通量
  100. "ns=3;s=PUBLIC_RO2_MTL", # RO2膜通量
  101. "ns=3;s=UF1_SSD_KMYC", # UF1跨膜压差
  102. "ns=3;s=UF2_SSD_KMYC", # UF2跨膜压差
  103. "ns=3;s=RO1_1D_YC", # RO1一段压差
  104. "ns=3;s=RO1_2D_YC", # RO1二段压差
  105. "ns=3;s=RO2_1D_YC", # RO2一段压差
  106. "ns=3;s=RO2_2D_YC", # RO2二段压差
  107. "ns=3;s=PUBLIC_BY_REAL_1", # RO1三段压差
  108. "ns=3;s=PUBLIC_BY_REAL_2", # RO2三段压差
  109. ]
  110. # --- 用于防空值兜底机制的变量 ---
  111. self.raw_input_data = None
  112. # 目标列名自动推导(最后 labels_num 个列)
  113. self.target_columns = self.required_columns[-self.labels_num:]
  114. def _load_model(self):
  115. """内部方法:加载模型权重"""
  116. class ModelArgs: pass
  117. args = ModelArgs()
  118. args.feature_num = self.feature_num
  119. args.hidden_size = self.hidden_size
  120. args.num_layers = self.num_layers
  121. args.output_size = self.output_size
  122. args.labels_num = self.labels_num
  123. args.dropout = self.dropout
  124. self.model = GAT_LSTM(args).to(self.device)
  125. # 加载 edge_index.pt
  126. if os.path.exists('edge_index.pt'):
  127. edge_index = torch.load('edge_index.pt', map_location=self.device, weights_only=True)
  128. self.model.set_edge_index(edge_index)
  129. if not os.path.exists(self.model_path):
  130. raise FileNotFoundError(f"未找到模型权重文件: {self.model_path}")
  131. state_dict = torch.load(self.model_path, map_location=self.device, weights_only=True)
  132. self.model.load_state_dict(state_dict)
  133. self.model.eval()
  134. def _preprocess(self, df):
  135. """数据预处理:补全、排序、生成时间特征、整体归一化"""
  136. data = df.copy()
  137. # 1. 统一时间列名
  138. if 'datetime' in data.columns:
  139. data = data.rename(columns={'datetime': 'index'})
  140. if 'index' not in data.columns:
  141. data['index'] = pd.date_range(end=datetime.now(), periods=len(data), freq='min')
  142. data['index'] = pd.to_datetime(data['index'])
  143. # 2. 补全长度 (Padding)
  144. if len(data) < self.seq_len:
  145. pad_len = self.seq_len - len(data)
  146. first_row = data.iloc[0:1]
  147. pads = pd.concat([first_row] * pad_len, ignore_index=True)
  148. start_time = data['index'].iloc[0]
  149. for i in range(pad_len):
  150. pads.at[i, 'index'] = start_time - timedelta(minutes=(pad_len-i))
  151. data = pd.concat([pads, data], ignore_index=True)
  152. # 3. 列筛选排序 (提取业务数据,不含index)
  153. try:
  154. # required_columns[0] 是 'index',我们取后面的业务列
  155. business_cols = self.required_columns[1:]
  156. data_business = data[business_cols]
  157. # 策略: 前向填充 -> 后向填充 -> 填充为0
  158. data_business = data_business.ffill().bfill().fillna(0.0)
  159. except KeyError:
  160. missing = list(set(self.required_columns) - set(data.columns))
  161. raise ValueError(f"缺少列: {missing}")
  162. # 4. 生成时间特征
  163. date_col = data['index']
  164. minute_of_day = date_col.dt.hour * 60 + date_col.dt.minute
  165. day_of_year = date_col.dt.dayofyear
  166. time_features = pd.DataFrame({
  167. 'minute_sin': np.sin(2 * np.pi * minute_of_day / 1440),
  168. 'minute_cos': np.cos(2 * np.pi * minute_of_day / 1440),
  169. 'day_year_sin': np.sin(2 * np.pi * day_of_year / 366),
  170. 'day_year_cos': np.cos(2 * np.pi * day_of_year / 366)
  171. })
  172. # 5. 拼接:[时间特征 + 业务特征]
  173. # 注意:训练时的顺序是 time_features + other_columns
  174. # 必须重置索引以避免拼接错位
  175. data_to_scale = pd.concat([
  176. time_features.reset_index(drop=True),
  177. data_business.reset_index(drop=True)
  178. ], axis=1)
  179. # 6. 整体归一化
  180. # 此时 columns 应该包含: minute_sin, minute_cos..., AR.1#UF_JSFLOW_O...
  181. # 顺序和名字必须与 fit 时一致
  182. scaled_array = self.scaler.transform(data_to_scale)
  183. return scaled_array
  184. # --- 备用防空值兜底函数 ---
  185. def get_recent_values_as_fallback(self):
  186. """从原始输入数据中获取最近的output_size条记录作为备用输出,避免输出空值"""
  187. if self.raw_input_data is None or self.raw_input_data.empty:
  188. return np.zeros((self.output_size, self.labels_num))
  189. df_copy = self.raw_input_data.copy()
  190. # 统一时间列格式,防止报错
  191. if 'datetime' in df_copy.columns:
  192. df_copy = df_copy.rename(columns={'datetime': 'index'})
  193. if 'index' not in df_copy.columns:
  194. df_copy['index'] = pd.date_range(end=datetime.now(), periods=len(df_copy), freq='min')
  195. df_copy['index'] = pd.to_datetime(df_copy['index'])
  196. # 按时间排序并取最近的output_size条
  197. recent_data = df_copy.sort_values('index').tail(self.output_size)
  198. # 若数据不足,用最后一条补充
  199. if len(recent_data) < self.output_size:
  200. last_row = recent_data.iloc[-1:] if not recent_data.empty else pd.DataFrame(
  201. {col: [0.0] for col in self.target_columns}, index=[0])
  202. while len(recent_data) < self.output_size:
  203. recent_data = pd.concat([recent_data, last_row], ignore_index=True)
  204. # 确保提取的兜底数据中没有空值 (NaN)
  205. recent_data[self.target_columns] = recent_data[self.target_columns].ffill().bfill().fillna(0.0)
  206. # 提取目标列值并返回
  207. try:
  208. fallback_values = recent_data[self.target_columns].values
  209. except KeyError:
  210. # 极度异常情况兜底(输入中缺少目标列)
  211. fallback_values = np.zeros((self.output_size, self.labels_num))
  212. return fallback_values
  213. def predict(self, df):
  214. """
  215. 返回: List[List[float]]
  216. 格式: [[t+1时刻的4个值], [t+2时刻的4个值], ..., [t+5时刻的4个值]]
  217. """
  218. # --- 保存原始输入数据用于可能的降级策略 ---
  219. self.raw_input_data = df.copy()
  220. # 1. 预处理 (返回的是归一化后的 numpy 数组)
  221. processed_data = self._preprocess(df)
  222. # 2. 取最后 seq_len 个时间步构建 Tensor
  223. input_seq = processed_data[-self.seq_len:]
  224. input_tensor = torch.tensor(input_seq, dtype=torch.float32).unsqueeze(0).to(self.device)
  225. # 3. 推理
  226. with torch.no_grad():
  227. output = self.model(input_tensor)
  228. # 4. 反归一化
  229. # 输出形状调整为 (5, 4) -> 5个步长, 4个变量
  230. preds = output.cpu().numpy().reshape(self.output_size, self.labels_num)
  231. # 获取最后4列的归一化参数 (目标变量)
  232. target_min = self.scaler.min_[-self.labels_num:]
  233. target_scale = self.scaler.scale_[-self.labels_num:]
  234. real_preds = (preds - target_min) / target_scale
  235. real_preds = np.abs(real_preds)
  236. # --- 空值/NaN 检测与兜底机制 ---
  237. # 如果模型因极端情况输出 NaN 或者 inf 无穷大,触发历史数据兜底
  238. if np.isnan(real_preds).any() or np.isinf(real_preds).any():
  239. real_preds = self.get_recent_values_as_fallback()
  240. # 5. 返回纯数值列表
  241. return real_preds.tolist()
  242. if __name__ == "__main__":
  243. # 测试代码
  244. try:
  245. # 初始化
  246. predictor = RealTimePredictor()
  247. # 生成模拟数据
  248. mock_data = pd.DataFrame()
  249. mock_data['index'] = pd.date_range(end=datetime.now(), periods=15, freq='min')
  250. for col in predictor.required_columns[1:]:
  251. mock_data[col] = np.random.rand(15) * 10
  252. # 人为制造一些空值进行测试
  253. # mock_data.loc[5:7, 'water_in'] = np.nan
  254. # mock_data.loc[12, predictor.target_columns[0]] = np.nan
  255. # 预测
  256. result = predictor.predict(mock_data)
  257. print("预测结果 (5x8 数组):")
  258. print(result)
  259. except Exception as e:
  260. print(f"Error: {e}")
  261. import traceback
  262. traceback.print_exc()