Explorar o código

1:新增config 文件

wmy hai 5 meses
pai
achega
f730c08855
Modificáronse 1 ficheiros con 80 adicións e 0 borrados
  1. 80 0
      models/pressure-predictor/gat-lstm_model/config.yaml

+ 80 - 0
models/pressure-predictor/gat-lstm_model/config.yaml

@@ -0,0 +1,80 @@
+# GAT-LSTM 20分钟TMP预测模型配置文件
+# 此配置文件可被其他Python文件通过 common.utils.config 模块调用
+
+# 模型参数配置
+model:
+  seq_len: 10              # 输入序列长度(历史时间步)
+  output_size: 5           # 预测步长(未来预测的时间步数)
+  labels_num: 16           # 预测目标数量(16个待预测的指标)
+  feature_num: 79          # 输入特征总维度
+  step_size: 5             # 数据采样步长(每隔step_size取一个样本)
+  dropout: 0.0             # dropout概率(防止过拟合)
+  lr: 0.01                 # 学习率(训练时使用,预测时仅作参数记录)
+  num_heads: 8             # GAT注意力头数(模型结构参数)
+  hidden_size: 64          # LSTM隐藏层维度
+  batch_size: 512          # 批处理大小
+  num_layers: 1            # LSTM层数
+  random_seed: 1314        # 随机种子(保证实验可重复性)
+
+# 数据处理参数
+data:
+  resolution: 60           # 数据分辨率(原始数据每隔60条取一条,下采样)
+  test_start_date: '2025-07-01'  # 测试集起始日期(初始值,会动态更新)
+  
+  # 小波变换相关参数(预留,未实际使用)
+  wavelet:
+    type: 'db4'            # 小波变换类型
+    level: 3               # 小波分解层数
+    level_after: 4         # 后续小波处理层数
+    mode: 'soft'           # 小波阈值模式
+  
+  # 数据质量阈值(预留)
+  threshold:
+    uf: 0.001              # UF指标阈值
+    ro: 0.01               # RO指标阈值
+    flow: 1.0              # 流量阈值
+
+# 文件路径配置
+paths:
+  model_path: '20min_model.pth'              # 模型权重文件路径
+  scaler_path: '20min_scaler.pkl'            # 数据归一化器文件路径
+  edge_index_path: 'edge_index.pt'           # 图结构边索引文件路径
+  output_csv_path: '20min_predictions.csv'   # 预测结果保存路径
+
+# 日志配置
+logging:
+  level: 'INFO'                              # 日志级别:DEBUG, INFO, WARNING, ERROR, CRITICAL
+  format: 'colored'                          # 日志格式:colored(彩色), json(JSON格式), standard(标准格式)
+  log_file: 'logs/20min_predict.log'         # 日志文件路径(相对于模型目录)
+  max_bytes: 10485760                        # 日志文件最大大小(10MB)
+  backup_count: 5                            # 日志文件备份数量
+
+# 预测结果后处理配置
+postprocess:
+  remove_outliers: false   # 是否移除异常值(四分位法)
+  smooth: false            # 是否平滑预测结果(加权平均)
+
+# GPU配置
+device:
+  use_cuda: true           # 是否使用CUDA(如果可用)
+  cuda_device: 0           # CUDA设备编号
+
+# 预测目标列名(16个待预测的指标)
+target_columns:
+  - 'C.M.UF1_DB@press_PV'  # UF1压差
+  - 'C.M.UF2_DB@press_PV'  # UF2压差
+  - 'C.M.UF3_DB@press_PV'  # UF3压差
+  - 'C.M.UF4_DB@press_PV'  # UF4压差
+  - 'UF1Per'               # UF1渗透率
+  - 'UF2Per'               # UF2渗透率
+  - 'UF3Per'               # UF3渗透率
+  - 'UF4Per'               # UF4渗透率
+  - 'C.M.RO1_DB@DPT_1'     # RO1一段压差
+  - 'C.M.RO2_DB@DPT_1'     # RO2一段压差
+  - 'C.M.RO3_DB@DPT_1'     # RO3一段压差
+  - 'C.M.RO4_DB@DPT_1'     # RO4一段压差
+  - 'C.M.RO1_DB@DPT_2'     # RO1二段压差
+  - 'C.M.RO2_DB@DPT_2'     # RO2二段压差
+  - 'C.M.RO3_DB@DPT_2'     # RO3二段压差
+  - 'C.M.RO4_DB@DPT_2'     # RO4二段压差
+