20min_predict.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642
  1. """
  2. 20分钟TMP预测模型
  3. 版本:1.0
  4. 最后更新:2025-10-28
  5. """
  6. import os
  7. import sys
  8. import torch
  9. import pandas as pd
  10. import numpy as np
  11. import joblib
  12. import pywt
  13. from datetime import datetime, timedelta
  14. from torch.utils.data import DataLoader, TensorDataset
  15. from gat_lstm import GAT_LSTM # 导入自定义的GAT-LSTM模型
  16. from tqdm import tqdm
  17. # 添加项目根目录到系统路径,以便导入common模块
  18. project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
  19. if project_root not in sys.path:
  20. sys.path.insert(0, project_root)
  21. from common.utils.logger import setup_logger, log_execution_time
  22. from common.utils.config import Config
  23. def set_seed(seed):
  24. """
  25. 设置全局随机种子,保证实验可重复性
  26. Args:
  27. seed: 随机种子值
  28. Note:
  29. - 设置Python、NumPy、PyTorch的随机种子
  30. - 确保CUDA操作的确定性
  31. - 关闭CUDA的性能优化(以确保可重复性)
  32. """
  33. import random
  34. random.seed(seed) # Python随机数生成器
  35. os.environ['PYTHONHASHSEED'] = str(seed) # Python哈希种子
  36. np.random.seed(seed) # NumPy随机数生成器
  37. torch.manual_seed(seed) # PyTorch CPU随机数生成器
  38. torch.cuda.manual_seed(seed) # 当前GPU随机数生成器
  39. torch.cuda.manual_seed_all(seed) # 所有GPU随机数生成器
  40. torch.backends.cudnn.deterministic = True # 确保CUDA操作确定性
  41. torch.backends.cudnn.benchmark = False # 关闭CUDA性能优化
  42. class Predictor:
  43. """
  44. TMP预测器类
  45. 功能:
  46. - 加载并预处理输入数据
  47. - 加载训练好的GAT-LSTM模型
  48. - 执行预测并保存结果
  49. 使用示例:
  50. predictor = Predictor()
  51. predictions = predictor.predict(df)
  52. predictor.save_predictions(predictions)
  53. """
  54. def __init__(self, config_path='config.yaml'):
  55. """
  56. 初始化预测器
  57. Args:
  58. config_path: 配置文件路径,默认为当前目录下的config.yaml
  59. Raises:
  60. FileNotFoundError: 配置文件或模型文件不存在
  61. Note:
  62. - 从配置文件加载所有参数
  63. - 自动检测并使用GPU(如果可用)
  64. - 加载训练时保存的数据归一化器
  65. """
  66. # 加载配置文件
  67. config_file = os.path.join(os.path.dirname(__file__), config_path)
  68. self.config = Config(config_file)
  69. # 设置日志
  70. log_dir = os.path.join(os.path.dirname(__file__),
  71. os.path.dirname(self.config.get('logging.log_file', 'logs/20min_predict.log')))
  72. os.makedirs(log_dir, exist_ok=True)
  73. log_file = os.path.join(os.path.dirname(__file__),
  74. self.config.get('logging.log_file', 'logs/20min_predict.log'))
  75. self.logger = setup_logger(
  76. name='20min_predict',
  77. level=self.config.get('logging.level', 'INFO'),
  78. log_file=log_file,
  79. format_type=self.config.get('logging.format', 'colored'),
  80. max_bytes=self.config.get('logging.max_bytes', 10*1024*1024),
  81. backup_count=self.config.get('logging.backup_count', 5)
  82. )
  83. self.logger.info("=" * 80)
  84. self.logger.info("初始化20分钟TMP预测器")
  85. self.logger.info("=" * 80)
  86. # 模型参数(从配置文件加载)
  87. self.seq_len = self.config.get('model.seq_len', 10)
  88. self.output_size = self.config.get('model.output_size', 5)
  89. self.labels_num = self.config.get('model.labels_num', 16)
  90. self.feature_num = self.config.get('model.feature_num', 79)
  91. self.step_size = self.config.get('model.step_size', 5)
  92. self.dropout = self.config.get('model.dropout', 0)
  93. self.lr = self.config.get('model.lr', 0.01)
  94. self.num_heads = self.config.get('model.num_heads', 8)
  95. self.hidden_size = self.config.get('model.hidden_size', 64)
  96. self.batch_size = self.config.get('model.batch_size', 512)
  97. self.num_layers = self.config.get('model.num_layers', 1)
  98. self.random_seed = self.config.get('model.random_seed', 1314)
  99. # 数据处理参数
  100. self.resolution = self.config.get('data.resolution', 60)
  101. self.test_start_date = self.config.get('data.test_start_date', '2025-07-01')
  102. self.wavelet = self.config.get('data.wavelet.type', 'db4')
  103. self.level = self.config.get('data.wavelet.level', 3)
  104. self.level_after = self.config.get('data.wavelet.level_after', 4)
  105. self.mode = self.config.get('data.wavelet.mode', 'soft')
  106. # 阈值参数
  107. self.uf_threshold = self.config.get('data.threshold.uf', 0.001)
  108. self.ro_threshold = self.config.get('data.threshold.ro', 0.01)
  109. self.flow_threshold = self.config.get('data.threshold.flow', 1.0)
  110. # 文件路径(相对于当前脚本目录)
  111. script_dir = os.path.dirname(__file__)
  112. self.model_path = os.path.join(script_dir, self.config.get('paths.model_path', '20min_model.pth'))
  113. self.scaler_path = os.path.join(script_dir, self.config.get('paths.scaler_path', '20min_scaler.pkl'))
  114. self.edge_index_path = os.path.join(script_dir, self.config.get('paths.edge_index_path', 'edge_index.pt'))
  115. self.output_csv_path = os.path.join(script_dir, self.config.get('paths.output_csv_path', '20min_predictions.csv'))
  116. # 后处理参数
  117. self.remove_outliers_flag = self.config.get('postprocess.remove_outliers', False)
  118. self.smooth_flag = self.config.get('postprocess.smooth', False)
  119. # 设备配置
  120. use_cuda = self.config.get('device.use_cuda', True)
  121. cuda_device = self.config.get('device.cuda_device', 0)
  122. if use_cuda and torch.cuda.is_available():
  123. self.device = torch.device(f"cuda:{cuda_device}")
  124. self.logger.info(f"使用GPU设备: {self.device} ({torch.cuda.get_device_name(cuda_device)})")
  125. else:
  126. self.device = torch.device("cpu")
  127. self.logger.info("使用CPU设备")
  128. # 设置随机种子
  129. self.logger.info(f"设置随机种子: {self.random_seed}")
  130. set_seed(self.random_seed)
  131. # 加载数据归一化器
  132. if not os.path.exists(self.scaler_path):
  133. self.logger.error(f"归一化器文件不存在: {self.scaler_path}")
  134. raise FileNotFoundError(f"归一化器文件不存在: {self.scaler_path}")
  135. self.logger.info(f"加载数据归一化器: {self.scaler_path}")
  136. self.scaler = joblib.load(self.scaler_path)
  137. # 初始化模型和数据加载器(后续加载)
  138. self.model = None
  139. self.edge_index = None
  140. self.test_loader = None
  141. self.logger.info("预测器初始化完成")
  142. self.logger.info(f"模型参数: seq_len={self.seq_len}, output_size={self.output_size}, "
  143. f"labels_num={self.labels_num}, feature_num={self.feature_num}")
  144. self.logger.info(f"数据参数: resolution={self.resolution}, batch_size={self.batch_size}")
  145. def reorder_columns(self, df):
  146. """
  147. 调整数据列顺序,确保与训练时的特征顺序一致
  148. Args:
  149. df: 输入的DataFrame
  150. Returns:
  151. DataFrame: 列顺序调整后的DataFrame
  152. Note:
  153. - 避免因列顺序不一致导致模型输入特征错位
  154. - 必须包含所有必需的特征列
  155. """
  156. self.logger.debug("开始重排数据列顺序")
  157. desired_order = [
  158. 'index',
  159. 'C.M.FT_ZGJJY1@out','C.M.RO1_FT_JS@out','C.M.RO2_FT_JS@out','C.M.RO3_FT_JS@out',
  160. 'C.M.RO4_FT_JS@out','C.M.UF1_FT_JS@out','C.M.UF2_FT_JS@out','C.M.UF3_FT_JS@out',
  161. 'C.M.UF4_FT_JS@out','C.M.UF_FT_ZCS@out','C.M.FT_ZGJJY2@out','C.M.FT_ZGJJY3@out',
  162. 'C.M.FT_ZGJJY4@out','C.M.RO1_PT_JS@out','C.M.RO2_PT_JS@out','C.M.RO3_PT_JS@out',
  163. 'C.M.UF1_PT_JS@out','C.M.UF2_PT_JS@out','C.M.UF3_PT_JS@out','C.M.UF4_PT_JS@out',
  164. 'C.M.LT_JSC@out','C.M.RO1_PT_CS@out','C.M.RO1_PT_DJ2@out','C.M.RO2_PT_CS@out',
  165. 'C.M.RO2_PT_DJ2@out','C.M.RO3_PT_CS@out','C.M.RO3_PT_DJ2@out','C.M.RO4_PT_CS@out',
  166. 'C.M.RO4_PT_DJ2@out','C.M.RO4_PT_JS@out','C.M.LT_HCl@out','C.M.LT_NaClO@out',
  167. 'C.M.LT_PAC@out','C.M.LT_QSC@out','C.M.RO_Cond_ZCS@out','C.M.RO_TT_ZJS@out',
  168. 'C.M.UF1_JSF_kd@out','C.M.UF2_JSF_kd@out','C.M.UF_GSB4_fre@out','C.M.UF_ORP_ZCS@out',
  169. 'C.M.JYB2_ZGJ1_fre@out','C.M.JYB2_ZGJ2_fre@out','C.M.JYB2_ZGJ3_fre@out','C.M.JYB2_ZGJ4_fre@out',
  170. 'C.M.RO1_GYB_fre@out','C.M.RO2_GYB_fre@out','C.M.RO3_GYB_fre@out','C.M.RO4_GYB_fre@out',
  171. 'C.M.UF3_JSF_kd@out','C.M.UF4_JSF_kd@out','C.M.UF_FXB2_fre@out','C.M.RO1_DJB_fre@out',
  172. 'C.M.RO1_GYBF_kd@out','C.M.RO2_DJB_fre@out','C.M.RO2_GYBF_kd@out','C.M.RO3_DJB_fre@out',
  173. 'C.M.RO3_GYBF_kd@out','C.M.RO4_DJB_fre@out','C.M.RO4_GYBF_kd@out',
  174. 'C.M.UF1_DB@press_PV','C.M.UF2_DB@press_PV','C.M.UF3_DB@press_PV','C.M.UF4_DB@press_PV',
  175. 'UF1Per','UF2Per','UF3Per','UF4Per',
  176. 'C.M.RO1_DB@DPT_1','C.M.RO2_DB@DPT_1','C.M.RO3_DB@DPT_1','C.M.RO4_DB@DPT_1',
  177. 'C.M.RO1_DB@DPT_2','C.M.RO2_DB@DPT_2','C.M.RO3_DB@DPT_2','C.M.RO4_DB@DPT_2',
  178. ]
  179. self.logger.debug(f"原始列: {list(df.columns)}")
  180. self.logger.debug(f"目标列顺序: {desired_order}")
  181. return df.loc[:, desired_order]
  182. def process_date(self, data):
  183. """
  184. 处理日期列,生成周期性时间特征
  185. Args:
  186. data: 输入DataFrame,必须包含'index'或'date'列
  187. Returns:
  188. DataFrame: 包含时间特征的DataFrame
  189. Note:
  190. - 生成分钟级正弦/余弦特征(捕捉每日周期性模式)
  191. - 生成年中日正弦/余弦特征(捕捉年度周期性模式)
  192. - 使用三角函数编码确保时间连续性(避免边界突变)
  193. """
  194. self.logger.debug("开始处理日期特征")
  195. if 'index' in data.columns:
  196. data = data.rename(columns={'index': 'date'})
  197. data['date'] = pd.to_datetime(data['date'])
  198. data['minute_of_day'] = data['date'].dt.hour * 60 + data['date'].dt.minute
  199. data['day_of_year'] = data['date'].dt.dayofyear
  200. # 周期性编码(将时间转换为正弦/余弦值,确保周期性连续)
  201. data['minute_sin'] = np.sin(2 * np.pi * data['minute_of_day'] / 1440) # 分钟正弦特征
  202. data['minute_cos'] = np.cos(2 * np.pi * data['minute_of_day'] / 1440) # 分钟余弦特征
  203. data['day_year_sin'] = np.sin(2 * np.pi * data['day_of_year'] / 366) # 年中日正弦特征
  204. data['day_year_cos'] = np.cos(2 * np.pi * data['day_of_year'] / 366) # 年中日余弦特征
  205. # 移除原始时间列(仅保留编码后的特征)
  206. data.drop(columns=['minute_of_day', 'day_of_year'], inplace=True)
  207. # 调整列顺序:日期 + 时间特征 + 其他特征
  208. time_features = ['minute_sin', 'minute_cos', 'day_year_sin', 'day_year_cos']
  209. other_columns = [col for col in data.columns if col not in ['date'] + time_features]
  210. self.logger.debug(f"生成时间特征: {time_features}")
  211. return data[['date'] + time_features + other_columns]
  212. def scaler_data(self, data):
  213. """
  214. 对数据进行归一化处理
  215. Args:
  216. data: 输入DataFrame
  217. Returns:
  218. DataFrame: 归一化后的DataFrame
  219. Note:
  220. - 使用训练时保存的scaler进行归一化
  221. - 保持与训练数据的归一化方式一致(MinMax 0-1缩放)
  222. - 日期列不参与归一化
  223. """
  224. self.logger.debug("开始数据归一化")
  225. date_col = data[['date']]
  226. data_to_scale = data.drop(columns=['date'])
  227. scaled = self.scaler.transform(data_to_scale)
  228. scaled_df = pd.DataFrame(scaled, columns=data_to_scale.columns)
  229. # 拼接日期列和归一化后的特征列
  230. result = pd.concat([date_col.reset_index(drop=True), scaled_df], axis=1)
  231. self.logger.debug(f"归一化完成,数据形状: {result.shape}")
  232. return result
  233. def remove_outliers(self, predictions):
  234. """
  235. 使用四分位法处理预测结果中的异常值
  236. Args:
  237. predictions: numpy数组,形状为[时间步, 标签数]
  238. Returns:
  239. numpy数组: 处理异常值后的预测结果
  240. Note:
  241. - 异常值定义:小于Q1-1.5*IQR或大于Q3+1.5*IQR的值
  242. - 异常值替换为正常值的平均值(避免极端值影响结果)
  243. - 按列(每个指标)独立处理
  244. """
  245. self.logger.info("开始移除异常值(四分位法)")
  246. cleaned = predictions.copy()
  247. # 遍历每个特征列(16个标签)
  248. for col in range(cleaned.shape[1]):
  249. values = cleaned[:, col]
  250. # 计算四分位数
  251. q1 = np.percentile(values, 25)
  252. q3 = np.percentile(values, 75)
  253. iqr = q3 - q1
  254. # 异常值边界
  255. lower_bound = q1 - 1.5 * iqr
  256. upper_bound = q3 + 1.5 * iqr
  257. # 筛选正常值
  258. normal_values = values[(values >= lower_bound) & (values <= upper_bound)]
  259. # 用正常值的平均值替换异常值
  260. if len(normal_values) > 0:
  261. mean_normal = np.mean(normal_values)
  262. outlier_count = np.sum((values < lower_bound) | (values > upper_bound))
  263. if outlier_count > 0:
  264. self.logger.debug(f"列{col}: 检测到{outlier_count}个异常值,替换为均值{mean_normal:.4f}")
  265. cleaned[(values < lower_bound) | (values > upper_bound), col] = mean_normal
  266. self.logger.info("异常值处理完成")
  267. return cleaned
  268. def smooth_predictions(self, predictions):
  269. """
  270. 对预测结果进行加权平滑处理
  271. Args:
  272. predictions: numpy数组,形状为[时间步, 标签数]
  273. Returns:
  274. numpy数组: 平滑后的预测结果
  275. Note:
  276. - 采用滑动窗口加权平均减少预测波动
  277. - 中间值权重为2,前后邻居权重为1
  278. - 边缘值特殊处理(避免过度平滑)
  279. """
  280. self.logger.info("开始平滑预测结果")
  281. smoothed = predictions.copy()
  282. n_timesteps = predictions.shape[0]
  283. if n_timesteps <= 1:
  284. return smoothed
  285. # 遍历每个特征列
  286. for col in range(predictions.shape[1]):
  287. values = predictions[:, col]
  288. # 第一个值:加权前两个值(避免边缘过度平滑)
  289. smoothed[0, col] = (2 * values[0] + values[1]) / 3
  290. # 中间值:加权前后邻居(核心平滑)
  291. for i in range(1, n_timesteps - 1):
  292. smoothed[i, col] = (values[i-1] + 2 * values[i] + values[i+1]) / 4
  293. # 最后一个值:加权最后两个值(避免边缘过度平滑)
  294. smoothed[-1, col] = (values[-2] + 2 * values[-1]) / 3
  295. self.logger.info("预测结果平滑完成")
  296. return smoothed
  297. def create_test_loader(self, df):
  298. """
  299. 构建测试数据加载器
  300. Args:
  301. df: 预处理后的DataFrame
  302. Returns:
  303. DataLoader: PyTorch数据加载器
  304. Note:
  305. - 将原始时间序列数据转换为模型输入格式
  306. - 构建滑动窗口序列:[样本数, 序列长度, 特征数]
  307. - 确保有足够的历史数据构建输入序列
  308. """
  309. self.logger.info("创建测试数据加载器")
  310. df['date'] = pd.to_datetime(df['date'])
  311. # 计算时间间隔(根据分辨率,单位:分钟)
  312. time_interval = pd.Timedelta(minutes=(4 * self.resolution / 60))
  313. # 计算窗口时间跨度(确保能覆盖输入序列长度+预测步长)
  314. window_time_span = time_interval * (self.seq_len + 20)
  315. # 调整测试集起始时间(确保有足够的历史数据构建输入序列)
  316. adjusted_test_start = pd.to_datetime(self.test_start_date) - window_time_span
  317. # 筛选所需的历史数据
  318. test_df = df[df['date'] >= adjusted_test_start].reset_index(drop=True)
  319. test_df = test_df.drop(columns=['date'])
  320. # 构建监督学习数据集(输入序列+目标序列的占位)
  321. feature_columns = test_df.columns.tolist()
  322. cols = []
  323. # 构建输入序列(历史seq_len个时间步的特征)
  324. for col in feature_columns:
  325. for i in range(self.seq_len - 1, -1, -1):
  326. cols.append(test_df[[col]].shift(i)) # 滞后i步的特征(t-0到t-(seq_len-1))
  327. # 构建目标序列占位(未来output_size个时间步的标签,预测时不使用真实值)
  328. for i in range(1, self.output_size + 1):
  329. for col in feature_columns[-self.labels_num:]:
  330. cols.append(test_df[[col]].shift(-i)) # 超前i步的标签(t+1到t+output_size)
  331. # 合并列并按步长采样,最后取最后一行作为预测输入(最新的历史数据)
  332. dataset = pd.concat(cols, axis=1).iloc[::self.step_size]
  333. dataset = dataset.iloc[[-1]]
  334. # 提取输入特征(前n_features_total列)
  335. n_features_total = self.feature_num * self.seq_len
  336. supervised_data = dataset.iloc[:, :n_features_total]
  337. # 转换为模型输入格式:[样本数, 序列长度, 特征数]
  338. X = supervised_data.values.reshape(-1, self.seq_len, self.feature_num)
  339. X = torch.tensor(X, dtype=torch.float32).to(self.device)
  340. tensor_dataset = TensorDataset(X)
  341. loader = DataLoader(tensor_dataset, batch_size=self.batch_size, shuffle=False)
  342. self.logger.info(f"测试数据加载器创建完成,输入形状: {X.shape}")
  343. return loader
  344. @log_execution_time
  345. def load_data(self, df):
  346. """
  347. 数据加载和预处理主流程
  348. Args:
  349. df: 原始输入DataFrame
  350. Note:
  351. - 重排列特征列顺序
  352. - 下采样(根据resolution参数)
  353. - 日期特征工程
  354. - 数据归一化
  355. - 创建测试数据加载器
  356. - 加载图结构边索引
  357. """
  358. self.logger.info("开始加载和预处理数据")
  359. self.logger.info(f"原始数据形状: {df.shape}")
  360. df = self.reorder_columns(df)
  361. self.logger.info(f"下采样率: {self.resolution}")
  362. df = df.iloc[::self.resolution, :].reset_index(drop=True)
  363. self.logger.info(f"下采样后数据形状: {df.shape}")
  364. df = self.process_date(df)
  365. df = self.scaler_data(df)
  366. self.test_loader = self.create_test_loader(df)
  367. if not os.path.exists(self.edge_index_path):
  368. self.logger.error(f"图边索引文件不存在: {self.edge_index_path}")
  369. raise FileNotFoundError(f"图边索引文件不存在: {self.edge_index_path}")
  370. self.logger.info(f"加载图边索引: {self.edge_index_path}")
  371. self.edge_index = torch.load(self.edge_index_path, map_location=self.device, weights_only=True)
  372. self.logger.info("数据加载和预处理完成")
  373. @log_execution_time
  374. def load_model(self):
  375. """
  376. 加载模型结构和预训练权重
  377. Raises:
  378. FileNotFoundError: 模型文件不存在
  379. Note:
  380. - 实例化GAT-LSTM模型
  381. - 加载预训练权重
  382. - 设置为评估模式(关闭dropout和batch normalization)
  383. - 设置图结构边索引
  384. """
  385. if not os.path.exists(self.model_path):
  386. self.logger.error(f"模型文件不存在: {self.model_path}")
  387. raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
  388. self.logger.info("开始加载模型")
  389. self.logger.info(f"模型路径: {self.model_path}")
  390. self.model = GAT_LSTM(self).to(self.device)
  391. if self.edge_index is not None:
  392. self.logger.debug(f"设置图边索引,形状: {self.edge_index.shape}")
  393. self.model.set_edge_index(self.edge_index.to(self.device))
  394. self.model.load_state_dict(torch.load(self.model_path, map_location=self.device, weights_only=True))
  395. self.model.eval()
  396. # 统计模型参数量
  397. total_params = sum(p.numel() for p in self.model.parameters())
  398. trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
  399. self.logger.info(f"模型加载完成 - 总参数量: {total_params:,}, 可训练参数量: {trainable_params:,}")
  400. @log_execution_time
  401. def predict(self, df):
  402. """
  403. 执行预测主流程
  404. Args:
  405. df: 原始输入DataFrame,必须包含'index'列(时间戳)
  406. Returns:
  407. numpy数组: 反归一化后的预测结果,形状为[output_size, labels_num]
  408. Note:
  409. - 自动更新测试起始时间为输入数据最新时间+4分钟
  410. - 执行数据预处理
  411. - 加载模型
  412. - 执行批量预测
  413. - 反归一化预测结果
  414. - 可选的异常值处理和平滑
  415. """
  416. self.logger.info("=" * 80)
  417. self.logger.info("开始预测流程")
  418. self.logger.info("=" * 80)
  419. # 更新测试起始时间为输入数据最新时间+4分钟(预测起始点)
  420. latest_time = pd.to_datetime(df['index']).max()
  421. self.test_start_date = (latest_time + timedelta(minutes=4)).strftime("%Y-%m-%d %H:%M:%S")
  422. self.logger.info(f"输入数据最新时间: {latest_time}")
  423. self.logger.info(f"预测起始时间: {self.test_start_date}")
  424. # 加载和预处理数据
  425. self.load_data(df)
  426. # 加载模型
  427. self.load_model()
  428. # 执行预测
  429. self.logger.info("开始模型推理")
  430. all_predictions = []
  431. with torch.no_grad():
  432. for batch_idx, batch in enumerate(self.test_loader):
  433. inputs = batch[0].to(self.device)
  434. outputs = self.model(inputs)
  435. all_predictions.append(outputs.cpu().numpy())
  436. self.logger.debug(f"批次 {batch_idx + 1} 推理完成,输入形状: {inputs.shape}, 输出形状: {outputs.shape}")
  437. # 拼接所有批次的预测结果,并重塑为[时间步, 标签数]
  438. predictions = np.concatenate(all_predictions, axis=0).reshape(-1, self.labels_num)
  439. self.logger.info(f"模型推理完成,预测结果形状: {predictions.shape}")
  440. # 反归一化(仅对标签列,使用训练时的scaler参数)
  441. self.logger.info("开始反归一化预测结果")
  442. from sklearn.preprocessing import MinMaxScaler
  443. inverse_scaler = MinMaxScaler()
  444. inverse_scaler.min_ = self.scaler.min_[-self.labels_num:]
  445. inverse_scaler.scale_ = self.scaler.scale_[-self.labels_num:]
  446. predictions = inverse_scaler.inverse_transform(predictions)
  447. self.logger.info("反归一化完成")
  448. # 可选:异常值处理和平滑(根据配置文件决定是否启用)
  449. if self.remove_outliers_flag:
  450. predictions = self.remove_outliers(predictions)
  451. if self.smooth_flag:
  452. predictions = self.smooth_predictions(predictions)
  453. self.logger.info(f"预测流程完成,最终预测结果形状: {predictions.shape}")
  454. self.logger.info(f"预测值范围: min={predictions.min():.4f}, max={predictions.max():.4f}, mean={predictions.mean():.4f}")
  455. return predictions
  456. def save_predictions(self, predictions):
  457. """
  458. 保存预测结果为CSV文件
  459. Args:
  460. predictions: 反归一化后的预测结果(numpy数组)
  461. Note:
  462. - 生成时间戳序列
  463. - 添加列名
  464. - 保存为CSV格式
  465. """
  466. self.logger.info("开始保存预测结果")
  467. start_time = datetime.strptime(self.test_start_date, "%Y-%m-%d %H:%M:%S")
  468. time_interval = timedelta(minutes=(4 * self.resolution / 60))
  469. timestamps = [start_time + i * time_interval for i in range(len(predictions))]
  470. # 定义16个预测目标的原始列名
  471. base_columns = [
  472. 'C.M.UF1_DB@press_PV', 'C.M.UF2_DB@press_PV', 'C.M.UF3_DB@press_PV', 'C.M.UF4_DB@press_PV',
  473. 'UF1Per','UF2Per','UF3Per','UF4Per',
  474. 'C.M.RO1_DB@DPT_1', 'C.M.RO2_DB@DPT_1', 'C.M.RO3_DB@DPT_1', 'C.M.RO4_DB@DPT_1',
  475. 'C.M.RO1_DB@DPT_2', 'C.M.RO2_DB@DPT_2', 'C.M.RO3_DB@DPT_2', 'C.M.RO4_DB@DPT_2',
  476. ]
  477. column_names = [
  478. 'C.M.UF1_DB@press_PV', 'C.M.UF2_DB@press_PV', 'C.M.UF3_DB@press_PV', 'C.M.UF4_DB@press_PV',
  479. 'C.M.RO1_DB@DPT_1', 'C.M.RO2_DB@DPT_1', 'C.M.RO3_DB@DPT_1', 'C.M.RO4_DB@DPT_1',
  480. 'C.M.RO1_DB@DPT_2', 'C.M.RO2_DB@DPT_2', 'C.M.RO3_DB@DPT_2', 'C.M.RO4_DB@DPT_2',
  481. 'RO1_CSFlow', 'RO2_CSFlow', 'RO3_CSFlow', 'RO4_CSFlow']
  482. pred_columns = [f'{col}_Predicted' for col in base_columns]
  483. df_result = pd.DataFrame(predictions, columns=pred_columns)
  484. df_result.insert(0, 'date', timestamps)
  485. df_result.to_csv(self.output_csv_path, index=False)
  486. self.logger.info(f"预测结果已保存至: {self.output_csv_path}")
  487. self.logger.info(f"预测时间范围: {timestamps[0]} 至 {timestamps[-1]}")
  488. self.logger.info(f"预测记录数: {len(predictions)}")
  489. if __name__ == '__main__':
  490. """
  491. 主函数:执行20分钟TMP预测
  492. 使用方法:
  493. 1. 准备输入数据(JSON格式)
  494. 2. 运行此脚本
  495. 3. 查看预测结果(保存在20min_predictions.csv)
  496. 输入数据格式:
  497. - JSON文件,包含历史时间序列数据
  498. - 必须包含'index'列(时间戳)和所有必需的特征列
  499. """
  500. import json
  501. import os
  502. import pandas as pd
  503. from datetime import timedelta
  504. try:
  505. # 初始化预测器(自动加载配置文件)
  506. predictor = Predictor()
  507. # 读取JSON文件作为输入数据
  508. json_file_path = '/Users/wmy/Downloads/pp.json' # pp.json文件路径,可根据实际位置修改
  509. if not os.path.exists(json_file_path):
  510. predictor.logger.error(f"输入文件不存在: {json_file_path}")
  511. raise FileNotFoundError(f"未找到文件: {json_file_path}")
  512. predictor.logger.info(f"读取输入文件: {json_file_path}")
  513. # 解析JSON并转换为DataFrame
  514. with open(json_file_path, 'r', encoding='utf-8') as f:
  515. json_data = json.load(f)
  516. df = pd.DataFrame(json_data)
  517. predictor.logger.info(f"成功读取输入数据,数据形状: {df.shape}")
  518. # 执行预测并保存结果
  519. predictions = predictor.predict(df)
  520. predictor.save_predictions(predictions)
  521. predictor.logger.info("=" * 80)
  522. predictor.logger.info("预测任务全部完成!")
  523. predictor.logger.info("=" * 80)
  524. except Exception as e:
  525. if 'predictor' in locals():
  526. predictor.logger.error(f"预测过程发生错误: {str(e)}", exc_info=True)
  527. else:
  528. print(f"初始化预测器时发生错误: {str(e)}")
  529. raise