data_trainer.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # data_trainer.py
  2. import torch
  3. import joblib
  4. import numpy as np
  5. import pandas as pd
  6. from sklearn.metrics import r2_score
  7. from datetime import datetime, timedelta
  8. from sklearn.preprocessing import MinMaxScaler
  9. class Trainer:
  10. def __init__(self, model, args, data):
  11. """
  12. 初始化训练器
  13. :param model: 待训练的模型
  14. :param args: 配置参数(训练轮数、早停参数等)
  15. :param data: 原始数据(用于生成评估时的时间戳)
  16. """
  17. self.args = args
  18. self.model = model
  19. self.data = data
  20. # 早停相关参数
  21. self.patience = args.patience
  22. self.min_delta = args.min_delta
  23. self.counter = 0
  24. self.early_stop = False
  25. self.best_val_loss = float('inf')
  26. self.best_model_state = None
  27. self.best_epoch = 0
  28. def train_full_model(self, train_loader, val_loader, optimizer, criterion, scheduler):
  29. """
  30. 联合训练所有子模型(端到端训练)
  31. :param train_loader: 训练数据加载器
  32. :param val_loader: 验证数据加载器
  33. :param optimizer: 优化器(如Adam)
  34. :param criterion: 损失函数(如MSE)
  35. :param scheduler: 学习率调度器
  36. :return: 训练好的模型(加载最佳权重)
  37. """
  38. self.counter = 0
  39. self.best_val_loss = float('inf')
  40. self.early_stop = False
  41. self.best_model_state = None
  42. self.best_epoch = 0
  43. max_epochs = self.args.epochs
  44. for epoch in range(max_epochs):
  45. self.model.train()
  46. running_loss = 0.0
  47. for inputs, targets in train_loader:
  48. inputs = inputs.to(self.args.device)
  49. targets = targets.to(self.args.device) # 整体目标值(包含所有16个因变量)
  50. optimizer.zero_grad()
  51. outputs = self.model(inputs) # 整体模型输出
  52. loss = criterion(outputs, targets) # 计算整体损失
  53. loss.backward()
  54. optimizer.step()
  55. running_loss += loss.item()
  56. train_loss = running_loss / len(train_loader)
  57. val_loss = self.validate_full(val_loader, criterion) if val_loader else 0.0
  58. print(f'Epoch {epoch+1}/{max_epochs}, '
  59. f'Train Loss: {train_loss:.6f}, '
  60. f'Val Loss: {val_loss:.6f}, '
  61. f'LR: {optimizer.param_groups[0]["lr"]:.6f}')
  62. # 早停逻辑(基于整体验证损失)
  63. if val_loader:
  64. improved = val_loss < (self.best_val_loss - self.min_delta)
  65. if improved:
  66. self.best_val_loss = val_loss
  67. self.counter = 0
  68. self.best_model_state = self.model.state_dict()
  69. self.best_epoch = epoch
  70. else:
  71. self.counter += 1
  72. if self.counter >= self.patience:
  73. self.early_stop = True
  74. print(f"早停触发")
  75. scheduler.step()
  76. torch.cuda.empty_cache()
  77. if self.early_stop:
  78. break
  79. # 加载最佳状态
  80. if self.best_model_state is not None:
  81. self.model.load_state_dict(self.best_model_state)
  82. print(f"最佳迭代: {self.best_epoch+1}, 最佳验证损失: {self.best_val_loss:.6f}")
  83. return self.model
  84. def validate_full(self, val_loader, criterion):
  85. """
  86. 验证整个模型(计算验证损失)
  87. :param val_loader: 验证数据加载器
  88. :param criterion: 损失函数
  89. :return: 平均验证损失
  90. """
  91. self.model.eval()
  92. total_loss = 0.0
  93. with torch.no_grad():
  94. for inputs, targets in val_loader:
  95. inputs = inputs.to(self.args.device)
  96. targets = targets.to(self.args.device) # 整体目标值
  97. outputs = self.model(inputs) # 整体模型输出
  98. loss = criterion(outputs, targets) # 整体损失计算
  99. total_loss += loss.item()
  100. return total_loss / len(val_loader)
  101. def save_model(self):
  102. """保存模型权重到指定路径"""
  103. torch.save(self.model.state_dict(), self.args.model_path)
  104. print(f"模型已保存到:{self.args.model_path}")
  105. def evaluate_model(self, test_loader, criterion):
  106. """
  107. 评估模型在测试集上的性能,计算R²、RMSE、MAPE等指标并保存结果
  108. :param test_loader: 测试数据加载器
  109. :param criterion: 损失函数
  110. :return: 各指标字典(R²、RMSE、MAPE)
  111. """
  112. self.model.eval()
  113. scaler_path = 'scaler.pkl'
  114. scaler = joblib.load(scaler_path)
  115. predictions = []
  116. true_values = []
  117. device = self.args.device
  118. with torch.no_grad():
  119. for inputs, targets in test_loader:
  120. inputs = inputs.to(device)
  121. targets = targets.to(device)
  122. outputs = self.model(inputs)
  123. predictions.append(outputs.cpu().numpy())
  124. true_values.append(targets.cpu().numpy())
  125. predictions = np.concatenate(predictions, axis=0)
  126. true_values = np.concatenate(true_values, axis=0)
  127. # 重塑预测值和真实值形状以匹配反归一化要求
  128. reshaped_predictions = predictions.reshape(predictions.shape[0],
  129. self.args.output_size,
  130. self.args.labels_num)
  131. predictions = reshaped_predictions.reshape(-1, self.args.labels_num)
  132. reshaped_true_values = true_values.reshape(true_values.shape[0],
  133. self.args.output_size,
  134. self.args.labels_num)
  135. true_values = reshaped_true_values.reshape(-1, self.args.labels_num)
  136. # 反归一化(仅对标签列)
  137. column_scaler = MinMaxScaler(feature_range=(0, 1))
  138. column_scaler.min_ = scaler.min_[-self.args.labels_num:]
  139. column_scaler.scale_ = scaler.scale_[-self.args.labels_num:]
  140. true_values = column_scaler.inverse_transform(true_values)
  141. predictions = column_scaler.inverse_transform(predictions)
  142. # 定义列名(16个因变量)
  143. column_names = [
  144. 'C.M.RO1_DB@DPT_1', 'C.M.RO2_DB@DPT_1', 'C.M.RO3_DB@DPT_1', 'C.M.RO4_DB@DPT_1',
  145. 'C.M.RO1_DB@DPT_2', 'C.M.RO2_DB@DPT_2', 'C.M.RO3_DB@DPT_2', 'C.M.RO4_DB@DPT_2',
  146. ]
  147. # 生成时间序列
  148. start_datetime = datetime.strptime(self.args.test_start_date, "%Y-%m-%d")
  149. time_interval = timedelta(minutes=(4 * self.args.resolution / 60))
  150. total_points = len(predictions)
  151. date_times = [start_datetime + i * time_interval for i in range(total_points)]
  152. # 保存结果到DataFrame
  153. results = pd.DataFrame({'date': date_times})
  154. # 计算评估指标
  155. r2_scores = {}
  156. rmse_scores = {}
  157. mape_scores = {}
  158. metrics_details = []
  159. for i, col_name in enumerate(column_names):
  160. results[f'{col_name}_True'] = true_values[:, i]
  161. results[f'{col_name}_Predicted'] = predictions[:, i]
  162. var_true = true_values[:, i]
  163. var_pred = predictions[:, i]
  164. # 过滤零值(避免除零错误)
  165. non_zero_mask = var_true != 0
  166. var_true_nonzero = var_true[non_zero_mask]
  167. var_pred_nonzero = var_pred[non_zero_mask]
  168. r2 = float('nan')
  169. rmse = float('nan')
  170. mape = float('nan')
  171. if len(var_true_nonzero) > 0:
  172. r2 = r2_score(var_true_nonzero, var_pred_nonzero)
  173. rmse = np.sqrt(np.mean((var_true_nonzero - var_pred_nonzero) ** 2))
  174. mape = np.mean(np.abs((var_true_nonzero - var_pred_nonzero) / np.abs(var_true_nonzero))) * 100
  175. r2_scores[col_name] = r2
  176. rmse_scores[col_name] = rmse
  177. mape_scores[col_name] = mape
  178. detail = f"{col_name}:\n R方 = {r2:.6f}\n RMSE = {rmse:.6f}\n MAPE = {mape:.6f}%"
  179. metrics_details.append(detail)
  180. print(f"{col_name} R方: {r2:.6f}")
  181. else:
  182. metrics_details.append(f"{col_name}: 没有有效数据用于计算指标")
  183. print(f"{col_name} 没有有效数据用于计算R方")
  184. # 计算平均指标
  185. valid_r2 = [score for score in r2_scores.values() if not np.isnan(score)]
  186. valid_rmse = [score for score in rmse_scores.values() if not np.isnan(score)]
  187. valid_mape = [score for score in mape_scores.values() if not np.isnan(score)]
  188. avg_r2 = np.mean(valid_r2) if valid_r2 else float('nan')
  189. avg_rmse = np.mean(valid_rmse) if valid_rmse else float('nan')
  190. avg_mape = np.mean(valid_mape) if valid_mape else float('nan')
  191. avg_detail = f"\n平均指标:\n R方 = {avg_r2:.6f}\n RMSE = {avg_rmse:.6f}\n MAPE = {avg_mape:.6f}%"
  192. if np.isnan(avg_r2):
  193. avg_detail = "\n平均指标: 没有有效的指标可用于计算平均值"
  194. metrics_details.append(avg_detail)
  195. print(avg_detail)
  196. # 保存结果
  197. results.to_csv(self.args.output_csv_path, index=False)
  198. print(f"预测结果已保存到:{self.args.output_csv_path}")
  199. txt_path = self.args.output_csv_path.replace('.csv', '_metrics_results.txt')
  200. with open(txt_path, 'w') as f:
  201. f.write("各变量预测指标结果:\n")
  202. f.write("===================\n\n")
  203. for detail in metrics_details:
  204. f.write(detail + '\n')
  205. print(f"预测指标结果已保存到:{txt_path}")
  206. return r2_scores, rmse_scores, mape_scores