Преглед изворни кода

增加代码逻辑,避免Nan值

zhanghao пре 5 месеци
родитељ
комит
cd43f42db8
1 измењених фајлова са 5 додато и 2 уклоњено
  1. 5 2
      models/pressure-predictor/gat-lstm_model/20min/predict.py

+ 5 - 2
models/pressure-predictor/gat-lstm_model/20min/predict.py

@@ -209,6 +209,9 @@ class Predictor:
         self.remove_outliers_flag = self.config.get('postprocess.remove_outliers', False)
         self.smooth_flag = self.config.get('postprocess.smooth', False)
         
+        # 预测目标列名
+        self.target_columns = self.config.get('target_columns', [])
+        
         # 设备配置
         use_cuda = self.config.get('device.use_cuda', True)
         cuda_device = self.config.get('device.cuda_device', 0)
@@ -237,7 +240,7 @@ class Predictor:
         
         self.logger.info("预测器初始化完成")
         
-	def ensure_min_rows(self, df):
+    def ensure_min_rows(self, df):
         """
         确保数据至少有指定行数,不足则进行前后补充
         向前补充:使用最早的数据向前扩展
@@ -404,7 +407,7 @@ class Predictor:
         
         return loader
       
-	def get_recent_values_as_fallback(self):
+    def get_recent_values_as_fallback(self):
         """从原始输入数据中获取最近的output_size条记录作为备用输出"""
         # 确保原始数据已保存
         if self.raw_input_data is None: