|
|
@@ -192,6 +192,7 @@ class Predictor:
|
|
|
self.level = self.config.get('data.wavelet.level', 3)
|
|
|
self.level_after = self.config.get('data.wavelet.level_after', 4)
|
|
|
self.mode = self.config.get('data.wavelet.mode', 'soft')
|
|
|
+ self.min_rows = self.config.get('data.min_rows', 600)
|
|
|
|
|
|
# 阈值参数
|
|
|
self.uf_threshold = self.config.get('data.threshold.uf', 0.001)
|
|
|
@@ -232,9 +233,57 @@ class Predictor:
|
|
|
self.model = None
|
|
|
self.edge_index = None
|
|
|
self.test_loader = None
|
|
|
+ self.raw_input_data = None
|
|
|
|
|
|
self.logger.info("预测器初始化完成")
|
|
|
|
|
|
+ def ensure_min_rows(self, df):
|
|
|
+ """
|
|
|
+ 确保数据至少有指定行数,不足则进行前后补充
|
|
|
+ 向前补充:使用最早的数据向前扩展
|
|
|
+ 向后补充:使用最新的数据向后扩展
|
|
|
+ """
|
|
|
+ current_rows = len(df)
|
|
|
+ if current_rows >= self.min_rows:
|
|
|
+ return df
|
|
|
+
|
|
|
+ # 计算需要补充的行数
|
|
|
+ need_rows = self.min_rows - current_rows
|
|
|
+ self.logger.info(f"数据行数不足{self.min_rows}行(当前{current_rows}行),需要补充{need_rows}行")
|
|
|
+
|
|
|
+ # 计算时间间隔(假设数据是均匀采样的)
|
|
|
+ time_col = 'index'
|
|
|
+ df[time_col] = pd.to_datetime(df[time_col])
|
|
|
+ time_diff = (df[time_col].iloc[1] - df[time_col].iloc[0]).total_seconds()
|
|
|
+
|
|
|
+ # 向前补充(使用最早的数据)
|
|
|
+ forward_rows = need_rows // 2
|
|
|
+ if forward_rows > 0:
|
|
|
+ earliest_data = df.iloc[0:1].copy()
|
|
|
+ forward_data = []
|
|
|
+ for i in range(1, forward_rows + 1):
|
|
|
+ new_row = earliest_data.copy()
|
|
|
+ new_row[time_col] = earliest_data[time_col] - timedelta(seconds=time_diff * i)
|
|
|
+ forward_data.append(new_row)
|
|
|
+ forward_df = pd.concat(forward_data, ignore_index=True)
|
|
|
+ df = pd.concat([forward_df, df], ignore_index=True)
|
|
|
+
|
|
|
+ # 检查是否还需要向后补充
|
|
|
+ current_rows = len(df)
|
|
|
+ if current_rows < self.min_rows:
|
|
|
+ backward_rows = self.min_rows - current_rows
|
|
|
+ latest_data = df.iloc[-1:].copy()
|
|
|
+ backward_data = []
|
|
|
+ for i in range(1, backward_rows + 1):
|
|
|
+ new_row = latest_data.copy()
|
|
|
+ new_row[time_col] = latest_data[time_col] + timedelta(seconds=time_diff * i)
|
|
|
+ backward_data.append(new_row)
|
|
|
+ backward_df = pd.concat(backward_data, ignore_index=True)
|
|
|
+ df = pd.concat([df, backward_df], ignore_index=True)
|
|
|
+
|
|
|
+ self.logger.info(f"数据补充完成,当前行数:{len(df)}行")
|
|
|
+ return df
|
|
|
+
|
|
|
def reorder_columns(self, df):
|
|
|
"""调整数据列顺序,确保与训练时的特征顺序一致"""
|
|
|
desired_order = [
|
|
|
@@ -354,7 +403,27 @@ class Predictor:
|
|
|
loader = DataLoader(tensor_dataset, batch_size=self.batch_size, shuffle=False)
|
|
|
|
|
|
return loader
|
|
|
-
|
|
|
+
|
|
|
+ def get_recent_values_as_fallback(self):
|
|
|
+ """从原始输入数据中获取最近的output_size条记录作为备用输出"""
|
|
|
+ # 确保原始数据已保存
|
|
|
+ if self.raw_input_data is None:
|
|
|
+ raise ValueError("原始输入数据未保存,无法获取备用值")
|
|
|
+
|
|
|
+ # 按时间排序并取最近的output_size条
|
|
|
+ recent_data = self.raw_input_data.sort_values('index').tail(self.output_size)
|
|
|
+
|
|
|
+ # 若数据不足,用最后一条补充
|
|
|
+ if len(recent_data) < self.output_size:
|
|
|
+ last_row = recent_data.iloc[-1:] if not recent_data.empty else pd.DataFrame(
|
|
|
+ {col: [0.0] for col in self.target_columns}, index=[0])
|
|
|
+ while len(recent_data) < self.output_size:
|
|
|
+ recent_data = pd.concat([recent_data, last_row], ignore_index=True)
|
|
|
+
|
|
|
+ # 提取目标列值并返回
|
|
|
+ fallback_values = recent_data[self.target_columns].values
|
|
|
+ return fallback_values
|
|
|
+
|
|
|
@log_execution_time
|
|
|
def load_data(self, df):
|
|
|
"""数据加载和预处理"""
|
|
|
@@ -427,6 +496,12 @@ class Predictor:
|
|
|
"""执行预测"""
|
|
|
self.logger.info("[预测流程] 开始")
|
|
|
|
|
|
+ # 保存原始输入数据用于可能的降级策略
|
|
|
+ self.raw_input_data = df.copy()
|
|
|
+
|
|
|
+ # 确保数据行数不少于指定行数
|
|
|
+ df = self.ensure_min_rows(df)
|
|
|
+
|
|
|
try:
|
|
|
# 更新测试起始时间
|
|
|
latest_time = pd.to_datetime(df['index']).max()
|
|
|
@@ -474,6 +549,11 @@ class Predictor:
|
|
|
self.logger.error(f"[反归一化] 失败: {e}")
|
|
|
raise
|
|
|
|
|
|
+ # 检查是否有NaN值,有则使用备用值
|
|
|
+ if np.isnan(predictions).any():
|
|
|
+ self.logger.warning("[预测结果] 发现NaN值,使用最近值作为备用")
|
|
|
+ predictions = self.get_recent_values_as_fallback()
|
|
|
+
|
|
|
# 可选后处理
|
|
|
if self.remove_outliers_flag:
|
|
|
self.logger.info("[后处理] 执行异常值移除")
|