import torch import torch.nn as nn import torch.optim as optim import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm import os from sklearn.metrics import mean_absolute_error, mean_squared_error import logging class DataTrainer: def __init__(self, model, args, preprocessor, optimizer=None, scheduler=None, logger=None): self.model = model self.args = args self.preprocessor = preprocessor self.device = args.device self.logger = logger if logger is not None else self._default_logger() self.criterion = nn.MSELoss() self.optimizer = optimizer if optimizer is not None else optim.Adam( model.parameters(), lr=args.lr, weight_decay=args.weight_decay ) self.scheduler = scheduler self.train_losses = [] self.val_losses = [] self.train_mae = [] self.val_mae = [] self.best_val_loss = float('inf') self.early_stop_counter = 0 self.model_save_path = os.path.join('models', 'best_model.pth') self.final_model_path = os.path.join('models', 'final_model.pth') if not os.path.exists('models'): os.makedirs('models') if not os.path.exists('plots'): os.makedirs('plots') def _default_logger(self): logger = logging.getLogger('DefaultLogger') logger.setLevel(logging.INFO) console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') console_handler.setFormatter(formatter) if not logger.handlers: logger.addHandler(console_handler) return logger @torch.enable_grad() def train_epoch(self, train_loader, adj, max_batches=None): """ 支持通过 max_batches 限制本 epoch 使用的批次数(用于 RL 快速评估)。 """ self.model.train() total_loss = 0.0 all_outputs = [] all_targets = [] adj = adj.to(self.device) for batch_idx, (data, target) in enumerate(train_loader): if max_batches is not None and batch_idx >= max_batches: break data, target = data.to(self.device), target.to(self.device) data = data.unsqueeze(-1) # (batch_size, 145, 1) self.optimizer.zero_grad() output = self.model(data, adj) # (batch_size, 145, 47) output = output.mean(dim=1) # (batch_size, 47) loss = self.criterion(output, target) loss.backward() if self.args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip) self.optimizer.step() total_loss += loss.item() all_outputs.append(output.detach().cpu().numpy()) all_targets.append(target.detach().cpu().numpy()) if len(all_outputs) == 0: # 在极小数据/极小 batch 情况下的保护 return float('inf'), float('inf') all_outputs = np.vstack(all_outputs) all_targets = np.vstack(all_targets) mae = mean_absolute_error(all_targets, all_outputs) avg_loss = total_loss / (min(len(train_loader), max_batches) if max_batches else len(train_loader)) self.train_losses.append(avg_loss) self.train_mae.append(mae) return avg_loss, mae @torch.no_grad() def validate(self, val_loader, adj, max_batches=None): """ 支持通过 max_batches 限制验证批次数(用于 RL 快速评估)。 """ self.model.eval() total_loss = 0.0 all_outputs = [] all_targets = [] adj = adj.to(self.device) for batch_idx, (data, target) in enumerate(val_loader): if max_batches is not None and batch_idx >= max_batches: break data, target = data.to(self.device), target.to(self.device) data = data.unsqueeze(-1) # (batch_size, 145, 1) output = self.model(data, adj) # (batch_size, 145, 47) output = output.mean(dim=1) # (batch_size, 47) loss = self.criterion(output, target) total_loss += loss.item() all_outputs.append(output.cpu().numpy()) all_targets.append(target.cpu().numpy()) if len(all_outputs) == 0: return float('inf'), float('inf') all_outputs = np.vstack(all_outputs) all_targets = np.vstack(all_targets) mae = mean_absolute_error(all_targets, all_outputs) avg_loss = total_loss / (min(len(val_loader), max_batches) if max_batches else len(val_loader)) self.val_losses.append(avg_loss) self.val_mae.append(mae) return avg_loss, mae def train(self, train_loader, val_loader, adj, epochs=None): adj = adj.to(self.device) epochs = epochs if epochs is not None else self.args.epochs self.logger.info(f"开始训练,共 {epochs} 个epoch") for epoch in tqdm(range(epochs)): train_loss, train_mae = self.train_epoch(train_loader, adj) val_loss, val_mae = self.validate(val_loader, adj) if self.scheduler is not None: self.scheduler.step(val_loss) # 保存最佳模型 if val_loss < self.best_val_loss: self.best_val_loss = val_loss torch.save({ 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'loss': val_loss, 'args': self.args # 保存训练参数 }, self.model_save_path) self.early_stop_counter = 0 self.logger.info(f"已更新最佳模型到 {self.model_save_path}") # 日志 else: self.early_stop_counter += 1 if self.early_stop_counter >= self.args.patience: self.logger.info(f"早停机制触发,在第 {epoch+1} 轮停止训练") break if (epoch + 1) % 10 == 0: self.logger.info(f'Epoch {epoch+1}/{epochs}, ' f'Train Loss: {train_loss:.6f}, Train MAE: {train_mae:.6f}, ' f'Val Loss: {val_loss:.6f}, Val MAE: {val_mae:.6f}') self.plot_losses() self.plot_mae() # 加载最佳模型 checkpoint = torch.load(self.model_save_path, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.logger.info(f"加载最佳模型(第 {checkpoint['epoch']+1} 轮,验证损失: {checkpoint['loss']:.6f})") # 保存最终训练完成的模型(加载最佳模型后) torch.save({ 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'best_val_loss': self.best_val_loss, 'args': self.args }, self.final_model_path) self.logger.info(f"已保存最终模型到 {self.final_model_path}") return self.model @torch.no_grad() def test(self, test_loader, adj): self.model.eval() total_loss = 0.0 all_outputs = [] all_targets = [] adj = adj.to(self.device) for data, target in test_loader: data, target = data.to(self.device), target.to(self.device) data = data.unsqueeze(-1) output = self.model(data, adj) output = output.mean(dim=1) loss = self.criterion(output, target) total_loss += loss.item() all_outputs.append(output.cpu().numpy()) all_targets.append(target.cpu().numpy()) all_outputs = np.vstack(all_outputs) all_targets = np.vstack(all_targets) mse = total_loss / len(test_loader) mae = mean_absolute_error(all_targets, all_outputs) rmse = np.sqrt(mean_squared_error(all_targets, all_outputs)) all_outputs_original = self.preprocessor.inverse_transform_targets(all_outputs) all_targets_original = self.preprocessor.inverse_transform_targets(all_targets) original_mse = mean_squared_error(all_targets_original, all_outputs_original) original_mae = mean_absolute_error(all_targets_original, all_outputs_original) original_rmse = np.sqrt(original_mse) self.logger.info(f'Test Loss (normalized): MSE={mse:.6f}, MAE={mae:.6f}, RMSE={rmse:.6f}') self.logger.info(f'Test Loss (original scale): MSE={original_mse:.6f}, MAE={original_mae:.6f}, RMSE={original_rmse:.6f}') self.plot_predictions(all_outputs_original, all_targets_original) return { 'normalized_mse': mse, 'normalized_mae': mae, 'normalized_rmse': rmse, 'original_mse': original_mse, 'original_mae': original_mae, 'original_rmse': original_rmse, 'predictions': all_outputs_original, 'targets': all_targets_original } def plot_losses(self): plt.figure(figsize=(10, 6)) plt.plot(self.train_losses, label='Train Loss') plt.plot(self.val_losses, label='Validation Loss') plt.xlabel('Epoch') plt.ylabel('MSE Loss') plt.title('Training and Validation Loss') plt.legend() plt.savefig('plots/loss_curve.png') plt.close() def plot_mae(self): plt.figure(figsize=(10, 6)) plt.plot(self.train_mae, label='Train MAE') plt.plot(self.val_mae, label='Validation MAE') plt.xlabel('Epoch') plt.ylabel('MAE') plt.title('Training and Validation MAE') plt.legend() plt.savefig('plots/mae_curve.png') plt.close() def plot_predictions(self, predictions, targets): num_plots = min(3, self.args.num_targets) plt.figure(figsize=(15, 5*num_plots)) for i in range(num_plots): plt.subplot(num_plots, 1, i+1) plt.plot(targets[:100, i], label='True Value') plt.plot(predictions[:100, i], label='Predicted Value') plt.xlabel('Time Step') plt.ylabel(f'Target {i+1}') plt.title(f'Prediction vs True Value for Target {i+1}') plt.legend() plt.tight_layout() plt.savefig('plots/prediction_examples.png') plt.close()