|
|
@@ -1,262 +1,262 @@
|
|
|
-import torch
|
|
|
-import torch.nn as nn
|
|
|
-import torch.optim as optim
|
|
|
-import numpy as np
|
|
|
-import matplotlib.pyplot as plt
|
|
|
-from tqdm import tqdm
|
|
|
-import os
|
|
|
-from sklearn.metrics import mean_absolute_error, mean_squared_error
|
|
|
-import logging
|
|
|
-
|
|
|
-class DataTrainer:
|
|
|
- def __init__(self, model, args, preprocessor, optimizer=None, scheduler=None, logger=None):
|
|
|
- self.model = model
|
|
|
- self.args = args
|
|
|
- self.preprocessor = preprocessor
|
|
|
- self.device = args.device
|
|
|
- self.logger = logger if logger is not None else self._default_logger()
|
|
|
-
|
|
|
- self.criterion = nn.MSELoss()
|
|
|
- self.optimizer = optimizer if optimizer is not None else optim.Adam(
|
|
|
- model.parameters(),
|
|
|
- lr=args.lr,
|
|
|
- weight_decay=args.weight_decay
|
|
|
- )
|
|
|
- self.scheduler = scheduler
|
|
|
-
|
|
|
- self.train_losses = []
|
|
|
- self.val_losses = []
|
|
|
- self.train_mae = []
|
|
|
- self.val_mae = []
|
|
|
- self.best_val_loss = float('inf')
|
|
|
- self.early_stop_counter = 0
|
|
|
- self.model_save_path = os.path.join('models', 'best_model.pth')
|
|
|
- self.final_model_path = os.path.join('models', 'final_model.pth')
|
|
|
-
|
|
|
- if not os.path.exists('models'):
|
|
|
- os.makedirs('models')
|
|
|
- if not os.path.exists('plots'):
|
|
|
- os.makedirs('plots')
|
|
|
-
|
|
|
- def _default_logger(self):
|
|
|
- logger = logging.getLogger('DefaultLogger')
|
|
|
- logger.setLevel(logging.INFO)
|
|
|
- console_handler = logging.StreamHandler()
|
|
|
- console_handler.setLevel(logging.INFO)
|
|
|
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
- console_handler.setFormatter(formatter)
|
|
|
- if not logger.handlers:
|
|
|
- logger.addHandler(console_handler)
|
|
|
- return logger
|
|
|
-
|
|
|
- @torch.enable_grad()
|
|
|
- def train_epoch(self, train_loader, adj, max_batches=None):
|
|
|
- """
|
|
|
- 支持通过 max_batches 限制本 epoch 使用的批次数(用于 RL 快速评估)。
|
|
|
- """
|
|
|
- self.model.train()
|
|
|
- total_loss = 0.0
|
|
|
- all_outputs = []
|
|
|
- all_targets = []
|
|
|
- adj = adj.to(self.device)
|
|
|
-
|
|
|
- for batch_idx, (data, target) in enumerate(train_loader):
|
|
|
- if max_batches is not None and batch_idx >= max_batches:
|
|
|
- break
|
|
|
- data, target = data.to(self.device), target.to(self.device)
|
|
|
- data = data.unsqueeze(-1) # (batch_size, 145, 1)
|
|
|
-
|
|
|
- self.optimizer.zero_grad()
|
|
|
- output = self.model(data, adj) # (batch_size, 145, 47)
|
|
|
- output = output.mean(dim=1) # (batch_size, 47)
|
|
|
-
|
|
|
- loss = self.criterion(output, target)
|
|
|
- loss.backward()
|
|
|
- if self.args.grad_clip > 0:
|
|
|
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)
|
|
|
- self.optimizer.step()
|
|
|
-
|
|
|
- total_loss += loss.item()
|
|
|
- all_outputs.append(output.detach().cpu().numpy())
|
|
|
- all_targets.append(target.detach().cpu().numpy())
|
|
|
-
|
|
|
- if len(all_outputs) == 0:
|
|
|
- # 在极小数据/极小 batch 情况下的保护
|
|
|
- return float('inf'), float('inf')
|
|
|
-
|
|
|
- all_outputs = np.vstack(all_outputs)
|
|
|
- all_targets = np.vstack(all_targets)
|
|
|
- mae = mean_absolute_error(all_targets, all_outputs)
|
|
|
-
|
|
|
- avg_loss = total_loss / (min(len(train_loader), max_batches) if max_batches else len(train_loader))
|
|
|
- self.train_losses.append(avg_loss)
|
|
|
- self.train_mae.append(mae)
|
|
|
- return avg_loss, mae
|
|
|
-
|
|
|
- @torch.no_grad()
|
|
|
- def validate(self, val_loader, adj, max_batches=None):
|
|
|
- """
|
|
|
- 支持通过 max_batches 限制验证批次数(用于 RL 快速评估)。
|
|
|
- """
|
|
|
- self.model.eval()
|
|
|
- total_loss = 0.0
|
|
|
- all_outputs = []
|
|
|
- all_targets = []
|
|
|
- adj = adj.to(self.device)
|
|
|
-
|
|
|
- for batch_idx, (data, target) in enumerate(val_loader):
|
|
|
- if max_batches is not None and batch_idx >= max_batches:
|
|
|
- break
|
|
|
- data, target = data.to(self.device), target.to(self.device)
|
|
|
- data = data.unsqueeze(-1) # (batch_size, 145, 1)
|
|
|
- output = self.model(data, adj) # (batch_size, 145, 47)
|
|
|
- output = output.mean(dim=1) # (batch_size, 47)
|
|
|
- loss = self.criterion(output, target)
|
|
|
- total_loss += loss.item()
|
|
|
- all_outputs.append(output.cpu().numpy())
|
|
|
- all_targets.append(target.cpu().numpy())
|
|
|
-
|
|
|
- if len(all_outputs) == 0:
|
|
|
- return float('inf'), float('inf')
|
|
|
-
|
|
|
- all_outputs = np.vstack(all_outputs)
|
|
|
- all_targets = np.vstack(all_targets)
|
|
|
- mae = mean_absolute_error(all_targets, all_outputs)
|
|
|
- avg_loss = total_loss / (min(len(val_loader), max_batches) if max_batches else len(val_loader))
|
|
|
- self.val_losses.append(avg_loss)
|
|
|
- self.val_mae.append(mae)
|
|
|
- return avg_loss, mae
|
|
|
-
|
|
|
- def train(self, train_loader, val_loader, adj, epochs=None):
|
|
|
- adj = adj.to(self.device)
|
|
|
- epochs = epochs if epochs is not None else self.args.epochs
|
|
|
- self.logger.info(f"开始训练,共 {epochs} 个epoch")
|
|
|
-
|
|
|
- for epoch in tqdm(range(epochs)):
|
|
|
- train_loss, train_mae = self.train_epoch(train_loader, adj)
|
|
|
- val_loss, val_mae = self.validate(val_loader, adj)
|
|
|
- if self.scheduler is not None:
|
|
|
- self.scheduler.step(val_loss)
|
|
|
-
|
|
|
- # 保存最佳模型
|
|
|
- if val_loss < self.best_val_loss:
|
|
|
- self.best_val_loss = val_loss
|
|
|
- torch.save({
|
|
|
- 'epoch': epoch,
|
|
|
- 'model_state_dict': self.model.state_dict(),
|
|
|
- 'optimizer_state_dict': self.optimizer.state_dict(),
|
|
|
- 'loss': val_loss,
|
|
|
- 'args': self.args # 保存训练参数
|
|
|
- }, self.model_save_path)
|
|
|
- self.early_stop_counter = 0
|
|
|
- self.logger.info(f"已更新最佳模型到 {self.model_save_path}") # 日志
|
|
|
- else:
|
|
|
- self.early_stop_counter += 1
|
|
|
- if self.early_stop_counter >= self.args.patience:
|
|
|
- self.logger.info(f"早停机制触发,在第 {epoch+1} 轮停止训练")
|
|
|
- break
|
|
|
-
|
|
|
- if (epoch + 1) % 10 == 0:
|
|
|
- self.logger.info(f'Epoch {epoch+1}/{epochs}, '
|
|
|
- f'Train Loss: {train_loss:.6f}, Train MAE: {train_mae:.6f}, '
|
|
|
- f'Val Loss: {val_loss:.6f}, Val MAE: {val_mae:.6f}')
|
|
|
-
|
|
|
- self.plot_losses()
|
|
|
- self.plot_mae()
|
|
|
-
|
|
|
- # 加载最佳模型
|
|
|
- checkpoint = torch.load(self.model_save_path, map_location=self.device)
|
|
|
- self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
- self.logger.info(f"加载最佳模型(第 {checkpoint['epoch']+1} 轮,验证损失: {checkpoint['loss']:.6f})")
|
|
|
-
|
|
|
- # 保存最终训练完成的模型(加载最佳模型后)
|
|
|
- torch.save({
|
|
|
- 'model_state_dict': self.model.state_dict(),
|
|
|
- 'optimizer_state_dict': self.optimizer.state_dict(),
|
|
|
- 'best_val_loss': self.best_val_loss,
|
|
|
- 'args': self.args
|
|
|
- }, self.final_model_path)
|
|
|
- self.logger.info(f"已保存最终模型到 {self.final_model_path}")
|
|
|
-
|
|
|
- return self.model
|
|
|
-
|
|
|
- @torch.no_grad()
|
|
|
- def test(self, test_loader, adj):
|
|
|
- self.model.eval()
|
|
|
- total_loss = 0.0
|
|
|
- all_outputs = []
|
|
|
- all_targets = []
|
|
|
- adj = adj.to(self.device)
|
|
|
-
|
|
|
- for data, target in test_loader:
|
|
|
- data, target = data.to(self.device), target.to(self.device)
|
|
|
- data = data.unsqueeze(-1)
|
|
|
- output = self.model(data, adj)
|
|
|
- output = output.mean(dim=1)
|
|
|
- loss = self.criterion(output, target)
|
|
|
- total_loss += loss.item()
|
|
|
- all_outputs.append(output.cpu().numpy())
|
|
|
- all_targets.append(target.cpu().numpy())
|
|
|
-
|
|
|
- all_outputs = np.vstack(all_outputs)
|
|
|
- all_targets = np.vstack(all_targets)
|
|
|
- mse = total_loss / len(test_loader)
|
|
|
- mae = mean_absolute_error(all_targets, all_outputs)
|
|
|
- rmse = np.sqrt(mean_squared_error(all_targets, all_outputs))
|
|
|
-
|
|
|
- all_outputs_original = self.preprocessor.inverse_transform_targets(all_outputs)
|
|
|
- all_targets_original = self.preprocessor.inverse_transform_targets(all_targets)
|
|
|
- original_mse = mean_squared_error(all_targets_original, all_outputs_original)
|
|
|
- original_mae = mean_absolute_error(all_targets_original, all_outputs_original)
|
|
|
- original_rmse = np.sqrt(original_mse)
|
|
|
-
|
|
|
- self.logger.info(f'Test Loss (normalized): MSE={mse:.6f}, MAE={mae:.6f}, RMSE={rmse:.6f}')
|
|
|
- self.logger.info(f'Test Loss (original scale): MSE={original_mse:.6f}, MAE={original_mae:.6f}, RMSE={original_rmse:.6f}')
|
|
|
- self.plot_predictions(all_outputs_original, all_targets_original)
|
|
|
- return {
|
|
|
- 'normalized_mse': mse,
|
|
|
- 'normalized_mae': mae,
|
|
|
- 'normalized_rmse': rmse,
|
|
|
- 'original_mse': original_mse,
|
|
|
- 'original_mae': original_mae,
|
|
|
- 'original_rmse': original_rmse,
|
|
|
- 'predictions': all_outputs_original,
|
|
|
- 'targets': all_targets_original
|
|
|
- }
|
|
|
-
|
|
|
- def plot_losses(self):
|
|
|
- plt.figure(figsize=(10, 6))
|
|
|
- plt.plot(self.train_losses, label='Train Loss')
|
|
|
- plt.plot(self.val_losses, label='Validation Loss')
|
|
|
- plt.xlabel('Epoch')
|
|
|
- plt.ylabel('MSE Loss')
|
|
|
- plt.title('Training and Validation Loss')
|
|
|
- plt.legend()
|
|
|
- plt.savefig('plots/loss_curve.png')
|
|
|
- plt.close()
|
|
|
-
|
|
|
- def plot_mae(self):
|
|
|
- plt.figure(figsize=(10, 6))
|
|
|
- plt.plot(self.train_mae, label='Train MAE')
|
|
|
- plt.plot(self.val_mae, label='Validation MAE')
|
|
|
- plt.xlabel('Epoch')
|
|
|
- plt.ylabel('MAE')
|
|
|
- plt.title('Training and Validation MAE')
|
|
|
- plt.legend()
|
|
|
- plt.savefig('plots/mae_curve.png')
|
|
|
- plt.close()
|
|
|
-
|
|
|
- def plot_predictions(self, predictions, targets):
|
|
|
- num_plots = min(3, self.args.num_targets)
|
|
|
- plt.figure(figsize=(15, 5*num_plots))
|
|
|
- for i in range(num_plots):
|
|
|
- plt.subplot(num_plots, 1, i+1)
|
|
|
- plt.plot(targets[:100, i], label='True Value')
|
|
|
- plt.plot(predictions[:100, i], label='Predicted Value')
|
|
|
- plt.xlabel('Time Step')
|
|
|
- plt.ylabel(f'Target {i+1}')
|
|
|
- plt.title(f'Prediction vs True Value for Target {i+1}')
|
|
|
- plt.legend()
|
|
|
- plt.tight_layout()
|
|
|
- plt.savefig('plots/prediction_examples.png')
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+import torch.optim as optim
|
|
|
+import numpy as np
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+from tqdm import tqdm
|
|
|
+import os
|
|
|
+from sklearn.metrics import mean_absolute_error, mean_squared_error
|
|
|
+import logging
|
|
|
+
|
|
|
+class DataTrainer:
|
|
|
+ def __init__(self, model, args, preprocessor, optimizer=None, scheduler=None, logger=None):
|
|
|
+ self.model = model
|
|
|
+ self.args = args
|
|
|
+ self.preprocessor = preprocessor
|
|
|
+ self.device = args.device
|
|
|
+ self.logger = logger if logger is not None else self._default_logger()
|
|
|
+
|
|
|
+ self.criterion = nn.MSELoss()
|
|
|
+ self.optimizer = optimizer if optimizer is not None else optim.Adam(
|
|
|
+ model.parameters(),
|
|
|
+ lr=args.lr,
|
|
|
+ weight_decay=args.weight_decay
|
|
|
+ )
|
|
|
+ self.scheduler = scheduler
|
|
|
+
|
|
|
+ self.train_losses = []
|
|
|
+ self.val_losses = []
|
|
|
+ self.train_mae = []
|
|
|
+ self.val_mae = []
|
|
|
+ self.best_val_loss = float('inf')
|
|
|
+ self.early_stop_counter = 0
|
|
|
+ self.model_save_path = os.path.join('models', 'best_model.pth')
|
|
|
+ self.final_model_path = os.path.join('models', 'final_model.pth')
|
|
|
+
|
|
|
+ if not os.path.exists('models'):
|
|
|
+ os.makedirs('models')
|
|
|
+ if not os.path.exists('plots'):
|
|
|
+ os.makedirs('plots')
|
|
|
+
|
|
|
+ def _default_logger(self):
|
|
|
+ logger = logging.getLogger('DefaultLogger')
|
|
|
+ logger.setLevel(logging.INFO)
|
|
|
+ console_handler = logging.StreamHandler()
|
|
|
+ console_handler.setLevel(logging.INFO)
|
|
|
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
+ console_handler.setFormatter(formatter)
|
|
|
+ if not logger.handlers:
|
|
|
+ logger.addHandler(console_handler)
|
|
|
+ return logger
|
|
|
+
|
|
|
+ @torch.enable_grad()
|
|
|
+ def train_epoch(self, train_loader, adj, max_batches=None):
|
|
|
+ """
|
|
|
+ 支持通过 max_batches 限制本 epoch 使用的批次数(用于 RL 快速评估)。
|
|
|
+ """
|
|
|
+ self.model.train()
|
|
|
+ total_loss = 0.0
|
|
|
+ all_outputs = []
|
|
|
+ all_targets = []
|
|
|
+ adj = adj.to(self.device)
|
|
|
+
|
|
|
+ for batch_idx, (data, target) in enumerate(train_loader):
|
|
|
+ if max_batches is not None and batch_idx >= max_batches:
|
|
|
+ break
|
|
|
+ data, target = data.to(self.device), target.to(self.device)
|
|
|
+ data = data.unsqueeze(-1) # (batch_size, 145, 1)
|
|
|
+
|
|
|
+ self.optimizer.zero_grad()
|
|
|
+ output = self.model(data, adj) # (batch_size, 145, 47)
|
|
|
+ output = output.mean(dim=1) # (batch_size, 47)
|
|
|
+
|
|
|
+ loss = self.criterion(output, target)
|
|
|
+ loss.backward()
|
|
|
+ if self.args.grad_clip > 0:
|
|
|
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)
|
|
|
+ self.optimizer.step()
|
|
|
+
|
|
|
+ total_loss += loss.item()
|
|
|
+ all_outputs.append(output.detach().cpu().numpy())
|
|
|
+ all_targets.append(target.detach().cpu().numpy())
|
|
|
+
|
|
|
+ if len(all_outputs) == 0:
|
|
|
+ # 在极小数据/极小 batch 情况下的保护
|
|
|
+ return float('inf'), float('inf')
|
|
|
+
|
|
|
+ all_outputs = np.vstack(all_outputs)
|
|
|
+ all_targets = np.vstack(all_targets)
|
|
|
+ mae = mean_absolute_error(all_targets, all_outputs)
|
|
|
+
|
|
|
+ avg_loss = total_loss / (min(len(train_loader), max_batches) if max_batches else len(train_loader))
|
|
|
+ self.train_losses.append(avg_loss)
|
|
|
+ self.train_mae.append(mae)
|
|
|
+ return avg_loss, mae
|
|
|
+
|
|
|
+ @torch.no_grad()
|
|
|
+ def validate(self, val_loader, adj, max_batches=None):
|
|
|
+ """
|
|
|
+ 支持通过 max_batches 限制验证批次数(用于 RL 快速评估)。
|
|
|
+ """
|
|
|
+ self.model.eval()
|
|
|
+ total_loss = 0.0
|
|
|
+ all_outputs = []
|
|
|
+ all_targets = []
|
|
|
+ adj = adj.to(self.device)
|
|
|
+
|
|
|
+ for batch_idx, (data, target) in enumerate(val_loader):
|
|
|
+ if max_batches is not None and batch_idx >= max_batches:
|
|
|
+ break
|
|
|
+ data, target = data.to(self.device), target.to(self.device)
|
|
|
+ data = data.unsqueeze(-1) # (batch_size, 145, 1)
|
|
|
+ output = self.model(data, adj) # (batch_size, 145, 47)
|
|
|
+ output = output.mean(dim=1) # (batch_size, 47)
|
|
|
+ loss = self.criterion(output, target)
|
|
|
+ total_loss += loss.item()
|
|
|
+ all_outputs.append(output.cpu().numpy())
|
|
|
+ all_targets.append(target.cpu().numpy())
|
|
|
+
|
|
|
+ if len(all_outputs) == 0:
|
|
|
+ return float('inf'), float('inf')
|
|
|
+
|
|
|
+ all_outputs = np.vstack(all_outputs)
|
|
|
+ all_targets = np.vstack(all_targets)
|
|
|
+ mae = mean_absolute_error(all_targets, all_outputs)
|
|
|
+ avg_loss = total_loss / (min(len(val_loader), max_batches) if max_batches else len(val_loader))
|
|
|
+ self.val_losses.append(avg_loss)
|
|
|
+ self.val_mae.append(mae)
|
|
|
+ return avg_loss, mae
|
|
|
+
|
|
|
+ def train(self, train_loader, val_loader, adj, epochs=None):
|
|
|
+ adj = adj.to(self.device)
|
|
|
+ epochs = epochs if epochs is not None else self.args.epochs
|
|
|
+ self.logger.info(f"开始训练,共 {epochs} 个epoch")
|
|
|
+
|
|
|
+ for epoch in tqdm(range(epochs)):
|
|
|
+ train_loss, train_mae = self.train_epoch(train_loader, adj)
|
|
|
+ val_loss, val_mae = self.validate(val_loader, adj)
|
|
|
+ if self.scheduler is not None:
|
|
|
+ self.scheduler.step(val_loss)
|
|
|
+
|
|
|
+ # 保存最佳模型
|
|
|
+ if val_loss < self.best_val_loss:
|
|
|
+ self.best_val_loss = val_loss
|
|
|
+ torch.save({
|
|
|
+ 'epoch': epoch,
|
|
|
+ 'model_state_dict': self.model.state_dict(),
|
|
|
+ 'optimizer_state_dict': self.optimizer.state_dict(),
|
|
|
+ 'loss': val_loss,
|
|
|
+ 'args': self.args # 保存训练参数
|
|
|
+ }, self.model_save_path)
|
|
|
+ self.early_stop_counter = 0
|
|
|
+ self.logger.info(f"已更新最佳模型到 {self.model_save_path}") # 日志
|
|
|
+ else:
|
|
|
+ self.early_stop_counter += 1
|
|
|
+ if self.early_stop_counter >= self.args.patience:
|
|
|
+ self.logger.info(f"早停机制触发,在第 {epoch+1} 轮停止训练")
|
|
|
+ break
|
|
|
+
|
|
|
+ if (epoch + 1) % 10 == 0:
|
|
|
+ self.logger.info(f'Epoch {epoch+1}/{epochs}, '
|
|
|
+ f'Train Loss: {train_loss:.6f}, Train MAE: {train_mae:.6f}, '
|
|
|
+ f'Val Loss: {val_loss:.6f}, Val MAE: {val_mae:.6f}')
|
|
|
+
|
|
|
+ self.plot_losses()
|
|
|
+ self.plot_mae()
|
|
|
+
|
|
|
+ # 加载最佳模型
|
|
|
+ checkpoint = torch.load(self.model_save_path, map_location=self.device)
|
|
|
+ self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
+ self.logger.info(f"加载最佳模型(第 {checkpoint['epoch']+1} 轮,验证损失: {checkpoint['loss']:.6f})")
|
|
|
+
|
|
|
+ # 保存最终训练完成的模型(加载最佳模型后)
|
|
|
+ torch.save({
|
|
|
+ 'model_state_dict': self.model.state_dict(),
|
|
|
+ 'optimizer_state_dict': self.optimizer.state_dict(),
|
|
|
+ 'best_val_loss': self.best_val_loss,
|
|
|
+ 'args': self.args
|
|
|
+ }, self.final_model_path)
|
|
|
+ self.logger.info(f"已保存最终模型到 {self.final_model_path}")
|
|
|
+
|
|
|
+ return self.model
|
|
|
+
|
|
|
+ @torch.no_grad()
|
|
|
+ def test(self, test_loader, adj):
|
|
|
+ self.model.eval()
|
|
|
+ total_loss = 0.0
|
|
|
+ all_outputs = []
|
|
|
+ all_targets = []
|
|
|
+ adj = adj.to(self.device)
|
|
|
+
|
|
|
+ for data, target in test_loader:
|
|
|
+ data, target = data.to(self.device), target.to(self.device)
|
|
|
+ data = data.unsqueeze(-1)
|
|
|
+ output = self.model(data, adj)
|
|
|
+ output = output.mean(dim=1)
|
|
|
+ loss = self.criterion(output, target)
|
|
|
+ total_loss += loss.item()
|
|
|
+ all_outputs.append(output.cpu().numpy())
|
|
|
+ all_targets.append(target.cpu().numpy())
|
|
|
+
|
|
|
+ all_outputs = np.vstack(all_outputs)
|
|
|
+ all_targets = np.vstack(all_targets)
|
|
|
+ mse = total_loss / len(test_loader)
|
|
|
+ mae = mean_absolute_error(all_targets, all_outputs)
|
|
|
+ rmse = np.sqrt(mean_squared_error(all_targets, all_outputs))
|
|
|
+
|
|
|
+ all_outputs_original = self.preprocessor.inverse_transform_targets(all_outputs)
|
|
|
+ all_targets_original = self.preprocessor.inverse_transform_targets(all_targets)
|
|
|
+ original_mse = mean_squared_error(all_targets_original, all_outputs_original)
|
|
|
+ original_mae = mean_absolute_error(all_targets_original, all_outputs_original)
|
|
|
+ original_rmse = np.sqrt(original_mse)
|
|
|
+
|
|
|
+ self.logger.info(f'Test Loss (normalized): MSE={mse:.6f}, MAE={mae:.6f}, RMSE={rmse:.6f}')
|
|
|
+ self.logger.info(f'Test Loss (original scale): MSE={original_mse:.6f}, MAE={original_mae:.6f}, RMSE={original_rmse:.6f}')
|
|
|
+ self.plot_predictions(all_outputs_original, all_targets_original)
|
|
|
+ return {
|
|
|
+ 'normalized_mse': mse,
|
|
|
+ 'normalized_mae': mae,
|
|
|
+ 'normalized_rmse': rmse,
|
|
|
+ 'original_mse': original_mse,
|
|
|
+ 'original_mae': original_mae,
|
|
|
+ 'original_rmse': original_rmse,
|
|
|
+ 'predictions': all_outputs_original,
|
|
|
+ 'targets': all_targets_original
|
|
|
+ }
|
|
|
+
|
|
|
+ def plot_losses(self):
|
|
|
+ plt.figure(figsize=(10, 6))
|
|
|
+ plt.plot(self.train_losses, label='Train Loss')
|
|
|
+ plt.plot(self.val_losses, label='Validation Loss')
|
|
|
+ plt.xlabel('Epoch')
|
|
|
+ plt.ylabel('MSE Loss')
|
|
|
+ plt.title('Training and Validation Loss')
|
|
|
+ plt.legend()
|
|
|
+ plt.savefig('plots/loss_curve.png')
|
|
|
+ plt.close()
|
|
|
+
|
|
|
+ def plot_mae(self):
|
|
|
+ plt.figure(figsize=(10, 6))
|
|
|
+ plt.plot(self.train_mae, label='Train MAE')
|
|
|
+ plt.plot(self.val_mae, label='Validation MAE')
|
|
|
+ plt.xlabel('Epoch')
|
|
|
+ plt.ylabel('MAE')
|
|
|
+ plt.title('Training and Validation MAE')
|
|
|
+ plt.legend()
|
|
|
+ plt.savefig('plots/mae_curve.png')
|
|
|
+ plt.close()
|
|
|
+
|
|
|
+ def plot_predictions(self, predictions, targets):
|
|
|
+ num_plots = min(3, self.args.num_targets)
|
|
|
+ plt.figure(figsize=(15, 5*num_plots))
|
|
|
+ for i in range(num_plots):
|
|
|
+ plt.subplot(num_plots, 1, i+1)
|
|
|
+ plt.plot(targets[:100, i], label='True Value')
|
|
|
+ plt.plot(predictions[:100, i], label='Predicted Value')
|
|
|
+ plt.xlabel('Time Step')
|
|
|
+ plt.ylabel(f'Target {i+1}')
|
|
|
+ plt.title(f'Prediction vs True Value for Target {i+1}')
|
|
|
+ plt.legend()
|
|
|
+ plt.tight_layout()
|
|
|
+ plt.savefig('plots/prediction_examples.png')
|
|
|
plt.close()
|