predict.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import os
  2. import torch
  3. import pandas as pd
  4. import numpy as np
  5. import joblib
  6. from datetime import datetime, timedelta
  7. from torch.utils.data import DataLoader, TensorDataset
  8. from .gat_lstm import GAT_LSTM
  9. from scipy.signal import savgol_filter
  10. from sklearn.preprocessing import MinMaxScaler
  11. from .data_mysql import get_sensor_data
  12. def set_seed(seed):
  13. import random
  14. random.seed(seed)
  15. os.environ['PYTHONHASHSEED'] = str(seed)
  16. np.random.seed(seed)
  17. torch.manual_seed(seed)
  18. torch.cuda.manual_seed(seed)
  19. torch.cuda.manual_seed_all(seed)
  20. torch.backends.cudnn.deterministic = True
  21. torch.backends.cudnn.benchmark = False
  22. class Predictor:
  23. """预测器类,封装了数据处理、模型加载、预测和结果保存的完整流程"""
  24. def __init__(self):
  25. # 模型和数据相关参数
  26. self.seq_len = 4320 # 输入序列长度
  27. self.output_size = 2160 # 预测输出长度
  28. self.labels_num = 8 # 预测目标特征数量
  29. self.feature_num = 16 # 输入特征总数量
  30. self.step_size = 2160 # 滑动窗口步长
  31. self.dropout = 0 # 模型dropout参数
  32. self.lr = 0.01 # 学习率
  33. self.hidden_size = 64 # LSTM隐藏层大小
  34. self.batch_size = 128 # 批处理大小
  35. self.num_layers = 1 # LSTM层数
  36. self.resolution = 60 # 数据时间分辨率(单位:秒)
  37. self.test_start_date = '2025-09-10' # 预测起始日期(动态更新)
  38. self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  39. self.model_path = 'model.pth' # 模型权重路径(可外部修改)
  40. self.output_csv_path = 'predictions.csv' # 结果保存路径(可外部修改)
  41. self.random_seed = 1314 # 随机种子
  42. # 预测结果平滑参数
  43. self.smooth_window = 30
  44. self.ema_alpha = 0.1
  45. self.use_savitzky = True
  46. self.sg_window = 25
  47. self.sg_polyorder = 2
  48. # 初始化设置
  49. set_seed(self.random_seed)
  50. scaler_path = os.path.join(os.path.dirname(__file__), 'scaler.pkl')
  51. self.scaler = joblib.load(scaler_path) # 加载标准化器(确保文件存在)
  52. self.model = None
  53. self.edge_index = None
  54. self.test_loader = None
  55. def reorder_columns(self, df):
  56. """调整DataFrame列顺序以匹配模型训练时的特征顺序"""
  57. desired_order = [
  58. 'index', # 时间索引列
  59. 'C.M.RO1_FT_JS@out','C.M.RO2_FT_JS@out','C.M.RO3_FT_JS@out','C.M.RO4_FT_JS@out',
  60. 'C.M.RO_TT_ZJS@out','C.M.RO_Cond_ZJS@out',
  61. 'C.M.RO1_DB@DPT_1','C.M.RO1_DB@DPT_2',
  62. 'C.M.RO2_DB@DPT_1','C.M.RO2_DB@DPT_2',
  63. 'C.M.RO3_DB@DPT_1','C.M.RO3_DB@DPT_2',
  64. 'C.M.RO4_DB@DPT_1','C.M.RO4_DB@DPT_2',
  65. ]
  66. return df.loc[:, desired_order]
  67. def process_date(self, data):
  68. """处理日期特征,添加年周期正弦/余弦编码"""
  69. if 'index' in data.columns:
  70. data = data.rename(columns={'index': 'date'})
  71. data['date'] = pd.to_datetime(data['date'])
  72. data['day_of_year'] = data['date'].dt.dayofyear
  73. data['day_year_sin'] = np.sin(2 * np.pi * data['day_of_year'] / 366)
  74. data['day_year_cos'] = np.cos(2 * np.pi * data['day_of_year'] / 366)
  75. data.drop(columns=['day_of_year'], inplace=True)
  76. time_features = ['day_year_sin', 'day_year_cos']
  77. other_columns = [col for col in data.columns if col not in ['date'] + time_features]
  78. return data[['date'] + time_features + other_columns]
  79. def scaler_data(self, data):
  80. """使用预训练标准化器处理数据(保留date列)"""
  81. date_col = data[['date']]
  82. data_to_scale = data.drop(columns=['date'])
  83. scaled = self.scaler.transform(data_to_scale)
  84. scaled_df = pd.DataFrame(scaled, columns=data_to_scale.columns)
  85. return pd.concat([date_col.reset_index(drop=True), scaled_df], axis=1)
  86. def create_test_loader(self, df):
  87. """创建测试数据加载器,生成模型输入序列"""
  88. if 'date' in df.columns:
  89. test_data = df.drop(columns=['date']).values
  90. else:
  91. test_data = df.values
  92. X = test_data.reshape(-1, self.seq_len, self.feature_num)
  93. X = torch.tensor(X, dtype=torch.float32).to(self.device)
  94. tensor_dataset = TensorDataset(X)
  95. return DataLoader(tensor_dataset, batch_size=self.batch_size, shuffle=False)
  96. def load_data(self, df):
  97. """数据加载与预处理统一接口"""
  98. df = self.reorder_columns(df)
  99. # df = df.iloc[::self.resolution, :].reset_index(drop=True)
  100. df = self.process_date(df)
  101. df = self.scaler_data(df)
  102. self.test_loader = self.create_test_loader(df)
  103. def load_model(self):
  104. """加载预训练模型并设置为评估模式"""
  105. self.model = GAT_LSTM(self).to(self.device)
  106. model_path = os.path.join(os.path.dirname(__file__), 'model.pth' )
  107. self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True))
  108. self.model.eval()
  109. def moving_average_smooth(self, data):
  110. """滑动平均平滑处理"""
  111. smoothed = []
  112. for i in range(data.shape[1]):
  113. feature = data[:, i]
  114. padded = np.pad(feature, (self.smooth_window//2, self.smooth_window//2), mode='edge')
  115. window = np.ones(self.smooth_window) / self.smooth_window
  116. smoothed_feature = np.convolve(padded, window, mode='valid')
  117. smoothed.append(smoothed_feature.reshape(-1, 1))
  118. return np.concatenate(smoothed, axis=1)
  119. def exponential_smooth(self, data):
  120. """指数移动平均平滑处理"""
  121. smoothed = []
  122. for i in range(data.shape[1]):
  123. feature = data[:, i]
  124. smoothed_feature = np.zeros_like(feature)
  125. smoothed_feature[0] = feature[0]
  126. for t in range(1, len(feature)):
  127. smoothed_feature[t] = self.ema_alpha * feature[t] + (1 - self.ema_alpha) * smoothed_feature[t-1]
  128. smoothed.append(smoothed_feature.reshape(-1, 1))
  129. return np.concatenate(smoothed, axis=1)
  130. def savitzky_golay_smooth(self, data):
  131. """Savitzky-Golay滤波平滑处理"""
  132. smoothed = []
  133. for i in range(data.shape[1]):
  134. feature = data[:, i]
  135. window = min(self.sg_window, len(feature) if len(feature) % 2 == 1 else len(feature)-1)
  136. if window < 3:
  137. smoothed.append(feature.reshape(-1, 1))
  138. continue
  139. smoothed_feature = savgol_filter(feature, window_length=window, polyorder=self.sg_polyorder)
  140. smoothed.append(smoothed_feature.reshape(-1, 1))
  141. return np.concatenate(smoothed, axis=1)
  142. def smooth_predictions(self, predictions):
  143. """组合多步平滑策略处理预测结果"""
  144. smoothed = self.moving_average_smooth(predictions)
  145. smoothed = self.exponential_smooth(smoothed)
  146. if self.use_savitzky and len(predictions) >= self.sg_window:
  147. smoothed = self.savitzky_golay_smooth(smoothed)
  148. return smoothed
  149. def predict(self, df):
  150. # 获取当前或者传入时间的历史数据(180天)
  151. # df = get_sensor_data(start_date=start_date)
  152. """核心预测接口:输入原始数据,返回处理后的预测结果"""
  153. self.test_start_date = (pd.to_datetime(df['index']).max() + timedelta(hours=3)).strftime("%Y-%m-%d %H:%M:%S")
  154. self.load_data(df)
  155. self.load_model()
  156. all_predictions = []
  157. with torch.no_grad():
  158. for batch in self.test_loader:
  159. inputs = batch[0].to(self.device)
  160. outputs = self.model(inputs)
  161. all_predictions.append(outputs.cpu().numpy())
  162. predictions = np.concatenate(all_predictions, axis=0).reshape(-1, self.labels_num)
  163. # 反标准化处理
  164. inverse_scaler = MinMaxScaler()
  165. inverse_scaler.min_ = self.scaler.min_[-self.labels_num:]
  166. inverse_scaler.scale_ = self.scaler.scale_[-self.labels_num:]
  167. predictions = inverse_scaler.inverse_transform(predictions)
  168. predictions = np.clip(predictions, 0, None)
  169. # 平滑处理
  170. predictions = self.smooth_predictions(predictions)
  171. self.test_start_date = (pd.to_datetime(df['index']).max() + timedelta(hours=1)).strftime("%Y-%m-%d %H:%M:%S")
  172. # 直接返回结果
  173. start_time = datetime.strptime(self.test_start_date, "%Y-%m-%d %H:%M:%S")
  174. time_interval = pd.Timedelta(hours=(self.resolution / 60))
  175. timestamps = [start_time + i * time_interval for i in range(len(predictions))]
  176. base_columns = [
  177. 'C.M.RO1_DB@DPT_1', 'C.M.RO2_DB@DPT_1', 'C.M.RO3_DB@DPT_1', 'C.M.RO4_DB@DPT_1',
  178. 'C.M.RO1_DB@DPT_2', 'C.M.RO2_DB@DPT_2', 'C.M.RO3_DB@DPT_2', 'C.M.RO4_DB@DPT_2',
  179. ]
  180. pred_columns = [f'{col}_pred' for col in base_columns]
  181. df_result = pd.DataFrame(predictions, columns=pred_columns)
  182. df_result.insert(0, 'date', timestamps)
  183. return df_result
  184. def save_predictions(self, predictions):
  185. """保存预测结果到CSV文件"""
  186. start_time = datetime.strptime(self.test_start_date, "%Y-%m-%d %H:%M:%S")
  187. time_interval = pd.Timedelta(hours=(self.resolution / 60))
  188. timestamps = [start_time + i * time_interval for i in range(len(predictions))]
  189. base_columns = [
  190. 'C.M.RO1_DB@DPT_1', 'C.M.RO2_DB@DPT_1', 'C.M.RO3_DB@DPT_1', 'C.M.RO4_DB@DPT_1',
  191. 'C.M.RO1_DB@DPT_2', 'C.M.RO2_DB@DPT_2', 'C.M.RO3_DB@DPT_2', 'C.M.RO4_DB@DPT_2',
  192. ]
  193. pred_columns = [f'{col}_pred' for col in base_columns]
  194. df_result = pd.DataFrame(predictions, columns=pred_columns)
  195. df_result.insert(0, 'date', timestamps)
  196. df_result.to_csv(self.output_csv_path, index=False)
  197. print(f"预测结果保存至:{self.output_csv_path}")
  198. if __name__ == '__main__':
  199. # 获取处理后的数据
  200. # 创建预测器实例并进行预测
  201. predictor = Predictor()
  202. predictions = predictor.predict(start_date =None)
  203. # 保存预测结果到 CSV 文件
  204. predictor.save_predictions(predictions)