|
|
@@ -209,6 +209,9 @@ class Predictor:
|
|
|
self.remove_outliers_flag = self.config.get('postprocess.remove_outliers', False)
|
|
|
self.smooth_flag = self.config.get('postprocess.smooth', False)
|
|
|
|
|
|
+ # 预测目标列名
|
|
|
+ self.target_columns = self.config.get('target_columns', [])
|
|
|
+
|
|
|
# 设备配置
|
|
|
use_cuda = self.config.get('device.use_cuda', True)
|
|
|
cuda_device = self.config.get('device.cuda_device', 0)
|
|
|
@@ -237,7 +240,7 @@ class Predictor:
|
|
|
|
|
|
self.logger.info("预测器初始化完成")
|
|
|
|
|
|
- def ensure_min_rows(self, df):
|
|
|
+ def ensure_min_rows(self, df):
|
|
|
"""
|
|
|
确保数据至少有指定行数,不足则进行前后补充
|
|
|
向前补充:使用最早的数据向前扩展
|
|
|
@@ -404,7 +407,7 @@ class Predictor:
|
|
|
|
|
|
return loader
|
|
|
|
|
|
- def get_recent_values_as_fallback(self):
|
|
|
+ def get_recent_values_as_fallback(self):
|
|
|
"""从原始输入数据中获取最近的output_size条记录作为备用输出"""
|
|
|
# 确保原始数据已保存
|
|
|
if self.raw_input_data is None:
|