data_preprocessor.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. # data_preprocessor.py
  2. import os
  3. import torch
  4. import joblib
  5. import numpy as np
  6. import pandas as pd
  7. from tqdm import tqdm
  8. from sklearn.preprocessing import MinMaxScaler
  9. from torch.utils.data import DataLoader, TensorDataset
  10. from concurrent.futures import ThreadPoolExecutor
  11. class DataPreprocessor:
  12. """数据预处理类"""
  13. # 定义必须保留的列
  14. COLUMNS_TO_KEEP = [
  15. 'index',
  16. "water_in", # 进水量
  17. "water_out", # 外供水流量
  18. "RO1_TYL", # RO1脱盐率
  19. "RO2_TYL", # RO2脱盐率
  20. "UF1Per", # UF1渗透率
  21. "UF2Per", # UF2渗透率
  22. "2#RODJB_Eff", # 2#RO段间泵效率
  23. "1#RODJB_Eff", # 1#RO段间泵效率
  24. "2#ROGYB_Eff", # 2#RO高压泵效率
  25. "1#ROGYB_Eff", # 1#RO高压泵效率
  26. "ROHSL", # 反渗透回收率
  27. "ns=3;s=1#RO_CSDD_O", # 1#RO产水电导
  28. "ns=3;s=1#RO_CSPRESS_O", # 1#RO产水压力
  29. "ns=3;s=1#RO_EDCSFLOW_O", # 1#RO二段产水流量
  30. "ns=3;s=1#RO_EDJSPRESS_O", # 1#RO二段进水压力
  31. "ns=3;s=1#RO_EDNSPRESS_O", # 1#RO二段浓水压力
  32. "ns=3;s=1#RO_JSFLOW_O", # 1#RO进水流量
  33. "ns=3;s=1#RO_JSPRESS_O", # 1#RO进水压力
  34. "ns=3;s=1#RO_NSFLOW_O", # 1#RO浓水流量
  35. "ns=3;s=1#RO_SDCSFLOW_O", # 1#RO三段产水流量
  36. "ns=3;s=1#RO_SDJSPRESS_O", # 1#RO三段进水压力
  37. "ns=3;s=1#RO_SDNSPRESS_O", # 1#RO三段浓水压力
  38. "ns=3;s=1#RODJB_CUR_FB_O", # 1#RO段间泵电流反馈
  39. "ns=3;s=1#RODJB_CZ_O", # 1#RO段间泵测振反馈
  40. "ns=3;s=1#RODJB_FRE_FB_O", # 1#RO段间泵频率反馈
  41. "ns=3;s=1#ROGYB_CUR_FB_O", # 1#RO高压泵电流反馈
  42. "ns=3;s=1#ROGYB_CZ_O", # 1#RO高压泵测振反馈
  43. "ns=3;s=1#ROGYB_FRE_FB_O", # 1#RO高压泵频率反馈
  44. "ns=3;s=1#UF_CSPRESS_O", # 1#UF产水压力
  45. "ns=3;s=1#UF_JSFLOW_O", # 1#UF进水流量
  46. "ns=3;s=1#UF_JSPRESS_O", # 1#UF进水压力
  47. "ns=3;s=1#UF_V_FB_O", # 1#UF调节阀开度反馈
  48. "ns=3;s=1#UFBWB_CUR_FB_O", # 1#UF反洗泵电流反馈
  49. "ns=3;s=1#UFBWB_FRE_FB_O", # 1#UF反洗泵频率反馈
  50. "ns=3;s=2#RO_CSDD_O", # 2#RO产水电导
  51. "ns=3;s=2#RO_CSPRESS_O", # 2#RO产水压力
  52. "ns=3;s=2#RO_EDCSFLOW_O", # 2#RO二段产水流量
  53. "ns=3;s=2#RO_EDJSPRESS_O", # 2#RO二段进水压力
  54. "ns=3;s=2#RO_EDNSPRESS_O", # 2#RO二段浓水压力
  55. "ns=3;s=2#RO_JSFLOW_O", # 2#RO进水流量
  56. "ns=3;s=2#RO_JSPRESS_O", # 2#RO进水压力
  57. "ns=3;s=2#RO_NSFLOW_O", # 2#RO浓水流量
  58. "ns=3;s=2#RO_SDCSFLOW_O", # 2#RO三段产水流量
  59. "ns=3;s=2#RO_SDJSPRESS_O", # 2#RO三段进水压力
  60. "ns=3;s=2#RO_SDNSPRESS_O", # 2#RO三段浓水压力
  61. "ns=3;s=2#RODJB_CUR_FB_O", # 2#RO段间泵电流反馈
  62. "ns=3;s=2#RODJB_CZ_O", # 2#RO段间泵测振反馈
  63. "ns=3;s=2#RODJB_FRE_FB_O", # 2#RO段间泵频率反馈
  64. "ns=3;s=2#ROGYB_CUR_FB_O", # 2#RO高压泵电流反馈
  65. "ns=3;s=2#ROGYB_CZ_O", # 2#RO高压泵测振反馈
  66. "ns=3;s=2#ROGYB_FRE_FB_O", # 2#RO高压泵频率反馈
  67. "ns=3;s=2#UF_CSPRESS_O", # 2#UF产水压力
  68. "ns=3;s=2#UF_JSFLOW_O", # 2#UF进水流量
  69. "ns=3;s=2#UF_JSPRESS_O", # 2#UF进水压力
  70. "ns=3;s=2#UF_V_FB_O", # 2#UF调节阀开度反馈
  71. "ns=3;s=2#UFBWB_CUR_FB_O", # 2#UF反洗泵电流反馈
  72. "ns=3;s=2#UFBWB_FRE_FB_O", # 2#UF反洗泵频率反馈
  73. "ns=3;s=RO_JSDD_O", # RO进水电导
  74. "ns=3;s=RO_JSORP_O", # RO进水ORP
  75. "ns=3;s=RO_JSPH_O", # RO进水PH
  76. "ns=3;s=RO1_1DUAN_CS_FLOW", # RO1一段产水流量
  77. "ns=3;s=ZJS_PRESS_O", # 进水压力
  78. "ns=3;s=ZJS_TEMP_O", # 进水温度
  79. "ns=3;s=ZJS_ZD_O", # UF进水浊度
  80. "ns=3;s=PUBLIC_RO1_MTL", # RO1膜通量
  81. "ns=3;s=PUBLIC_RO2_MTL", # RO2膜通量
  82. "ns=3;s=UF1_SSD_KMYC", # UF1跨膜压差
  83. "ns=3;s=UF2_SSD_KMYC", # UF2跨膜压差
  84. "ns=3;s=RO1_1D_YC", # RO1一段压差
  85. "ns=3;s=RO1_2D_YC", # RO1二段压差
  86. "ns=3;s=RO2_1D_YC", # RO2一段压差
  87. "ns=3;s=RO2_2D_YC", # RO2二段压差
  88. "ns=3;s=PUBLIC_BY_REAL_1", # RO1三段压差
  89. "ns=3;s=PUBLIC_BY_REAL_2", # RO2三段压差
  90. ]
  91. @staticmethod
  92. def load_and_process_data(args, data):
  93. """加载并处理数据,划分训练/验证/测试集"""
  94. # 处理日期
  95. data['date'] = pd.to_datetime(data['date'])
  96. time_interval = pd.Timedelta(minutes=(4 * args.resolution / 60))
  97. window_time_span = time_interval * (args.seq_len + 1)
  98. val_start_date = pd.to_datetime(args.val_start_date)
  99. test_start_date = pd.to_datetime(args.test_start_date)
  100. # 调整时间窗口
  101. adjusted_val_start = val_start_date - window_time_span
  102. adjusted_test_start = test_start_date - window_time_span
  103. train_mask = (data['date'] >= pd.to_datetime(args.train_start_date)) & \
  104. (data['date'] <= pd.to_datetime(args.train_end_date))
  105. val_mask = (data['date'] >= adjusted_val_start) & \
  106. (data['date'] <= pd.to_datetime(args.val_end_date))
  107. test_mask = (data['date'] >= adjusted_test_start) & \
  108. (data['date'] <= pd.to_datetime(args.test_end_date))
  109. train_data = data[train_mask].reset_index(drop=True)
  110. val_data = data[val_mask].reset_index(drop=True)
  111. test_data = data[test_mask].reset_index(drop=True)
  112. train_data = train_data.drop(columns=['date'])
  113. val_data = val_data.drop(columns=['date'])
  114. test_data = test_data.drop(columns=['date'])
  115. # 创建数据集
  116. train_supervised = DataPreprocessor.create_supervised_dataset(args, train_data, 1)
  117. val_supervised = DataPreprocessor.create_supervised_dataset(args, val_data, 1)
  118. test_supervised = DataPreprocessor.create_supervised_dataset(args, test_data, args.step_size)
  119. # 转换为DataLoader
  120. train_loader = DataPreprocessor.load_data(args, train_supervised, shuffle=True)
  121. val_loader = DataPreprocessor.load_data(args, val_supervised, shuffle=False)
  122. test_loader = DataPreprocessor.load_data(args, test_supervised, shuffle=False)
  123. return train_loader, val_loader, test_loader, data
  124. @staticmethod
  125. def read_and_combine_csv_files(args):
  126. """读取文件并进行特征筛选和预处理"""
  127. current_dir = os.path.dirname(__file__)
  128. parent_dir = os.path.dirname(current_dir)
  129. args.data_dir = os.path.join(parent_dir, args.data_dir)
  130. def read_file(file_count):
  131. file_name = args.file_pattern.format(file_count)
  132. file_path = os.path.join(args.data_dir, file_name)
  133. try:
  134. df = pd.read_csv(file_path)
  135. # 确保只读取需要的列,若列不存在则会报错提示
  136. return df[DataPreprocessor.COLUMNS_TO_KEEP]
  137. except KeyError as e:
  138. print(f"文件 {file_name} 中缺少列: {e}")
  139. raise
  140. file_indices = list(range(args.start_files, args.end_files + 1))
  141. max_workers = os.cpu_count()
  142. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  143. results = list(tqdm(executor.map(read_file, file_indices),
  144. total=len(file_indices),
  145. desc="正在读取文件"))
  146. all_data = pd.concat(results, ignore_index=True)
  147. # 确保列顺序一致
  148. all_data = all_data[DataPreprocessor.COLUMNS_TO_KEEP]
  149. # 下采样
  150. chunk = all_data.iloc[::args.resolution, :].reset_index(drop=True)
  151. # 处理特征
  152. chunk = DataPreprocessor.process_date(chunk, args)
  153. chunk = DataPreprocessor.scaler_data(chunk, args)
  154. return chunk
  155. @staticmethod
  156. def process_date(data, args):
  157. data = data.rename(columns={'index': 'date'})
  158. data['date'] = pd.to_datetime(data['date'])
  159. time_features = []
  160. # 固定生成分钟级和日级特征,保持与Predictor一致
  161. data['minute_of_day'] = data['date'].dt.hour * 60 + data['date'].dt.minute
  162. data['minute_sin'] = np.sin(2 * np.pi * data['minute_of_day'] / 1440)
  163. data['minute_cos'] = np.cos(2 * np.pi * data['minute_of_day'] / 1440)
  164. data['day_of_year'] = data['date'].dt.dayofyear
  165. data['day_year_sin'] = np.sin(2 * np.pi * data['day_of_year'] / 366)
  166. data['day_year_cos'] = np.cos(2 * np.pi * data['day_of_year'] / 366)
  167. time_features.extend(['minute_sin', 'minute_cos', 'day_year_sin', 'day_year_cos'])
  168. data.drop(columns=['minute_of_day', 'day_of_year'], inplace=True)
  169. other_columns = [col for col in data.columns if col not in ['date'] and col not in time_features]
  170. data = data[['date'] + time_features + other_columns]
  171. return data
  172. @staticmethod
  173. def scaler_data(data, args):
  174. date_col = data[['date']]
  175. data_to_scale = data.drop(columns=['date'])
  176. scaler = MinMaxScaler(feature_range=(0, 1))
  177. scaled_data = scaler.fit_transform(data_to_scale)
  178. joblib.dump(scaler, args.scaler_path)
  179. scaled_data = pd.DataFrame(scaled_data, columns=data_to_scale.columns)
  180. scaled_data = pd.concat([date_col.reset_index(drop=True), scaled_data], axis=1)
  181. return scaled_data
  182. @staticmethod
  183. def create_supervised_dataset(args, data, step_size):
  184. data = pd.DataFrame(data)
  185. cols = []
  186. col_names = []
  187. feature_columns = data.columns.tolist()
  188. # 输入序列
  189. for col in feature_columns:
  190. for i in range(args.seq_len - 1, -1, -1):
  191. cols.append(data[[col]].shift(i))
  192. col_names.append(f"{col}(t-{i})")
  193. # 目标序列 (取最后labels_num列)
  194. target_columns = feature_columns[-args.labels_num:]
  195. for i in range(1, args.output_size + 1):
  196. for col in target_columns:
  197. cols.append(data[[col]].shift(-i))
  198. col_names.append(f"{col}(t+{i})")
  199. dataset = pd.concat(cols, axis=1)
  200. dataset.columns = col_names
  201. dataset = dataset.iloc[::step_size, :]
  202. dataset.dropna(inplace=True)
  203. return dataset
  204. @staticmethod
  205. def load_data(args, dataset, shuffle):
  206. input_length = args.seq_len
  207. n_features = args.feature_num
  208. labels_num = args.labels_num
  209. n_features_total = n_features * input_length
  210. n_labels_total = args.output_size * labels_num
  211. X = dataset.values[:, :n_features_total]
  212. y = dataset.values[:, n_features_total:n_features_total + n_labels_total]
  213. X = X.reshape(X.shape[0], input_length, n_features)
  214. X = torch.tensor(X, dtype=torch.float32).to(args.device)
  215. y = torch.tensor(y, dtype=torch.float32).to(args.device)
  216. dataset_tensor = TensorDataset(X, y)
  217. generator = torch.Generator()
  218. generator.manual_seed(args.random_seed)
  219. return DataLoader(dataset_tensor, batch_size=args.batch_size, shuffle=shuffle, generator=generator)