predict.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630
  1. """
  2. 20分钟TMP预测模型
  3. 版本:1.0
  4. 最后更新:2025-10-28
  5. """
  6. import os
  7. import sys
  8. import torch
  9. import pandas as pd
  10. import numpy as np
  11. import joblib
  12. import pywt
  13. from datetime import datetime, timedelta
  14. from torch.utils.data import DataLoader, TensorDataset
  15. from tqdm import tqdm
  16. # 添加父目录到系统路径以导入shared模块
  17. current_dir = os.path.dirname(os.path.abspath(__file__))
  18. parent_dir = os.path.dirname(current_dir)
  19. if parent_dir not in sys.path:
  20. sys.path.insert(0, parent_dir)
  21. # 从shared目录导入GAT-LSTM模型
  22. sys.path.insert(0, os.path.join(parent_dir, 'shared'))
  23. from gat_lstm import GAT_LSTM
  24. # 尝试导入common模块,如果失败则使用标准库
  25. try:
  26. project_root = os.path.abspath(os.path.join(parent_dir, '../..'))
  27. if project_root not in sys.path:
  28. sys.path.insert(0, project_root)
  29. from common.utils.logger import setup_logger, log_execution_time
  30. from common.utils.config import Config
  31. except ImportError:
  32. # 使用标准库作为fallback
  33. import logging
  34. import yaml
  35. from functools import wraps
  36. import time
  37. def setup_logger(name, level='INFO', log_file=None, format_type='colored', max_bytes=10485760, backup_count=5):
  38. """logger设置"""
  39. logger = logging.getLogger(name)
  40. logger.setLevel(getattr(logging, level))
  41. # 避免重复添加handler
  42. if logger.handlers:
  43. return logger
  44. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  45. # 控制台处理器
  46. console_handler = logging.StreamHandler()
  47. console_handler.setFormatter(formatter)
  48. logger.addHandler(console_handler)
  49. # 文件处理器
  50. if log_file:
  51. from logging.handlers import RotatingFileHandler
  52. file_handler = RotatingFileHandler(log_file, maxBytes=max_bytes, backupCount=backup_count)
  53. file_handler.setFormatter(formatter)
  54. logger.addHandler(file_handler)
  55. # 防止日志传播到root logger
  56. logger.propagate = False
  57. return logger
  58. def log_execution_time(func):
  59. """简化版执行时间装饰器"""
  60. @wraps(func)
  61. def wrapper(*args, **kwargs):
  62. start_time = time.time()
  63. result = func(*args, **kwargs)
  64. end_time = time.time()
  65. if hasattr(args[0], 'logger'):
  66. args[0].logger.info(f"{func.__name__} 执行时间: {end_time - start_time:.2f}秒")
  67. return result
  68. return wrapper
  69. class Config:
  70. """配置类"""
  71. def __init__(self, config_file):
  72. with open(config_file, 'r', encoding='utf-8') as f:
  73. self.config = yaml.safe_load(f)
  74. def get(self, key, default=None):
  75. keys = key.split('.')
  76. value = self.config
  77. for k in keys:
  78. if isinstance(value, dict):
  79. value = value.get(k)
  80. else:
  81. return default
  82. if value is None:
  83. return default
  84. return value
  85. def set_seed(seed):
  86. """
  87. 设置全局随机种子,保证实验可重复性
  88. Args:
  89. seed: 随机种子值
  90. Note:
  91. - 设置Python、NumPy、PyTorch的随机种子
  92. - 确保CUDA操作的确定性
  93. - 关闭CUDA的性能优化(以确保可重复性)
  94. """
  95. import random
  96. random.seed(seed) # Python随机数生成器
  97. os.environ['PYTHONHASHSEED'] = str(seed) # Python哈希种子
  98. np.random.seed(seed) # NumPy随机数生成器
  99. torch.manual_seed(seed) # PyTorch CPU随机数生成器
  100. torch.cuda.manual_seed(seed) # 当前GPU随机数生成器
  101. torch.cuda.manual_seed_all(seed) # 所有GPU随机数生成器
  102. torch.backends.cudnn.deterministic = True # 确保CUDA操作确定性
  103. torch.backends.cudnn.benchmark = False # 关闭CUDA性能优化
  104. class Predictor:
  105. """
  106. TMP预测器类
  107. 功能:
  108. - 加载并预处理输入数据
  109. - 加载训练好的GAT-LSTM模型
  110. - 执行预测并保存结果
  111. 使用示例:
  112. predictor = Predictor()
  113. predictions = predictor.predict(df)
  114. result_df = predictor.save_predictions(predictions, start_date)
  115. """
  116. def __init__(self, config_path='../config.yaml'):
  117. """
  118. 初始化预测器
  119. Args:
  120. config_path: 配置文件路径,相对于gat-lstm_model根目录
  121. Raises:
  122. FileNotFoundError: 配置文件或模型文件不存在
  123. Note:
  124. - 从配置文件加载所有参数
  125. - 自动检测并使用GPU(如果可用)
  126. - 加载训练时保存的数据归一化器
  127. """
  128. # 加载配置文件(指向父目录的config.yaml)
  129. current_dir = os.path.dirname(__file__)
  130. parent_dir = os.path.dirname(current_dir)
  131. config_file = os.path.join(parent_dir, 'config.yaml')
  132. self.config = Config(config_file)
  133. # 设置日志目录(在gat-lstm_model根目录的logs下)
  134. log_dir = os.path.join(parent_dir, 'logs')
  135. os.makedirs(log_dir, exist_ok=True)
  136. log_file = os.path.join(log_dir, '20min_predict.log')
  137. self.logger = setup_logger(
  138. name='20min_predict',
  139. level=self.config.get('logging.level', 'INFO'),
  140. log_file=log_file,
  141. format_type=self.config.get('logging.format', 'colored'),
  142. max_bytes=self.config.get('logging.max_bytes', 10*1024*1024),
  143. backup_count=self.config.get('logging.backup_count', 5)
  144. )
  145. self.logger.info("初始化20分钟TMP预测器")
  146. # 模型参数(从配置文件加载)
  147. self.seq_len = self.config.get('model.seq_len', 10)
  148. self.output_size = self.config.get('model.output_size', 5)
  149. self.labels_num = self.config.get('model.labels_num', 16)
  150. self.feature_num = self.config.get('model.feature_num', 79)
  151. self.step_size = self.config.get('model.step_size', 5)
  152. self.dropout = self.config.get('model.dropout', 0)
  153. self.lr = self.config.get('model.lr', 0.01)
  154. self.num_heads = self.config.get('model.num_heads', 8)
  155. self.hidden_size = self.config.get('model.hidden_size', 64)
  156. self.batch_size = self.config.get('model.batch_size', 512)
  157. self.num_layers = self.config.get('model.num_layers', 1)
  158. self.random_seed = self.config.get('model.random_seed', 1314)
  159. # 数据处理参数
  160. self.resolution = self.config.get('data.resolution', 60)
  161. self.test_start_date = self.config.get('data.test_start_date', '2025-07-01')
  162. self.wavelet = self.config.get('data.wavelet.type', 'db4')
  163. self.level = self.config.get('data.wavelet.level', 3)
  164. self.level_after = self.config.get('data.wavelet.level_after', 4)
  165. self.mode = self.config.get('data.wavelet.mode', 'soft')
  166. self.min_rows = self.config.get('data.min_rows', 600)
  167. # 阈值参数
  168. self.uf_threshold = self.config.get('data.threshold.uf', 0.001)
  169. self.ro_threshold = self.config.get('data.threshold.ro', 0.01)
  170. self.flow_threshold = self.config.get('data.threshold.flow', 1.0)
  171. # 文件路径(相对于20min目录)
  172. self.model_path = os.path.join(current_dir, '20min_model.pth')
  173. self.scaler_path = os.path.join(current_dir, '20min_scaler.pkl')
  174. self.edge_index_path = os.path.join(parent_dir, 'shared', 'edge_index.pt')
  175. self.output_csv_path = os.path.join(current_dir, '20min_predictions.csv')
  176. # 后处理参数
  177. self.remove_outliers_flag = self.config.get('postprocess.remove_outliers', False)
  178. self.smooth_flag = self.config.get('postprocess.smooth', False)
  179. # 预测目标列名
  180. self.target_columns = self.config.get('target_columns', [])
  181. # 设备配置
  182. use_cuda = self.config.get('device.use_cuda', True)
  183. cuda_device = self.config.get('device.cuda_device', 0)
  184. if use_cuda and torch.cuda.is_available():
  185. self.device = torch.device(f"cuda:{cuda_device}")
  186. self.logger.info(f"使用设备: GPU-{torch.cuda.get_device_name(cuda_device)}")
  187. else:
  188. self.device = torch.device("cpu")
  189. self.logger.info("使用设备: CPU")
  190. set_seed(self.random_seed)
  191. # 加载数据归一化器
  192. if not os.path.exists(self.scaler_path):
  193. self.logger.error(f"归一化器文件不存在: {self.scaler_path}")
  194. raise FileNotFoundError(f"归一化器文件不存在: {self.scaler_path}")
  195. self.scaler = joblib.load(self.scaler_path)
  196. # 初始化模型和数据加载器(后续加载)
  197. self.model = None
  198. self.edge_index = None
  199. self.test_loader = None
  200. self.raw_input_data = None
  201. self.logger.info("预测器初始化完成")
  202. def ensure_min_rows(self, df):
  203. """
  204. 确保数据至少有指定行数,不足则进行前后补充
  205. 向前补充:使用最早的数据向前扩展
  206. 向后补充:使用最新的数据向后扩展
  207. """
  208. current_rows = len(df)
  209. if current_rows >= self.min_rows:
  210. return df
  211. # 计算需要补充的行数
  212. need_rows = self.min_rows - current_rows
  213. self.logger.info(f"数据行数不足{self.min_rows}行(当前{current_rows}行),需要补充{need_rows}行")
  214. # 计算时间间隔(假设数据是均匀采样的)
  215. time_col = 'index'
  216. df[time_col] = pd.to_datetime(df[time_col])
  217. time_diff = (df[time_col].iloc[1] - df[time_col].iloc[0]).total_seconds()
  218. # 向前补充(使用最早的数据)
  219. forward_rows = need_rows // 2
  220. if forward_rows > 0:
  221. earliest_data = df.iloc[0:1].copy()
  222. forward_data = []
  223. for i in range(1, forward_rows + 1):
  224. new_row = earliest_data.copy()
  225. new_row[time_col] = earliest_data[time_col] - timedelta(seconds=time_diff * i)
  226. forward_data.append(new_row)
  227. forward_df = pd.concat(forward_data, ignore_index=True)
  228. df = pd.concat([forward_df, df], ignore_index=True)
  229. # 检查是否还需要向后补充
  230. current_rows = len(df)
  231. if current_rows < self.min_rows:
  232. backward_rows = self.min_rows - current_rows
  233. latest_data = df.iloc[-1:].copy()
  234. backward_data = []
  235. for i in range(1, backward_rows + 1):
  236. new_row = latest_data.copy()
  237. new_row[time_col] = latest_data[time_col] + timedelta(seconds=time_diff * i)
  238. backward_data.append(new_row)
  239. backward_df = pd.concat(backward_data, ignore_index=True)
  240. df = pd.concat([df, backward_df], ignore_index=True)
  241. self.logger.info(f"数据补充完成,当前行数:{len(df)}行")
  242. return df
  243. def reorder_columns(self, df):
  244. """调整数据列顺序,确保与训练时的特征顺序一致"""
  245. desired_order = [
  246. 'index',
  247. 'C.M.FT_ZGJJY1@out','C.M.RO1_FT_JS@out','C.M.RO2_FT_JS@out','C.M.RO3_FT_JS@out',
  248. 'C.M.RO4_FT_JS@out','C.M.UF1_FT_JS@out','C.M.UF2_FT_JS@out','C.M.UF3_FT_JS@out',
  249. 'C.M.UF4_FT_JS@out','C.M.UF_FT_ZCS@out','C.M.FT_ZGJJY2@out','C.M.FT_ZGJJY3@out',
  250. 'C.M.FT_ZGJJY4@out','C.M.RO1_PT_JS@out','C.M.RO2_PT_JS@out','C.M.RO3_PT_JS@out',
  251. 'C.M.UF1_PT_JS@out','C.M.UF2_PT_JS@out','C.M.UF3_PT_JS@out','C.M.UF4_PT_JS@out',
  252. 'C.M.LT_JSC@out','C.M.RO1_PT_CS@out','C.M.RO1_PT_DJ2@out','C.M.RO2_PT_CS@out',
  253. 'C.M.RO2_PT_DJ2@out','C.M.RO3_PT_CS@out','C.M.RO3_PT_DJ2@out','C.M.RO4_PT_CS@out',
  254. 'C.M.RO4_PT_DJ2@out','C.M.RO4_PT_JS@out','C.M.LT_HCl@out','C.M.LT_NaClO@out',
  255. 'C.M.LT_PAC@out','C.M.LT_QSC@out','C.M.RO_Cond_ZCS@out','C.M.RO_TT_ZJS@out',
  256. 'C.M.UF1_JSF_kd@out','C.M.UF2_JSF_kd@out','C.M.UF_GSB4_fre@out','C.M.UF_ORP_ZCS@out',
  257. 'C.M.JYB2_ZGJ1_fre@out','C.M.JYB2_ZGJ2_fre@out','C.M.JYB2_ZGJ3_fre@out','C.M.JYB2_ZGJ4_fre@out',
  258. 'C.M.RO1_GYB_fre@out','C.M.RO2_GYB_fre@out','C.M.RO3_GYB_fre@out','C.M.RO4_GYB_fre@out',
  259. 'C.M.UF3_JSF_kd@out','C.M.UF4_JSF_kd@out','C.M.UF_FXB2_fre@out','C.M.RO1_DJB_fre@out',
  260. 'C.M.RO1_GYBF_kd@out','C.M.RO2_DJB_fre@out','C.M.RO2_GYBF_kd@out','C.M.RO3_DJB_fre@out',
  261. 'C.M.RO3_GYBF_kd@out','C.M.RO4_DJB_fre@out','C.M.RO4_GYBF_kd@out',
  262. 'C.M.UF1_DB@press_PV','C.M.UF2_DB@press_PV','C.M.UF3_DB@press_PV','C.M.UF4_DB@press_PV',
  263. 'UF1Per','UF2Per','UF3Per','UF4Per',
  264. 'C.M.RO1_DB@DPT_1','C.M.RO2_DB@DPT_1','C.M.RO3_DB@DPT_1','C.M.RO4_DB@DPT_1',
  265. 'C.M.RO1_DB@DPT_2','C.M.RO2_DB@DPT_2','C.M.RO3_DB@DPT_2','C.M.RO4_DB@DPT_2',
  266. ]
  267. return df.loc[:, desired_order]
  268. def process_date(self, data):
  269. """处理日期列,生成周期性时间特征"""
  270. if 'index' in data.columns:
  271. data = data.rename(columns={'index': 'date'})
  272. data['date'] = pd.to_datetime(data['date'])
  273. data['minute_of_day'] = data['date'].dt.hour * 60 + data['date'].dt.minute
  274. data['day_of_year'] = data['date'].dt.dayofyear
  275. # 周期性编码
  276. data['minute_sin'] = np.sin(2 * np.pi * data['minute_of_day'] / 1440)
  277. data['minute_cos'] = np.cos(2 * np.pi * data['minute_of_day'] / 1440)
  278. data['day_year_sin'] = np.sin(2 * np.pi * data['day_of_year'] / 366)
  279. data['day_year_cos'] = np.cos(2 * np.pi * data['day_of_year'] / 366)
  280. data.drop(columns=['minute_of_day', 'day_of_year'], inplace=True)
  281. # 调整列顺序
  282. time_features = ['minute_sin', 'minute_cos', 'day_year_sin', 'day_year_cos']
  283. other_columns = [col for col in data.columns if col not in ['date'] + time_features]
  284. return data[['date'] + time_features + other_columns]
  285. def scaler_data(self, data):
  286. """对数据进行归一化处理"""
  287. date_col = data[['date']]
  288. data_to_scale = data.drop(columns=['date'])
  289. scaled = self.scaler.transform(data_to_scale)
  290. scaled_df = pd.DataFrame(scaled, columns=data_to_scale.columns)
  291. result = pd.concat([date_col.reset_index(drop=True), scaled_df], axis=1)
  292. return result
  293. def remove_outliers(self, predictions):
  294. """使用四分位法处理预测结果中的异常值"""
  295. cleaned = predictions.copy()
  296. for col in range(cleaned.shape[1]):
  297. values = cleaned[:, col]
  298. q1 = np.percentile(values, 25)
  299. q3 = np.percentile(values, 75)
  300. iqr = q3 - q1
  301. lower_bound = q1 - 1.5 * iqr
  302. upper_bound = q3 + 1.5 * iqr
  303. normal_values = values[(values >= lower_bound) & (values <= upper_bound)]
  304. if len(normal_values) > 0:
  305. mean_normal = np.mean(normal_values)
  306. cleaned[(values < lower_bound) | (values > upper_bound), col] = mean_normal
  307. return cleaned
  308. def smooth_predictions(self, predictions):
  309. """对预测结果进行加权平滑处理"""
  310. smoothed = predictions.copy()
  311. n_timesteps = predictions.shape[0]
  312. if n_timesteps <= 1:
  313. return smoothed
  314. for col in range(predictions.shape[1]):
  315. values = predictions[:, col]
  316. smoothed[0, col] = (2 * values[0] + values[1]) / 3
  317. for i in range(1, n_timesteps - 1):
  318. smoothed[i, col] = (values[i-1] + 2 * values[i] + values[i+1]) / 4
  319. smoothed[-1, col] = (values[-2] + 2 * values[-1]) / 3
  320. return smoothed
  321. def create_test_loader(self, df):
  322. """构建测试数据加载器"""
  323. df['date'] = pd.to_datetime(df['date'])
  324. time_interval = pd.Timedelta(minutes=(4 * self.resolution / 60))
  325. window_time_span = time_interval * (self.seq_len + 20)
  326. adjusted_test_start = pd.to_datetime(self.test_start_date) - window_time_span
  327. test_df = df[df['date'] >= adjusted_test_start].reset_index(drop=True)
  328. test_df = test_df.drop(columns=['date'])
  329. # 构建监督学习数据集
  330. feature_columns = test_df.columns.tolist()
  331. cols = []
  332. for col in feature_columns:
  333. for i in range(self.seq_len - 1, -1, -1):
  334. cols.append(test_df[[col]].shift(i))
  335. for i in range(1, self.output_size + 1):
  336. for col in feature_columns[-self.labels_num:]:
  337. cols.append(test_df[[col]].shift(-i))
  338. dataset = pd.concat(cols, axis=1).iloc[::self.step_size]
  339. dataset = dataset.iloc[[-1]]
  340. n_features_total = self.feature_num * self.seq_len
  341. supervised_data = dataset.iloc[:, :n_features_total]
  342. X = supervised_data.values.reshape(-1, self.seq_len, self.feature_num)
  343. X = torch.tensor(X, dtype=torch.float32).to(self.device)
  344. tensor_dataset = TensorDataset(X)
  345. loader = DataLoader(tensor_dataset, batch_size=self.batch_size, shuffle=False)
  346. return loader
  347. def get_recent_values_as_fallback(self):
  348. """从原始输入数据中获取最近的output_size条记录作为备用输出"""
  349. # 确保原始数据已保存
  350. if self.raw_input_data is None:
  351. raise ValueError("原始输入数据未保存,无法获取备用值")
  352. # 按时间排序并取最近的output_size条
  353. recent_data = self.raw_input_data.sort_values('index').tail(self.output_size)
  354. # 若数据不足,用最后一条补充
  355. if len(recent_data) < self.output_size:
  356. last_row = recent_data.iloc[-1:] if not recent_data.empty else pd.DataFrame(
  357. {col: [0.0] for col in self.target_columns}, index=[0])
  358. while len(recent_data) < self.output_size:
  359. recent_data = pd.concat([recent_data, last_row], ignore_index=True)
  360. # 提取目标列值并返回
  361. fallback_values = recent_data[self.target_columns].values
  362. return fallback_values
  363. @log_execution_time
  364. def load_data(self, df):
  365. """数据加载和预处理"""
  366. self.logger.info(f"[数据加载] 原始形状: {df.shape}, 列数: {len(df.columns)}")
  367. try:
  368. df = self.reorder_columns(df)
  369. self.logger.info(f"[列重排] 完成")
  370. except Exception as e:
  371. self.logger.error(f"[列重排] 失败: {e}")
  372. raise
  373. df = df.iloc[::self.resolution, :].reset_index(drop=True)
  374. self.logger.info(f"[下采样] 采样率={self.resolution}, 采样后形状: {df.shape}")
  375. try:
  376. df = self.process_date(df)
  377. self.logger.info(f"[时间特征] 生成完成")
  378. except Exception as e:
  379. self.logger.error(f"[时间特征] 生成失败: {e}")
  380. raise
  381. try:
  382. df = self.scaler_data(df)
  383. self.logger.info(f"[归一化] 完成")
  384. except Exception as e:
  385. self.logger.error(f"[归一化] 失败: {e}")
  386. raise
  387. try:
  388. self.test_loader = self.create_test_loader(df)
  389. self.logger.info(f"[数据加载器] 创建完成")
  390. except Exception as e:
  391. self.logger.error(f"[数据加载器] 创建失败: {e}")
  392. raise
  393. if not os.path.exists(self.edge_index_path):
  394. self.logger.error(f"[图结构] 边索引文件不存在: {self.edge_index_path}")
  395. raise FileNotFoundError(f"图边索引文件不存在: {self.edge_index_path}")
  396. self.edge_index = torch.load(self.edge_index_path, map_location=self.device, weights_only=True)
  397. self.logger.info(f"[图结构] 边索引加载完成, shape: {self.edge_index.shape}")
  398. @log_execution_time
  399. def load_model(self):
  400. """加载模型和预训练权重"""
  401. if not os.path.exists(self.model_path):
  402. self.logger.error(f"[模型加载] 文件不存在: {self.model_path}")
  403. raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
  404. try:
  405. self.logger.info("[模型加载] 初始化模型结构")
  406. self.model = GAT_LSTM(self).to(self.device)
  407. if self.edge_index is not None:
  408. self.model.set_edge_index(self.edge_index.to(self.device))
  409. self.logger.info("[模型加载] 图结构边索引设置完成")
  410. self.model.load_state_dict(torch.load(self.model_path, map_location=self.device, weights_only=True))
  411. self.model.eval()
  412. total_params = sum(p.numel() for p in self.model.parameters())
  413. self.logger.info(f"[模型加载] 完成 - 参数量: {total_params:,}")
  414. except Exception as e:
  415. self.logger.error(f"[模型加载] 失败: {e}")
  416. raise
  417. @log_execution_time
  418. def predict(self, df):
  419. """执行预测"""
  420. self.logger.info("[预测流程] 开始")
  421. # 保存原始输入数据用于可能的降级策略
  422. self.raw_input_data = df.copy()
  423. # 确保数据行数不少于指定行数
  424. df = self.ensure_min_rows(df)
  425. try:
  426. # 更新测试起始时间
  427. latest_time = pd.to_datetime(df['index']).max()
  428. self.test_start_date = (latest_time + timedelta(minutes=4)).strftime("%Y-%m-%d %H:%M:%S")
  429. self.logger.info(f"[预测时间] 输入数据最新时间: {latest_time}, 预测起始时间: {self.test_start_date}")
  430. except Exception as e:
  431. self.logger.error(f"[预测时间] 计算失败: {e}")
  432. raise
  433. # 加载和预处理数据
  434. self.load_data(df)
  435. # 加载模型
  436. self.load_model()
  437. # 执行推理
  438. try:
  439. self.logger.info("[模型推理] 开始")
  440. all_predictions = []
  441. batch_count = 0
  442. with torch.no_grad():
  443. for batch in self.test_loader:
  444. inputs = batch[0].to(self.device)
  445. outputs = self.model(inputs)
  446. all_predictions.append(outputs.cpu().numpy())
  447. batch_count += 1
  448. self.logger.info(f"[模型推理] 完成 - 批次数: {batch_count}")
  449. predictions = np.concatenate(all_predictions, axis=0).reshape(-1, self.labels_num)
  450. self.logger.info(f"[模型推理] 原始预测形状: {predictions.shape}")
  451. except Exception as e:
  452. self.logger.error(f"[模型推理] 失败: {e}")
  453. raise
  454. # 反归一化
  455. try:
  456. self.logger.info("[反归一化] 开始")
  457. from sklearn.preprocessing import MinMaxScaler
  458. inverse_scaler = MinMaxScaler()
  459. inverse_scaler.min_ = self.scaler.min_[-self.labels_num:]
  460. inverse_scaler.scale_ = self.scaler.scale_[-self.labels_num:]
  461. predictions = inverse_scaler.inverse_transform(predictions)
  462. self.logger.info(f"[反归一化] 完成 - 值域: [{predictions.min():.2f}, {predictions.max():.2f}]")
  463. except Exception as e:
  464. self.logger.error(f"[反归一化] 失败: {e}")
  465. raise
  466. # 检查是否有NaN值,有则使用备用值
  467. if np.isnan(predictions).any():
  468. self.logger.warning("[预测结果] 发现NaN值,使用最近值作为备用")
  469. predictions = self.get_recent_values_as_fallback()
  470. # 可选后处理
  471. if self.remove_outliers_flag:
  472. self.logger.info("[后处理] 执行异常值移除")
  473. predictions = self.remove_outliers(predictions)
  474. if self.smooth_flag:
  475. self.logger.info("[后处理] 执行平滑处理")
  476. predictions = self.smooth_predictions(predictions)
  477. self.logger.info(f"[预测流程] 完成 - 最终形状: {predictions.shape}, 值域: [{predictions.min():.2f}, {predictions.max():.2f}]")
  478. return predictions
  479. def save_predictions(self, predictions, start_date=None, output_path=None):
  480. """保存预测结果为CSV并返回DataFrame"""
  481. try:
  482. if start_date is None:
  483. start_date = self.test_start_date
  484. self.logger.info(f"[保存结果] 预测起始时间: {start_date}, 预测点数: {len(predictions)}")
  485. start_time = datetime.strptime(start_date, "%Y-%m-%d %H:%M:%S")
  486. time_interval = timedelta(minutes=(4 * self.resolution / 60))
  487. timestamps = [start_time + i * time_interval for i in range(len(predictions))]
  488. base_columns = self.target_columns
  489. pred_columns = [f'{col}_Predicted' for col in base_columns]
  490. df_result = pd.DataFrame(predictions, columns=pred_columns)
  491. df_result.insert(0, 'index', timestamps)
  492. save_path = output_path if output_path else self.output_csv_path
  493. df_result.to_csv(save_path, index=False)
  494. self.logger.info(f"[保存结果] 完成 - 文件: {save_path}, 时间范围: {timestamps[0]} 至 {timestamps[-1]}")
  495. return df_result
  496. except Exception as e:
  497. self.logger.error(f"[保存结果] 失败: {e}")
  498. raise
  499. if __name__ == '__main__':
  500. """主函数:执行20分钟TMP预测"""
  501. import json
  502. import os
  503. import pandas as pd
  504. try:
  505. predictor = Predictor()
  506. json_file_path = '/Users/wmy/Downloads/pp.json'
  507. if not os.path.exists(json_file_path):
  508. predictor.logger.error(f"输入文件不存在: {json_file_path}")
  509. raise FileNotFoundError(f"未找到文件: {json_file_path}")
  510. predictor.logger.info(f"读取输入文件: {json_file_path}")
  511. with open(json_file_path, 'r', encoding='utf-8') as f:
  512. json_data = json.load(f)
  513. df = pd.DataFrame(json_data)
  514. predictions = predictor.predict(df)
  515. predictor.save_predictions(predictions)
  516. predictor.logger.info("预测任务完成")
  517. except Exception as e:
  518. if 'predictor' in locals():
  519. predictor.logger.error(f"预测失败: {str(e)}", exc_info=True)
  520. else:
  521. print(f"初始化失败: {str(e)}")
  522. raise