data_preprocessor.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # data_preprocessor.py
  2. import os
  3. import torch
  4. import joblib
  5. import numpy as np
  6. import pandas as pd
  7. from tqdm import tqdm
  8. from sklearn.preprocessing import MinMaxScaler
  9. from torch.utils.data import DataLoader, TensorDataset
  10. from concurrent.futures import ThreadPoolExecutor
  11. from config import config
  12. class DataPreprocessor:
  13. """数据预处理类"""
  14. @staticmethod
  15. def load_and_process_data(data):
  16. data['date'] = pd.to_datetime(data['date'])
  17. time_interval = pd.Timedelta(minutes=(4 * config.RESOLUTION / 60))
  18. window_time_span = time_interval * (config.SEQ_LEN + 1)
  19. val_start_date = pd.to_datetime(config.VAL_START_DATE)
  20. test_start_date = pd.to_datetime(config.TEST_START_DATE)
  21. adjusted_val_start = val_start_date - window_time_span
  22. adjusted_test_start = test_start_date - window_time_span
  23. train_mask = (data['date'] >= pd.to_datetime(config.TRAIN_START_DATE)) & \
  24. (data['date'] <= pd.to_datetime(config.TRAIN_END_DATE))
  25. val_mask = (data['date'] >= adjusted_val_start) & \
  26. (data['date'] <= pd.to_datetime(config.VAL_END_DATE))
  27. test_mask = (data['date'] >= adjusted_test_start) & \
  28. (data['date'] <= pd.to_datetime(config.TEST_END_DATE))
  29. train_data = data[train_mask].reset_index(drop=True).drop(columns=['date'])
  30. val_data = data[val_mask].reset_index(drop=True).drop(columns=['date'])
  31. test_data = data[test_mask].reset_index(drop=True).drop(columns=['date'])
  32. train_supervised = DataPreprocessor.create_supervised_dataset(train_data, 1)
  33. val_supervised = DataPreprocessor.create_supervised_dataset(val_data, 1)
  34. test_supervised = DataPreprocessor.create_supervised_dataset(test_data, config.STEP_SIZE)
  35. train_loader = DataPreprocessor.load_data(train_supervised, shuffle=True)
  36. val_loader = DataPreprocessor.load_data(val_supervised, shuffle=False)
  37. test_loader = DataPreprocessor.load_data(test_supervised, shuffle=False)
  38. return train_loader, val_loader, test_loader, data
  39. @staticmethod
  40. def read_and_combine_csv_files():
  41. def read_file(file_count):
  42. file_name = config.FILE_PATTERN.format(file_count)
  43. file_path = os.path.join(config.DATA_DIR, file_name)
  44. try:
  45. df = pd.read_csv(file_path)
  46. return df[config.REQUIRED_COLUMNS]
  47. except KeyError as e:
  48. print(f"文件 {file_name} 中缺少列: {e}")
  49. raise
  50. file_indices = list(range(config.START_FILES, config.END_FILES + 1))
  51. with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
  52. results = list(tqdm(executor.map(read_file, file_indices),
  53. total=len(file_indices), desc="正在读取文件"))
  54. all_data = pd.concat(results, ignore_index=True)
  55. all_data = all_data[config.REQUIRED_COLUMNS]
  56. chunk = all_data.iloc[::config.RESOLUTION, :].reset_index(drop=True)
  57. chunk = DataPreprocessor.process_date(chunk)
  58. chunk = DataPreprocessor.scaler_data(chunk)
  59. return chunk
  60. @staticmethod
  61. def process_date(data):
  62. data = data.rename(columns={'index': 'date'})
  63. data['date'] = pd.to_datetime(data['date'])
  64. minute_of_day = data['date'].dt.hour * 60 + data['date'].dt.minute
  65. day_of_year = data['date'].dt.dayofyear
  66. time_features = ['minute_sin', 'minute_cos', 'day_year_sin', 'day_year_cos']
  67. data['minute_sin'] = np.sin(2 * np.pi * minute_of_day / 1440)
  68. data['minute_cos'] = np.cos(2 * np.pi * minute_of_day / 1440)
  69. data['day_year_sin'] = np.sin(2 * np.pi * day_of_year / 366)
  70. data['day_year_cos'] = np.cos(2 * np.pi * day_of_year / 366)
  71. other_columns = [col for col in data.columns if col not in ['date'] + time_features]
  72. return data[['date'] + time_features + other_columns]
  73. @staticmethod
  74. def scaler_data(data):
  75. date_col = data[['date']]
  76. data_to_scale = data.drop(columns=['date'])
  77. scaler = MinMaxScaler(feature_range=(0, 1))
  78. scaled_data = scaler.fit_transform(data_to_scale)
  79. joblib.dump(scaler, config.SCALER_PATH)
  80. scaled_data = pd.DataFrame(scaled_data, columns=data_to_scale.columns)
  81. return pd.concat([date_col.reset_index(drop=True), scaled_data], axis=1)
  82. @staticmethod
  83. def create_supervised_dataset(data, step_size):
  84. data = pd.DataFrame(data)
  85. cols, col_names = [], []
  86. feature_columns = data.columns.tolist()
  87. for col in feature_columns:
  88. for i in range(config.SEQ_LEN - 1, -1, -1):
  89. cols.append(data[[col]].shift(i))
  90. col_names.append(f"{col}(t-{i})")
  91. target_columns = feature_columns[-config.LABELS_NUM:]
  92. for i in range(1, config.OUTPUT_SIZE + 1):
  93. for col in target_columns:
  94. cols.append(data[[col]].shift(-i))
  95. col_names.append(f"{col}(t+{i})")
  96. dataset = pd.concat(cols, axis=1)
  97. dataset.columns = col_names
  98. dataset = dataset.iloc[::step_size, :]
  99. dataset.dropna(inplace=True)
  100. return dataset
  101. @staticmethod
  102. def load_data(dataset, shuffle):
  103. n_features_total = config.FEATURE_NUM * config.SEQ_LEN
  104. n_labels_total = config.OUTPUT_SIZE * config.LABELS_NUM
  105. X = dataset.values[:, :n_features_total]
  106. y = dataset.values[:, n_features_total:n_features_total + n_labels_total]
  107. X = X.reshape(X.shape[0], config.SEQ_LEN, config.FEATURE_NUM)
  108. device = torch.device(f"cuda:{config.DEVICE_ID}" if torch.cuda.is_available() else "cpu")
  109. X = torch.tensor(X, dtype=torch.float32).to(device)
  110. y = torch.tensor(y, dtype=torch.float32).to(device)
  111. dataset_tensor = TensorDataset(X, y)
  112. generator = torch.Generator()
  113. generator.manual_seed(config.RANDOM_SEED)
  114. return DataLoader(dataset_tensor, batch_size=config.BATCH_SIZE, shuffle=shuffle, generator=generator)