data_trainer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. from tqdm import tqdm
  7. import os
  8. from sklearn.metrics import mean_absolute_error, mean_squared_error
  9. import logging
  10. class DataTrainer:
  11. def __init__(self, model, args, preprocessor, optimizer=None, scheduler=None, logger=None):
  12. self.model = model
  13. self.args = args
  14. self.preprocessor = preprocessor
  15. self.device = args.device
  16. self.logger = logger if logger is not None else self._default_logger()
  17. self.criterion = nn.MSELoss()
  18. self.optimizer = optimizer if optimizer is not None else optim.Adam(
  19. model.parameters(),
  20. lr=args.lr,
  21. weight_decay=args.weight_decay
  22. )
  23. self.scheduler = scheduler
  24. self.train_losses = []
  25. self.val_losses = []
  26. self.train_mae = []
  27. self.val_mae = []
  28. self.best_val_loss = float('inf')
  29. self.early_stop_counter = 0
  30. self.model_save_path = os.path.join('models', 'best_model.pth')
  31. self.final_model_path = os.path.join('models', 'final_model.pth')
  32. if not os.path.exists('models'):
  33. os.makedirs('models')
  34. if not os.path.exists('plots'):
  35. os.makedirs('plots')
  36. def _default_logger(self):
  37. logger = logging.getLogger('DefaultLogger')
  38. logger.setLevel(logging.INFO)
  39. console_handler = logging.StreamHandler()
  40. console_handler.setLevel(logging.INFO)
  41. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  42. console_handler.setFormatter(formatter)
  43. if not logger.handlers:
  44. logger.addHandler(console_handler)
  45. return logger
  46. @torch.enable_grad()
  47. def train_epoch(self, train_loader, adj, max_batches=None):
  48. """
  49. 支持通过 max_batches 限制本 epoch 使用的批次数(用于 RL 快速评估)。
  50. """
  51. self.model.train()
  52. total_loss = 0.0
  53. all_outputs = []
  54. all_targets = []
  55. adj = adj.to(self.device)
  56. for batch_idx, (data, target) in enumerate(train_loader):
  57. if max_batches is not None and batch_idx >= max_batches:
  58. break
  59. data, target = data.to(self.device), target.to(self.device)
  60. data = data.unsqueeze(-1) # (batch_size, 145, 1)
  61. self.optimizer.zero_grad()
  62. output = self.model(data, adj) # (batch_size, 145, 47)
  63. output = output.mean(dim=1) # (batch_size, 47)
  64. loss = self.criterion(output, target)
  65. loss.backward()
  66. if self.args.grad_clip > 0:
  67. torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)
  68. self.optimizer.step()
  69. total_loss += loss.item()
  70. all_outputs.append(output.detach().cpu().numpy())
  71. all_targets.append(target.detach().cpu().numpy())
  72. if len(all_outputs) == 0:
  73. # 在极小数据/极小 batch 情况下的保护
  74. return float('inf'), float('inf')
  75. all_outputs = np.vstack(all_outputs)
  76. all_targets = np.vstack(all_targets)
  77. mae = mean_absolute_error(all_targets, all_outputs)
  78. avg_loss = total_loss / (min(len(train_loader), max_batches) if max_batches else len(train_loader))
  79. self.train_losses.append(avg_loss)
  80. self.train_mae.append(mae)
  81. return avg_loss, mae
  82. @torch.no_grad()
  83. def validate(self, val_loader, adj, max_batches=None):
  84. """
  85. 支持通过 max_batches 限制验证批次数(用于 RL 快速评估)。
  86. """
  87. self.model.eval()
  88. total_loss = 0.0
  89. all_outputs = []
  90. all_targets = []
  91. adj = adj.to(self.device)
  92. for batch_idx, (data, target) in enumerate(val_loader):
  93. if max_batches is not None and batch_idx >= max_batches:
  94. break
  95. data, target = data.to(self.device), target.to(self.device)
  96. data = data.unsqueeze(-1) # (batch_size, 145, 1)
  97. output = self.model(data, adj) # (batch_size, 145, 47)
  98. output = output.mean(dim=1) # (batch_size, 47)
  99. loss = self.criterion(output, target)
  100. total_loss += loss.item()
  101. all_outputs.append(output.cpu().numpy())
  102. all_targets.append(target.cpu().numpy())
  103. if len(all_outputs) == 0:
  104. return float('inf'), float('inf')
  105. all_outputs = np.vstack(all_outputs)
  106. all_targets = np.vstack(all_targets)
  107. mae = mean_absolute_error(all_targets, all_outputs)
  108. avg_loss = total_loss / (min(len(val_loader), max_batches) if max_batches else len(val_loader))
  109. self.val_losses.append(avg_loss)
  110. self.val_mae.append(mae)
  111. return avg_loss, mae
  112. def train(self, train_loader, val_loader, adj, epochs=None):
  113. adj = adj.to(self.device)
  114. epochs = epochs if epochs is not None else self.args.epochs
  115. self.logger.info(f"开始训练,共 {epochs} 个epoch")
  116. for epoch in tqdm(range(epochs)):
  117. train_loss, train_mae = self.train_epoch(train_loader, adj)
  118. val_loss, val_mae = self.validate(val_loader, adj)
  119. if self.scheduler is not None:
  120. self.scheduler.step(val_loss)
  121. # 保存最佳模型
  122. if val_loss < self.best_val_loss:
  123. self.best_val_loss = val_loss
  124. torch.save({
  125. 'epoch': epoch,
  126. 'model_state_dict': self.model.state_dict(),
  127. 'optimizer_state_dict': self.optimizer.state_dict(),
  128. 'loss': val_loss,
  129. 'args': self.args # 保存训练参数
  130. }, self.model_save_path)
  131. self.early_stop_counter = 0
  132. self.logger.info(f"已更新最佳模型到 {self.model_save_path}") # 日志
  133. else:
  134. self.early_stop_counter += 1
  135. if self.early_stop_counter >= self.args.patience:
  136. self.logger.info(f"早停机制触发,在第 {epoch+1} 轮停止训练")
  137. break
  138. if (epoch + 1) % 10 == 0:
  139. self.logger.info(f'Epoch {epoch+1}/{epochs}, '
  140. f'Train Loss: {train_loss:.6f}, Train MAE: {train_mae:.6f}, '
  141. f'Val Loss: {val_loss:.6f}, Val MAE: {val_mae:.6f}')
  142. self.plot_losses()
  143. self.plot_mae()
  144. # 加载最佳模型
  145. checkpoint = torch.load(self.model_save_path, map_location=self.device)
  146. self.model.load_state_dict(checkpoint['model_state_dict'])
  147. self.logger.info(f"加载最佳模型(第 {checkpoint['epoch']+1} 轮,验证损失: {checkpoint['loss']:.6f})")
  148. # 保存最终训练完成的模型(加载最佳模型后)
  149. torch.save({
  150. 'model_state_dict': self.model.state_dict(),
  151. 'optimizer_state_dict': self.optimizer.state_dict(),
  152. 'best_val_loss': self.best_val_loss,
  153. 'args': self.args
  154. }, self.final_model_path)
  155. self.logger.info(f"已保存最终模型到 {self.final_model_path}")
  156. return self.model
  157. @torch.no_grad()
  158. def test(self, test_loader, adj):
  159. self.model.eval()
  160. total_loss = 0.0
  161. all_outputs = []
  162. all_targets = []
  163. adj = adj.to(self.device)
  164. for data, target in test_loader:
  165. data, target = data.to(self.device), target.to(self.device)
  166. data = data.unsqueeze(-1)
  167. output = self.model(data, adj)
  168. output = output.mean(dim=1)
  169. loss = self.criterion(output, target)
  170. total_loss += loss.item()
  171. all_outputs.append(output.cpu().numpy())
  172. all_targets.append(target.cpu().numpy())
  173. all_outputs = np.vstack(all_outputs)
  174. all_targets = np.vstack(all_targets)
  175. mse = total_loss / len(test_loader)
  176. mae = mean_absolute_error(all_targets, all_outputs)
  177. rmse = np.sqrt(mean_squared_error(all_targets, all_outputs))
  178. all_outputs_original = self.preprocessor.inverse_transform_targets(all_outputs)
  179. all_targets_original = self.preprocessor.inverse_transform_targets(all_targets)
  180. original_mse = mean_squared_error(all_targets_original, all_outputs_original)
  181. original_mae = mean_absolute_error(all_targets_original, all_outputs_original)
  182. original_rmse = np.sqrt(original_mse)
  183. self.logger.info(f'Test Loss (normalized): MSE={mse:.6f}, MAE={mae:.6f}, RMSE={rmse:.6f}')
  184. self.logger.info(f'Test Loss (original scale): MSE={original_mse:.6f}, MAE={original_mae:.6f}, RMSE={original_rmse:.6f}')
  185. self.plot_predictions(all_outputs_original, all_targets_original)
  186. return {
  187. 'normalized_mse': mse,
  188. 'normalized_mae': mae,
  189. 'normalized_rmse': rmse,
  190. 'original_mse': original_mse,
  191. 'original_mae': original_mae,
  192. 'original_rmse': original_rmse,
  193. 'predictions': all_outputs_original,
  194. 'targets': all_targets_original
  195. }
  196. def plot_losses(self):
  197. plt.figure(figsize=(10, 6))
  198. plt.plot(self.train_losses, label='Train Loss')
  199. plt.plot(self.val_losses, label='Validation Loss')
  200. plt.xlabel('Epoch')
  201. plt.ylabel('MSE Loss')
  202. plt.title('Training and Validation Loss')
  203. plt.legend()
  204. plt.savefig('plots/loss_curve.png')
  205. plt.close()
  206. def plot_mae(self):
  207. plt.figure(figsize=(10, 6))
  208. plt.plot(self.train_mae, label='Train MAE')
  209. plt.plot(self.val_mae, label='Validation MAE')
  210. plt.xlabel('Epoch')
  211. plt.ylabel('MAE')
  212. plt.title('Training and Validation MAE')
  213. plt.legend()
  214. plt.savefig('plots/mae_curve.png')
  215. plt.close()
  216. def plot_predictions(self, predictions, targets):
  217. num_plots = min(3, self.args.num_targets)
  218. plt.figure(figsize=(15, 5*num_plots))
  219. for i in range(num_plots):
  220. plt.subplot(num_plots, 1, i+1)
  221. plt.plot(targets[:100, i], label='True Value')
  222. plt.plot(predictions[:100, i], label='Predicted Value')
  223. plt.xlabel('Time Step')
  224. plt.ylabel(f'Target {i+1}')
  225. plt.title(f'Prediction vs True Value for Target {i+1}')
  226. plt.legend()
  227. plt.tight_layout()
  228. plt.savefig('plots/prediction_examples.png')
  229. plt.close()