data_preprocessor.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # data_preprocessor.py
  2. import os
  3. import torch
  4. import joblib
  5. import numpy as np
  6. import pandas as pd
  7. from tqdm import tqdm # 进度条工具
  8. from sklearn.preprocessing import MinMaxScaler # 数据标准化
  9. from torch.utils.data import DataLoader, TensorDataset # PyTorch数据加载工具
  10. from concurrent.futures import ThreadPoolExecutor # 多线程并行处理
  11. class DataPreprocessor:
  12. @staticmethod
  13. def load_and_process_data(args, data):
  14. """
  15. 加载并处理数据,划分训练/验证/测试集,生成数据加载器
  16. :param args: 配置参数(包含日期范围、序列长度等)
  17. :param data: 原始数据(DataFrame格式)
  18. :return: 训练/验证/测试数据加载器、原始数据
  19. """
  20. # 处理日期列
  21. data['date'] = pd.to_datetime(data['date'])
  22. time_interval = pd.Timedelta(hours=(args.resolution / 60))
  23. window_time_span = time_interval * (args.seq_len + 314)
  24. # 划分训练/验证/测试集(调整起始日期以适应滑动窗口)
  25. val_start_date = pd.to_datetime(args.val_start_date)
  26. test_start_date = pd.to_datetime(args.test_start_date)
  27. # 调整验证集/测试集的起始日期(提前窗口跨度,确保能生成完整输入序列)
  28. adjusted_val_start = val_start_date - window_time_span
  29. adjusted_test_start = test_start_date - window_time_span
  30. # 生成训练/验证/测试集的掩码(布尔索引)
  31. train_mask = (data['date'] >= pd.to_datetime(args.train_start_date)) & \
  32. (data['date'] <= pd.to_datetime(args.train_end_date))
  33. val_mask = (data['date'] >= adjusted_val_start) & \
  34. (data['date'] <= pd.to_datetime(args.val_end_date))
  35. test_mask = (data['date'] >= adjusted_test_start) & \
  36. (data['date'] <= pd.to_datetime(args.test_end_date))
  37. # 应用掩码并重置索引
  38. train_data = data[train_mask].reset_index(drop=True)
  39. val_data = data[val_mask].reset_index(drop=True)
  40. test_data = data[test_mask].reset_index(drop=True)
  41. # 移除日期列用于建模
  42. train_data = train_data.drop(columns=['date'])
  43. val_data = val_data.drop(columns=['date'])
  44. test_data = test_data.drop(columns=['date'])
  45. # 创建监督学习数据集(输入序列+目标序列)
  46. train_supervised = DataPreprocessor.create_supervised_dataset(
  47. args,
  48. train_data,
  49. 1
  50. )
  51. val_supervised = DataPreprocessor.create_supervised_dataset(
  52. args,
  53. val_data,
  54. 1
  55. )
  56. test_supervised = DataPreprocessor.create_supervised_dataset(
  57. args,
  58. test_data,
  59. args.step_size
  60. )
  61. # 转换为DataLoader
  62. train_loader = DataPreprocessor.load_data(
  63. args,
  64. train_supervised,
  65. shuffle=True
  66. )
  67. val_loader = DataPreprocessor.load_data(
  68. args,
  69. val_supervised,
  70. shuffle=False
  71. )
  72. test_loader = DataPreprocessor.load_data(
  73. args,
  74. test_supervised,
  75. shuffle=False
  76. )
  77. return train_loader, val_loader, test_loader, data # 返回原始数据用于后续处理
  78. @staticmethod
  79. def read_and_combine_csv_files(args):
  80. """
  81. 读取并合并多个CSV文件(支持多线程加速)
  82. :param args: 配置参数(包含数据目录、文件命名模式等)
  83. :return: 合并并预处理后的DataFrame
  84. """
  85. current_dir = os.path.dirname(__file__)
  86. parent_dir = os.path.dirname(current_dir)
  87. args.data_dir = os.path.join(parent_dir, args.data_dir)
  88. def read_file(file_count):
  89. """内部函数:读取单个CSV文件"""
  90. file_name = args.file_pattern.format(file_count)
  91. file_path = os.path.join(args.data_dir, file_name)
  92. return pd.read_csv(file_path)
  93. # 生成文件索引列表(从start_files到end_files)
  94. file_indices = list(range(args.start_files, args.end_files + 1))
  95. # 多线程读取文件(加速大文件读取)
  96. max_workers = os.cpu_count() # 按CPU核心数设置线程数
  97. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  98. results = list(tqdm(executor.map(read_file, file_indices),
  99. total=len(file_indices),
  100. desc="正在读取文件"))
  101. # 合并所有数据并重置索引
  102. all_data = pd.concat(results, ignore_index=True)
  103. # 按分辨率下采样
  104. chunk = all_data.iloc[::args.resolution, :].reset_index(drop=True)
  105. # 处理日期和时间特征
  106. chunk = DataPreprocessor.process_date(chunk)
  107. # 归一化
  108. chunk = DataPreprocessor.scaler_data(chunk)
  109. return chunk
  110. @staticmethod
  111. def process_date(data):
  112. """
  113. 处理日期列,生成周期性时间特征(年周期)
  114. :param data: 包含'index'列(日期字符串)的DataFrame
  115. :return: 增加时间特征后的DataFrame
  116. """
  117. data = data.rename(columns={'index': 'date'})
  118. data['date'] = pd.to_datetime(data['date'])
  119. # 生成周期性时间特征
  120. data['day_of_year'] = data['date'].dt.dayofyear
  121. data['day_year_sin'] = np.sin(2 * np.pi * data['day_of_year'] / 366)
  122. data['day_year_cos'] = np.cos(2 * np.pi * data['day_of_year'] / 366)
  123. # 移除原始时间列,保留特征列
  124. data.drop(columns=['day_of_year'], inplace=True)
  125. # 调整列顺序(日期+时间特征+其他特征)
  126. time_features = ['day_year_sin', 'day_year_cos']
  127. other_columns = [col for col in data.columns if col not in ['date'] and col not in time_features]
  128. data = data[['date'] + time_features + other_columns]
  129. return data
  130. @staticmethod
  131. def scaler_data(data):
  132. """
  133. 对数据进行归一化(除日期列外),并保存标准化器
  134. :param data: 包含'date'列的DataFrame
  135. :return: 标准化后的DataFrame
  136. """
  137. date_col = data[['date']]
  138. data_to_scale = data.drop(columns=['date'])
  139. scaler = MinMaxScaler(feature_range=(0, 1))
  140. scaled_data = scaler.fit_transform(data_to_scale)
  141. joblib.dump(scaler, 'scaler.pkl') # 保存归一化器
  142. scaled_data = pd.DataFrame(scaled_data, columns=data_to_scale.columns)
  143. # 拼接日期列和标准化后的数据
  144. scaled_data = pd.concat([date_col.reset_index(drop=True), scaled_data], axis=1)
  145. return scaled_data
  146. @staticmethod
  147. def create_supervised_dataset(args, data, step_size):
  148. """
  149. 将时间序列数据转换为监督学习格式(输入序列+目标序列)
  150. :param args: 配置参数(序列长度、输出长度等)
  151. :param data: 输入数据(DataFrame,不含日期列)
  152. :param step_size: 滑动窗口步长
  153. :return: 监督学习数据集(DataFrame)
  154. """
  155. data = pd.DataFrame(data)
  156. cols = []
  157. col_names = []
  158. feature_columns = data.columns.tolist()
  159. # 输入序列(t-0到t-(seq_len-1))
  160. for col in feature_columns:
  161. for i in range(args.seq_len - 1, -1, -1):
  162. cols.append(data[[col]].shift(i))
  163. col_names.append(f"{col}(t-{i})")
  164. # 目标序列(仅取最后labels_num列作为预测目标)
  165. target_columns = feature_columns[-args.labels_num:]
  166. for i in range(1, args.output_size + 1):
  167. for col in target_columns:
  168. cols.append(data[[col]].shift(-i))
  169. col_names.append(f"{col}(t+{i})")
  170. # 合并并清洗数据
  171. dataset = pd.concat(cols, axis=1)
  172. dataset.columns = col_names
  173. dataset = dataset.iloc[::step_size, :] # 按步长采样
  174. dataset.dropna(inplace=True) # 移除含缺失值的行
  175. return dataset
  176. @staticmethod
  177. def load_data(args, dataset, shuffle):
  178. """
  179. 将监督学习数据集转换为PyTorch张量,并创建DataLoader
  180. :param args: 配置参数
  181. :param dataset: 监督学习数据集(DataFrame)
  182. :param shuffle: 是否打乱数据
  183. :return: DataLoader对象
  184. """
  185. input_length = args.seq_len
  186. n_features = args.feature_num
  187. labels_num = args.labels_num
  188. n_features_total = n_features * input_length # 输入特征总维度
  189. n_labels_total = args.output_size * labels_num # 目标总维度
  190. # 分割输入和目标
  191. X = dataset.values[:, :n_features_total]
  192. y = dataset.values[:, n_features_total:n_features_total + n_labels_total]
  193. # 重塑输入为[样本数, 序列长度, 特征数]
  194. X = X.reshape(X.shape[0], input_length, n_features)
  195. X = torch.tensor(X, dtype=torch.float32).to(args.device)
  196. y = torch.tensor(y, dtype=torch.float32).to(args.device)
  197. # 创建数据集和数据加载器
  198. dataset_tensor = TensorDataset(X, y)
  199. generator = torch.Generator()
  200. generator.manual_seed(args.random_seed) # 固定随机种子确保可复现
  201. data_loader = DataLoader(
  202. dataset_tensor,
  203. batch_size=args.batch_size,
  204. shuffle=shuffle,
  205. generator=generator
  206. )
  207. return data_loader