data_trainer.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # data_trainer.py
  2. import torch
  3. import joblib
  4. import numpy as np
  5. import pandas as pd
  6. from sklearn.metrics import r2_score
  7. from datetime import datetime, timedelta
  8. from sklearn.preprocessing import MinMaxScaler
  9. class Trainer:
  10. def __init__(self, model, args, data):
  11. self.args = args
  12. self.model = model
  13. self.data = data
  14. # 早停相关参数
  15. self.patience = args.patience
  16. self.min_delta = args.min_delta
  17. self.counter = 0
  18. self.early_stop = False
  19. self.best_val_loss = float('inf')
  20. self.best_model_state = None
  21. self.best_epoch = 0
  22. def train_full_model(self, train_loader, val_loader, optimizer, criterion, scheduler):
  23. """联合训练所有16个子模型"""
  24. self.counter = 0
  25. self.best_val_loss = float('inf')
  26. self.early_stop = False
  27. self.best_model_state = None
  28. self.best_epoch = 0
  29. max_epochs = self.args.epochs
  30. for epoch in range(max_epochs):
  31. self.model.train()
  32. running_loss = 0.0
  33. for inputs, targets in train_loader:
  34. inputs = inputs.to(self.args.device)
  35. targets = targets.to(self.args.device) # 整体目标值(包含所有16个因变量)
  36. optimizer.zero_grad()
  37. outputs = self.model(inputs) # 整体模型输出
  38. loss = criterion(outputs, targets) # 计算整体损失
  39. loss.backward()
  40. optimizer.step()
  41. running_loss += loss.item()
  42. train_loss = running_loss / len(train_loader)
  43. val_loss = self.validate_full(val_loader, criterion) if val_loader else 0.0
  44. print(f'Epoch {epoch+1}/{max_epochs}, '
  45. f'Train Loss: {train_loss:.6f}, '
  46. f'Val Loss: {val_loss:.6f}, '
  47. f'LR: {optimizer.param_groups[0]["lr"]:.6f}')
  48. # 早停逻辑(基于整体验证损失)
  49. if val_loader:
  50. improved = val_loss < (self.best_val_loss - self.min_delta)
  51. if improved:
  52. self.best_val_loss = val_loss
  53. self.counter = 0
  54. self.best_model_state = self.model.state_dict()
  55. self.best_epoch = epoch
  56. else:
  57. self.counter += 1
  58. if self.counter >= self.patience:
  59. self.early_stop = True
  60. print(f"早停触发")
  61. scheduler.step()
  62. torch.cuda.empty_cache()
  63. if self.early_stop:
  64. break
  65. # 加载最佳状态
  66. if self.best_model_state is not None:
  67. self.model.load_state_dict(self.best_model_state)
  68. print(f"最佳迭代: {self.best_epoch+1}, 最佳验证损失: {self.best_val_loss:.6f}")
  69. return self.model
  70. def validate_full(self, val_loader, criterion):
  71. """验证整个模型(包含所有16个子模型)"""
  72. self.model.eval()
  73. total_loss = 0.0
  74. with torch.no_grad():
  75. for inputs, targets in val_loader:
  76. inputs = inputs.to(self.args.device)
  77. targets = targets.to(self.args.device) # 整体目标值
  78. outputs = self.model(inputs) # 整体模型输出
  79. loss = criterion(outputs, targets) # 整体损失计算
  80. total_loss += loss.item()
  81. return total_loss / len(val_loader)
  82. def save_model(self):
  83. torch.save(self.model.state_dict(), self.args.model_path)
  84. print(f"模型已保存到:{self.args.model_path}")
  85. def evaluate_model(self, test_loader, criterion):
  86. self.model.eval()
  87. scaler_path = 'scaler.pkl'
  88. scaler = joblib.load(scaler_path)
  89. predictions = []
  90. true_values = []
  91. device = self.args.device
  92. with torch.no_grad():
  93. for inputs, targets in test_loader:
  94. inputs = inputs.to(device)
  95. targets = targets.to(device)
  96. outputs = self.model(inputs)
  97. predictions.append(outputs.cpu().numpy())
  98. true_values.append(targets.cpu().numpy())
  99. predictions = np.concatenate(predictions, axis=0)
  100. true_values = np.concatenate(true_values, axis=0)
  101. # 重塑预测值和真实值形状以匹配反归一化要求
  102. reshaped_predictions = predictions.reshape(predictions.shape[0],
  103. self.args.output_size,
  104. self.args.labels_num)
  105. predictions = reshaped_predictions.reshape(-1, self.args.labels_num)
  106. reshaped_true_values = true_values.reshape(true_values.shape[0],
  107. self.args.output_size,
  108. self.args.labels_num)
  109. true_values = reshaped_true_values.reshape(-1, self.args.labels_num)
  110. # 反归一化(仅对标签列)
  111. column_scaler = MinMaxScaler(feature_range=(0, 1))
  112. column_scaler.min_ = scaler.min_[-self.args.labels_num:]
  113. column_scaler.scale_ = scaler.scale_[-self.args.labels_num:]
  114. true_values = column_scaler.inverse_transform(true_values)
  115. predictions = column_scaler.inverse_transform(predictions)
  116. # 定义列名(16个因变量)
  117. column_names = [
  118. 'C.M.RO1_DB@DPT_1', 'C.M.RO2_DB@DPT_1', 'C.M.RO3_DB@DPT_1', 'C.M.RO4_DB@DPT_1',
  119. 'C.M.RO1_DB@DPT_2', 'C.M.RO2_DB@DPT_2', 'C.M.RO3_DB@DPT_2', 'C.M.RO4_DB@DPT_2',
  120. ]
  121. # 生成时间序列
  122. start_datetime = datetime.strptime(self.args.test_start_date, "%Y-%m-%d")
  123. time_interval = timedelta(minutes=(4 * self.args.resolution / 60))
  124. total_points = len(predictions)
  125. date_times = [start_datetime + i * time_interval for i in range(total_points)]
  126. # 保存结果到DataFrame
  127. results = pd.DataFrame({'date': date_times})
  128. # 计算评估指标
  129. r2_scores = {}
  130. rmse_scores = {}
  131. mape_scores = {}
  132. metrics_details = []
  133. for i, col_name in enumerate(column_names):
  134. results[f'{col_name}_True'] = true_values[:, i]
  135. results[f'{col_name}_Predicted'] = predictions[:, i]
  136. var_true = true_values[:, i]
  137. var_pred = predictions[:, i]
  138. # 过滤零值(避免除零错误)
  139. non_zero_mask = var_true != 0
  140. var_true_nonzero = var_true[non_zero_mask]
  141. var_pred_nonzero = var_pred[non_zero_mask]
  142. r2 = float('nan')
  143. rmse = float('nan')
  144. mape = float('nan')
  145. if len(var_true_nonzero) > 0:
  146. r2 = r2_score(var_true_nonzero, var_pred_nonzero)
  147. rmse = np.sqrt(np.mean((var_true_nonzero - var_pred_nonzero) ** 2))
  148. mape = np.mean(np.abs((var_true_nonzero - var_pred_nonzero) / np.abs(var_true_nonzero))) * 100
  149. r2_scores[col_name] = r2
  150. rmse_scores[col_name] = rmse
  151. mape_scores[col_name] = mape
  152. detail = f"{col_name}:\n R方 = {r2:.6f}\n RMSE = {rmse:.6f}\n MAPE = {mape:.6f}%"
  153. metrics_details.append(detail)
  154. print(f"{col_name} R方: {r2:.6f}")
  155. else:
  156. metrics_details.append(f"{col_name}: 没有有效数据用于计算指标")
  157. print(f"{col_name} 没有有效数据用于计算R方")
  158. # 计算平均指标
  159. valid_r2 = [score for score in r2_scores.values() if not np.isnan(score)]
  160. valid_rmse = [score for score in rmse_scores.values() if not np.isnan(score)]
  161. valid_mape = [score for score in mape_scores.values() if not np.isnan(score)]
  162. avg_r2 = np.mean(valid_r2) if valid_r2 else float('nan')
  163. avg_rmse = np.mean(valid_rmse) if valid_rmse else float('nan')
  164. avg_mape = np.mean(valid_mape) if valid_mape else float('nan')
  165. avg_detail = f"\n平均指标:\n R方 = {avg_r2:.6f}\n RMSE = {avg_rmse:.6f}\n MAPE = {avg_mape:.6f}%"
  166. if np.isnan(avg_r2):
  167. avg_detail = "\n平均指标: 没有有效的指标可用于计算平均值"
  168. metrics_details.append(avg_detail)
  169. print(avg_detail)
  170. # 保存结果
  171. results.to_csv(self.args.output_csv_path, index=False)
  172. print(f"预测结果已保存到:{self.args.output_csv_path}")
  173. txt_path = self.args.output_csv_path.replace('.csv', '_metrics_results.txt')
  174. with open(txt_path, 'w') as f:
  175. f.write("各变量预测指标结果:\n")
  176. f.write("===================\n\n")
  177. for detail in metrics_details:
  178. f.write(detail + '\n')
  179. print(f"预测指标结果已保存到:{txt_path}")
  180. return r2_scores, rmse_scores, mape_scores