| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262 |
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import numpy as np
- import matplotlib.pyplot as plt
- from tqdm import tqdm
- import os
- from sklearn.metrics import mean_absolute_error, mean_squared_error
- import logging
- class DataTrainer:
- def __init__(self, model, args, preprocessor, optimizer=None, scheduler=None, logger=None):
- self.model = model
- self.args = args
- self.preprocessor = preprocessor
- self.device = args.device
- self.logger = logger if logger is not None else self._default_logger()
-
- self.criterion = nn.MSELoss()
- self.optimizer = optimizer if optimizer is not None else optim.Adam(
- model.parameters(),
- lr=args.lr,
- weight_decay=args.weight_decay
- )
- self.scheduler = scheduler
-
- self.train_losses = []
- self.val_losses = []
- self.train_mae = []
- self.val_mae = []
- self.best_val_loss = float('inf')
- self.early_stop_counter = 0
- self.model_save_path = os.path.join('models', 'best_model.pth')
- self.final_model_path = os.path.join('models', 'final_model.pth')
-
- if not os.path.exists('models'):
- os.makedirs('models')
- if not os.path.exists('plots'):
- os.makedirs('plots')
-
- def _default_logger(self):
- logger = logging.getLogger('DefaultLogger')
- logger.setLevel(logging.INFO)
- console_handler = logging.StreamHandler()
- console_handler.setLevel(logging.INFO)
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- console_handler.setFormatter(formatter)
- if not logger.handlers:
- logger.addHandler(console_handler)
- return logger
-
- @torch.enable_grad()
- def train_epoch(self, train_loader, adj, max_batches=None):
- """
- 支持通过 max_batches 限制本 epoch 使用的批次数(用于 RL 快速评估)。
- """
- self.model.train()
- total_loss = 0.0
- all_outputs = []
- all_targets = []
- adj = adj.to(self.device)
- for batch_idx, (data, target) in enumerate(train_loader):
- if max_batches is not None and batch_idx >= max_batches:
- break
- data, target = data.to(self.device), target.to(self.device)
- data = data.unsqueeze(-1) # (batch_size, 145, 1)
-
- self.optimizer.zero_grad()
- output = self.model(data, adj) # (batch_size, 145, 47)
- output = output.mean(dim=1) # (batch_size, 47)
-
- loss = self.criterion(output, target)
- loss.backward()
- if self.args.grad_clip > 0:
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)
- self.optimizer.step()
-
- total_loss += loss.item()
- all_outputs.append(output.detach().cpu().numpy())
- all_targets.append(target.detach().cpu().numpy())
-
- if len(all_outputs) == 0:
- # 在极小数据/极小 batch 情况下的保护
- return float('inf'), float('inf')
-
- all_outputs = np.vstack(all_outputs)
- all_targets = np.vstack(all_targets)
- mae = mean_absolute_error(all_targets, all_outputs)
-
- avg_loss = total_loss / (min(len(train_loader), max_batches) if max_batches else len(train_loader))
- self.train_losses.append(avg_loss)
- self.train_mae.append(mae)
- return avg_loss, mae
-
- @torch.no_grad()
- def validate(self, val_loader, adj, max_batches=None):
- """
- 支持通过 max_batches 限制验证批次数(用于 RL 快速评估)。
- """
- self.model.eval()
- total_loss = 0.0
- all_outputs = []
- all_targets = []
- adj = adj.to(self.device)
-
- for batch_idx, (data, target) in enumerate(val_loader):
- if max_batches is not None and batch_idx >= max_batches:
- break
- data, target = data.to(self.device), target.to(self.device)
- data = data.unsqueeze(-1) # (batch_size, 145, 1)
- output = self.model(data, adj) # (batch_size, 145, 47)
- output = output.mean(dim=1) # (batch_size, 47)
- loss = self.criterion(output, target)
- total_loss += loss.item()
- all_outputs.append(output.cpu().numpy())
- all_targets.append(target.cpu().numpy())
-
- if len(all_outputs) == 0:
- return float('inf'), float('inf')
-
- all_outputs = np.vstack(all_outputs)
- all_targets = np.vstack(all_targets)
- mae = mean_absolute_error(all_targets, all_outputs)
- avg_loss = total_loss / (min(len(val_loader), max_batches) if max_batches else len(val_loader))
- self.val_losses.append(avg_loss)
- self.val_mae.append(mae)
- return avg_loss, mae
-
- def train(self, train_loader, val_loader, adj, epochs=None):
- adj = adj.to(self.device)
- epochs = epochs if epochs is not None else self.args.epochs
- self.logger.info(f"开始训练,共 {epochs} 个epoch")
-
- for epoch in tqdm(range(epochs)):
- train_loss, train_mae = self.train_epoch(train_loader, adj)
- val_loss, val_mae = self.validate(val_loader, adj)
- if self.scheduler is not None:
- self.scheduler.step(val_loss)
-
- # 保存最佳模型
- if val_loss < self.best_val_loss:
- self.best_val_loss = val_loss
- torch.save({
- 'epoch': epoch,
- 'model_state_dict': self.model.state_dict(),
- 'optimizer_state_dict': self.optimizer.state_dict(),
- 'loss': val_loss,
- 'args': self.args # 保存训练参数
- }, self.model_save_path)
- self.early_stop_counter = 0
- self.logger.info(f"已更新最佳模型到 {self.model_save_path}") # 日志
- else:
- self.early_stop_counter += 1
- if self.early_stop_counter >= self.args.patience:
- self.logger.info(f"早停机制触发,在第 {epoch+1} 轮停止训练")
- break
-
- if (epoch + 1) % 10 == 0:
- self.logger.info(f'Epoch {epoch+1}/{epochs}, '
- f'Train Loss: {train_loss:.6f}, Train MAE: {train_mae:.6f}, '
- f'Val Loss: {val_loss:.6f}, Val MAE: {val_mae:.6f}')
-
- self.plot_losses()
- self.plot_mae()
-
- # 加载最佳模型
- checkpoint = torch.load(self.model_save_path, map_location=self.device)
- self.model.load_state_dict(checkpoint['model_state_dict'])
- self.logger.info(f"加载最佳模型(第 {checkpoint['epoch']+1} 轮,验证损失: {checkpoint['loss']:.6f})")
-
- # 保存最终训练完成的模型(加载最佳模型后)
- torch.save({
- 'model_state_dict': self.model.state_dict(),
- 'optimizer_state_dict': self.optimizer.state_dict(),
- 'best_val_loss': self.best_val_loss,
- 'args': self.args
- }, self.final_model_path)
- self.logger.info(f"已保存最终模型到 {self.final_model_path}")
-
- return self.model
-
- @torch.no_grad()
- def test(self, test_loader, adj):
- self.model.eval()
- total_loss = 0.0
- all_outputs = []
- all_targets = []
- adj = adj.to(self.device)
-
- for data, target in test_loader:
- data, target = data.to(self.device), target.to(self.device)
- data = data.unsqueeze(-1)
- output = self.model(data, adj)
- output = output.mean(dim=1)
- loss = self.criterion(output, target)
- total_loss += loss.item()
- all_outputs.append(output.cpu().numpy())
- all_targets.append(target.cpu().numpy())
-
- all_outputs = np.vstack(all_outputs)
- all_targets = np.vstack(all_targets)
- mse = total_loss / len(test_loader)
- mae = mean_absolute_error(all_targets, all_outputs)
- rmse = np.sqrt(mean_squared_error(all_targets, all_outputs))
-
- all_outputs_original = self.preprocessor.inverse_transform_targets(all_outputs)
- all_targets_original = self.preprocessor.inverse_transform_targets(all_targets)
- original_mse = mean_squared_error(all_targets_original, all_outputs_original)
- original_mae = mean_absolute_error(all_targets_original, all_outputs_original)
- original_rmse = np.sqrt(original_mse)
-
- self.logger.info(f'Test Loss (normalized): MSE={mse:.6f}, MAE={mae:.6f}, RMSE={rmse:.6f}')
- self.logger.info(f'Test Loss (original scale): MSE={original_mse:.6f}, MAE={original_mae:.6f}, RMSE={original_rmse:.6f}')
- self.plot_predictions(all_outputs_original, all_targets_original)
- return {
- 'normalized_mse': mse,
- 'normalized_mae': mae,
- 'normalized_rmse': rmse,
- 'original_mse': original_mse,
- 'original_mae': original_mae,
- 'original_rmse': original_rmse,
- 'predictions': all_outputs_original,
- 'targets': all_targets_original
- }
-
- def plot_losses(self):
- plt.figure(figsize=(10, 6))
- plt.plot(self.train_losses, label='Train Loss')
- plt.plot(self.val_losses, label='Validation Loss')
- plt.xlabel('Epoch')
- plt.ylabel('MSE Loss')
- plt.title('Training and Validation Loss')
- plt.legend()
- plt.savefig('plots/loss_curve.png')
- plt.close()
-
- def plot_mae(self):
- plt.figure(figsize=(10, 6))
- plt.plot(self.train_mae, label='Train MAE')
- plt.plot(self.val_mae, label='Validation MAE')
- plt.xlabel('Epoch')
- plt.ylabel('MAE')
- plt.title('Training and Validation MAE')
- plt.legend()
- plt.savefig('plots/mae_curve.png')
- plt.close()
-
- def plot_predictions(self, predictions, targets):
- num_plots = min(3, self.args.num_targets)
- plt.figure(figsize=(15, 5*num_plots))
- for i in range(num_plots):
- plt.subplot(num_plots, 1, i+1)
- plt.plot(targets[:100, i], label='True Value')
- plt.plot(predictions[:100, i], label='Predicted Value')
- plt.xlabel('Time Step')
- plt.ylabel(f'Target {i+1}')
- plt.title(f'Prediction vs True Value for Target {i+1}')
- plt.legend()
- plt.tight_layout()
- plt.savefig('plots/prediction_examples.png')
- plt.close()
|