data_trainer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # data_trainer.py
  2. import torch
  3. import joblib
  4. import numpy as np
  5. import pandas as pd
  6. from sklearn.metrics import r2_score
  7. from datetime import datetime, timedelta
  8. from sklearn.preprocessing import MinMaxScaler
  9. class Trainer:
  10. def __init__(self, model, args, data):
  11. """
  12. 模型训练器类,负责模型训练、验证、保存和评估
  13. 参数:
  14. model: 待训练的模型实例
  15. args: 配置参数
  16. data: 原始数据(用于生成评估的时间戳)
  17. """
  18. self.args = args
  19. self.model = model
  20. self.data = data
  21. # 早停相关参数
  22. self.patience = args.patience
  23. self.min_delta = args.min_delta
  24. self.counter = 0
  25. self.early_stop = False
  26. self.best_val_loss = float('inf')
  27. self.best_model_state = None
  28. self.best_epoch = 0
  29. def train_full_model(self, train_loader, val_loader, optimizer, criterion, scheduler):
  30. """
  31. 联合训练所有8/16个子模型(端到端训练)
  32. 参数:
  33. train_loader: 训练集数据加载器
  34. val_loader: 验证集数据加载器
  35. optimizer: 优化器(如Adam)
  36. criterion: 损失函数(如MSE)
  37. scheduler: 学习率调度器
  38. 返回:
  39. 训练好的模型(加载最佳权重)
  40. """
  41. self.counter = 0
  42. self.best_val_loss = float('inf')
  43. self.early_stop = False
  44. self.best_model_state = None
  45. self.best_epoch = 0
  46. max_epochs = self.args.epochs
  47. for epoch in range(max_epochs):
  48. self.model.train()
  49. running_loss = 0.0
  50. for inputs, targets in train_loader:
  51. inputs = inputs.to(self.args.device)
  52. targets = targets.to(self.args.device) # 整体目标值(包含所有8/16个因变量)
  53. optimizer.zero_grad()
  54. outputs = self.model(inputs) # 整体模型输出
  55. loss = criterion(outputs, targets) # 计算整体损失
  56. loss.backward()
  57. optimizer.step()
  58. running_loss += loss.item()
  59. train_loss = running_loss / len(train_loader)
  60. val_loss = self.validate_full(val_loader, criterion) if val_loader else 0.0
  61. print(f'Epoch {epoch+1}/{max_epochs}, '
  62. f'Train Loss: {train_loss:.6f}, '
  63. f'Val Loss: {val_loss:.6f}, '
  64. f'LR: {optimizer.param_groups[0]["lr"]:.6f}')
  65. # 早停逻辑(基于整体验证损失)
  66. if val_loader:
  67. improved = val_loss < (self.best_val_loss - self.min_delta)
  68. if improved:
  69. self.best_val_loss = val_loss
  70. self.counter = 0
  71. self.best_model_state = self.model.state_dict()
  72. self.best_epoch = epoch
  73. else:
  74. self.counter += 1
  75. if self.counter >= self.patience:
  76. self.early_stop = True
  77. print(f"早停触发")
  78. scheduler.step()
  79. torch.cuda.empty_cache()
  80. if self.early_stop:
  81. break
  82. # 加载最佳状态
  83. if self.best_model_state is not None:
  84. self.model.load_state_dict(self.best_model_state)
  85. print(f"最佳迭代: {self.best_epoch+1}, 最佳验证损失: {self.best_val_loss:.6f}")
  86. return self.model
  87. def validate_full(self, val_loader, criterion):
  88. """
  89. 验证整个模型(计算验证集损失)
  90. 参数:
  91. val_loader: 验证集数据加载器
  92. criterion: 损失函数
  93. 返回:
  94. 平均验证损失
  95. """
  96. self.model.eval()
  97. total_loss = 0.0
  98. with torch.no_grad():
  99. for inputs, targets in val_loader:
  100. inputs = inputs.to(self.args.device)
  101. targets = targets.to(self.args.device) # 整体目标值
  102. outputs = self.model(inputs) # 整体模型输出
  103. loss = criterion(outputs, targets) # 整体损失计算
  104. total_loss += loss.item()
  105. return total_loss / len(val_loader)
  106. def save_model(self):
  107. """保存模型最佳权重到指定路径"""
  108. torch.save(self.model.state_dict(), self.args.model_path)
  109. print(f"模型已保存到:{self.args.model_path}")
  110. def evaluate_model(self, test_loader, criterion):
  111. """
  112. 评估模型在测试集上的性能,计算R方、RMSE、MAPE等指标,并保存结果
  113. 参数:
  114. test_loader: 测试集数据加载器
  115. criterion: 损失函数(用于计算测试损失)
  116. 返回:
  117. 各指标的字典(R方、RMSE、MAPE)
  118. """
  119. self.model.eval()
  120. scaler_path = self.args.scaler_path
  121. scaler = joblib.load(scaler_path)
  122. predictions = []
  123. true_values = []
  124. device = self.args.device
  125. with torch.no_grad():
  126. for inputs, targets in test_loader:
  127. inputs = inputs.to(device)
  128. targets = targets.to(device)
  129. outputs = self.model(inputs)
  130. predictions.append(outputs.cpu().numpy())
  131. true_values.append(targets.cpu().numpy())
  132. predictions = np.concatenate(predictions, axis=0)
  133. true_values = np.concatenate(true_values, axis=0)
  134. # 重塑预测值和真实值形状以匹配反归一化要求
  135. reshaped_predictions = predictions.reshape(predictions.shape[0],
  136. self.args.output_size,
  137. self.args.labels_num)
  138. predictions = reshaped_predictions.reshape(-1, self.args.labels_num)
  139. reshaped_true_values = true_values.reshape(true_values.shape[0],
  140. self.args.output_size,
  141. self.args.labels_num)
  142. true_values = reshaped_true_values.reshape(-1, self.args.labels_num)
  143. # 反归一化(仅对标签列)
  144. column_scaler = MinMaxScaler(feature_range=(0, 1))
  145. column_scaler.min_ = scaler.min_[-self.args.labels_num:]
  146. column_scaler.scale_ = scaler.scale_[-self.args.labels_num:]
  147. true_values = column_scaler.inverse_transform(true_values)
  148. predictions = column_scaler.inverse_transform(predictions)
  149. # 定义列名(8/16个因变量)
  150. full_column_names = [
  151. 'C.M.UF1_DB@press_PV', 'C.M.UF2_DB@press_PV', 'C.M.UF3_DB@press_PV', 'C.M.UF4_DB@press_PV',
  152. 'UF1Per','UF2Per','UF3Per','UF4Per',
  153. 'C.M.RO1_DB@DPT_1', 'C.M.RO2_DB@DPT_1', 'C.M.RO3_DB@DPT_1', 'C.M.RO4_DB@DPT_1',
  154. 'C.M.RO1_DB@DPT_2', 'C.M.RO2_DB@DPT_2', 'C.M.RO3_DB@DPT_2', 'C.M.RO4_DB@DPT_2'
  155. ]
  156. # 根据labels_num选择对应的列名子集
  157. if self.args.labels_num == 16:
  158. column_names = full_column_names
  159. elif self.args.labels_num == 8:
  160. # 取后8个RO相关的列名
  161. # column_names = full_column_names[-8:]
  162. column_names = full_column_names[:8]
  163. else:
  164. # 处理不支持的labels_num值(可选,根据需求调整)
  165. raise ValueError(f"不支持的labels_num值: {self.args.labels_num},仅支持16或8")
  166. # 生成时间序列
  167. start_datetime = datetime.strptime(self.args.test_start_date, "%Y-%m-%d")
  168. time_interval = timedelta(minutes=(4 * self.args.resolution / 60))
  169. total_points = len(predictions)
  170. date_times = [start_datetime + i * time_interval for i in range(total_points)]
  171. # 保存结果到DataFrame
  172. results = pd.DataFrame({'date': date_times})
  173. # 计算评估指标
  174. r2_scores = {}
  175. rmse_scores = {}
  176. mape_scores = {}
  177. metrics_details = []
  178. for i, col_name in enumerate(column_names):
  179. results[f'{col_name}_True'] = true_values[:, i]
  180. results[f'{col_name}_Predicted'] = predictions[:, i]
  181. var_true = true_values[:, i]
  182. var_pred = predictions[:, i]
  183. # 过滤零值(避免除零错误)
  184. non_zero_mask = var_true != 0
  185. var_true_nonzero = var_true[non_zero_mask]
  186. var_pred_nonzero = var_pred[non_zero_mask]
  187. r2 = float('nan')
  188. rmse = float('nan')
  189. mape = float('nan')
  190. if len(var_true_nonzero) > 0:
  191. r2 = r2_score(var_true_nonzero, var_pred_nonzero)
  192. rmse = np.sqrt(np.mean((var_true_nonzero - var_pred_nonzero) ** 2))
  193. mape = np.mean(np.abs((var_true_nonzero - var_pred_nonzero) / np.abs(var_true_nonzero))) * 100
  194. r2_scores[col_name] = r2
  195. rmse_scores[col_name] = rmse
  196. mape_scores[col_name] = mape
  197. detail = f"{col_name}:\n R方 = {r2:.6f}\n RMSE = {rmse:.6f}\n MAPE = {mape:.6f}%"
  198. metrics_details.append(detail)
  199. print(f"{col_name} R方: {r2:.6f}")
  200. else:
  201. metrics_details.append(f"{col_name}: 没有有效数据用于计算指标")
  202. print(f"{col_name} 没有有效数据用于计算R方")
  203. # 计算平均指标
  204. valid_r2 = [score for score in r2_scores.values() if not np.isnan(score)]
  205. valid_rmse = [score for score in rmse_scores.values() if not np.isnan(score)]
  206. valid_mape = [score for score in mape_scores.values() if not np.isnan(score)]
  207. avg_r2 = np.mean(valid_r2) if valid_r2 else float('nan')
  208. avg_rmse = np.mean(valid_rmse) if valid_rmse else float('nan')
  209. avg_mape = np.mean(valid_mape) if valid_mape else float('nan')
  210. avg_detail = f"\n平均指标:\n R方 = {avg_r2:.6f}\n RMSE = {avg_rmse:.6f}\n MAPE = {avg_mape:.6f}%"
  211. if np.isnan(avg_r2):
  212. avg_detail = "\n平均指标: 没有有效的指标可用于计算平均值"
  213. metrics_details.append(avg_detail)
  214. print(avg_detail)
  215. # 保存结果
  216. results.to_csv(self.args.output_csv_path, index=False)
  217. print(f"预测结果已保存到:{self.args.output_csv_path}")
  218. txt_path = self.args.output_csv_path.replace('.csv', '_metrics_results.txt')
  219. with open(txt_path, 'w') as f:
  220. f.write("各变量预测指标结果:\n")
  221. f.write("===================\n\n")
  222. for detail in metrics_details:
  223. f.write(detail + '\n')
  224. print(f"预测指标结果已保存到:{txt_path}")
  225. return r2_scores, rmse_scores, mape_scores