فهرست منبع

增加代码逻辑,避免Nan值

zhanghao 5 ماه پیش
والد
کامیت
8a37ce6eb2
1فایلهای تغییر یافته به همراه81 افزوده شده و 1 حذف شده
  1. 81 1
      models/pressure-predictor/gat-lstm_model/20min/predict.py

+ 81 - 1
models/pressure-predictor/gat-lstm_model/20min/predict.py

@@ -192,6 +192,7 @@ class Predictor:
         self.level = self.config.get('data.wavelet.level', 3)
         self.level_after = self.config.get('data.wavelet.level_after', 4)
         self.mode = self.config.get('data.wavelet.mode', 'soft')
+        self.min_rows = self.config.get('data.min_rows', 600)
         
         # 阈值参数
         self.uf_threshold = self.config.get('data.threshold.uf', 0.001)
@@ -232,9 +233,57 @@ class Predictor:
         self.model = None
         self.edge_index = None
         self.test_loader = None
+        self.raw_input_data = None
         
         self.logger.info("预测器初始化完成")
         
+	def ensure_min_rows(self, df):
+        """
+        确保数据至少有指定行数,不足则进行前后补充
+        向前补充:使用最早的数据向前扩展
+        向后补充:使用最新的数据向后扩展
+        """
+        current_rows = len(df)
+        if current_rows >= self.min_rows:
+            return df
+        
+        # 计算需要补充的行数
+        need_rows = self.min_rows - current_rows
+        self.logger.info(f"数据行数不足{self.min_rows}行(当前{current_rows}行),需要补充{need_rows}行")
+        
+        # 计算时间间隔(假设数据是均匀采样的)
+        time_col = 'index'
+        df[time_col] = pd.to_datetime(df[time_col])
+        time_diff = (df[time_col].iloc[1] - df[time_col].iloc[0]).total_seconds()
+        
+        # 向前补充(使用最早的数据)
+        forward_rows = need_rows // 2
+        if forward_rows > 0:
+            earliest_data = df.iloc[0:1].copy()
+            forward_data = []
+            for i in range(1, forward_rows + 1):
+                new_row = earliest_data.copy()
+                new_row[time_col] = earliest_data[time_col] - timedelta(seconds=time_diff * i)
+                forward_data.append(new_row)
+            forward_df = pd.concat(forward_data, ignore_index=True)
+            df = pd.concat([forward_df, df], ignore_index=True)
+        
+        # 检查是否还需要向后补充
+        current_rows = len(df)
+        if current_rows < self.min_rows:
+            backward_rows = self.min_rows - current_rows
+            latest_data = df.iloc[-1:].copy()
+            backward_data = []
+            for i in range(1, backward_rows + 1):
+                new_row = latest_data.copy()
+                new_row[time_col] = latest_data[time_col] + timedelta(seconds=time_diff * i)
+                backward_data.append(new_row)
+            backward_df = pd.concat(backward_data, ignore_index=True)
+            df = pd.concat([df, backward_df], ignore_index=True)
+        
+        self.logger.info(f"数据补充完成,当前行数:{len(df)}行")
+        return df
+      
     def reorder_columns(self, df):
         """调整数据列顺序,确保与训练时的特征顺序一致"""
         desired_order = [
@@ -354,7 +403,27 @@ class Predictor:
         loader = DataLoader(tensor_dataset, batch_size=self.batch_size, shuffle=False)
         
         return loader
-
+      
+	def get_recent_values_as_fallback(self):
+        """从原始输入数据中获取最近的output_size条记录作为备用输出"""
+        # 确保原始数据已保存
+        if self.raw_input_data is None:
+            raise ValueError("原始输入数据未保存,无法获取备用值")
+            
+        # 按时间排序并取最近的output_size条
+        recent_data = self.raw_input_data.sort_values('index').tail(self.output_size)
+        
+        # 若数据不足,用最后一条补充
+        if len(recent_data) < self.output_size:
+            last_row = recent_data.iloc[-1:] if not recent_data.empty else pd.DataFrame(
+                {col: [0.0] for col in self.target_columns}, index=[0])
+            while len(recent_data) < self.output_size:
+                recent_data = pd.concat([recent_data, last_row], ignore_index=True)
+        
+        # 提取目标列值并返回
+        fallback_values = recent_data[self.target_columns].values
+        return fallback_values
+      
     @log_execution_time
     def load_data(self, df):
         """数据加载和预处理"""
@@ -427,6 +496,12 @@ class Predictor:
         """执行预测"""
         self.logger.info("[预测流程] 开始")
         
+        # 保存原始输入数据用于可能的降级策略
+        self.raw_input_data = df.copy()
+        
+        # 确保数据行数不少于指定行数
+        df = self.ensure_min_rows(df)
+        
         try:
             # 更新测试起始时间
             latest_time = pd.to_datetime(df['index']).max()
@@ -474,6 +549,11 @@ class Predictor:
             self.logger.error(f"[反归一化] 失败: {e}")
             raise
         
+        # 检查是否有NaN值,有则使用备用值
+        if np.isnan(predictions).any():
+            self.logger.warning("[预测结果] 发现NaN值,使用最近值作为备用")
+            predictions = self.get_recent_values_as_fallback()
+            
         # 可选后处理
         if self.remove_outliers_flag:
             self.logger.info("[后处理] 执行异常值移除")