Procházet zdrojové kódy

模型训练文件

zhanghao před 4 měsíci
rodič
revize
bc2f24b413
1 změnil soubory, kde provedl 261 přidání a 261 odebrání
  1. 261 261
      models/causal-inference/data_trainer.py

+ 261 - 261
models/causal-inference/data_trainer.py

@@ -1,262 +1,262 @@
-import torch
-import torch.nn as nn
-import torch.optim as optim
-import numpy as np
-import matplotlib.pyplot as plt
-from tqdm import tqdm
-import os
-from sklearn.metrics import mean_absolute_error, mean_squared_error
-import logging
-
-class DataTrainer:
-    def __init__(self, model, args, preprocessor, optimizer=None, scheduler=None, logger=None):
-        self.model = model
-        self.args = args
-        self.preprocessor = preprocessor
-        self.device = args.device
-        self.logger = logger if logger is not None else self._default_logger()
-        
-        self.criterion = nn.MSELoss()
-        self.optimizer = optimizer if optimizer is not None else optim.Adam(
-            model.parameters(),
-            lr=args.lr,
-            weight_decay=args.weight_decay
-        )
-        self.scheduler = scheduler
-        
-        self.train_losses = []
-        self.val_losses = []
-        self.train_mae = []
-        self.val_mae = []
-        self.best_val_loss = float('inf')
-        self.early_stop_counter = 0
-        self.model_save_path = os.path.join('models', 'best_model.pth')
-        self.final_model_path = os.path.join('models', 'final_model.pth')
-        
-        if not os.path.exists('models'):
-            os.makedirs('models')
-        if not os.path.exists('plots'):
-            os.makedirs('plots')
-        
-    def _default_logger(self):
-        logger = logging.getLogger('DefaultLogger')
-        logger.setLevel(logging.INFO)
-        console_handler = logging.StreamHandler()
-        console_handler.setLevel(logging.INFO)
-        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
-        console_handler.setFormatter(formatter)
-        if not logger.handlers:
-            logger.addHandler(console_handler)
-        return logger
-    
-    @torch.enable_grad()
-    def train_epoch(self, train_loader, adj, max_batches=None):
-        """
-        支持通过 max_batches 限制本 epoch 使用的批次数(用于 RL 快速评估)。
-        """
-        self.model.train()
-        total_loss = 0.0
-        all_outputs = []
-        all_targets = []
-        adj = adj.to(self.device)
-
-        for batch_idx, (data, target) in enumerate(train_loader):
-            if max_batches is not None and batch_idx >= max_batches:
-                break
-            data, target = data.to(self.device), target.to(self.device)
-            data = data.unsqueeze(-1)  # (batch_size, 145, 1)
-            
-            self.optimizer.zero_grad()
-            output = self.model(data, adj)  # (batch_size, 145, 47)
-            output = output.mean(dim=1)     # (batch_size, 47)
-            
-            loss = self.criterion(output, target)
-            loss.backward()
-            if self.args.grad_clip > 0:
-                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)
-            self.optimizer.step()
-            
-            total_loss += loss.item()
-            all_outputs.append(output.detach().cpu().numpy())
-            all_targets.append(target.detach().cpu().numpy())
-        
-        if len(all_outputs) == 0:
-            # 在极小数据/极小 batch 情况下的保护
-            return float('inf'), float('inf')
-        
-        all_outputs = np.vstack(all_outputs)
-        all_targets = np.vstack(all_targets)
-        mae = mean_absolute_error(all_targets, all_outputs)
-        
-        avg_loss = total_loss / (min(len(train_loader), max_batches) if max_batches else len(train_loader))
-        self.train_losses.append(avg_loss)
-        self.train_mae.append(mae)
-        return avg_loss, mae
-    
-    @torch.no_grad()
-    def validate(self, val_loader, adj, max_batches=None):
-        """
-        支持通过 max_batches 限制验证批次数(用于 RL 快速评估)。
-        """
-        self.model.eval()
-        total_loss = 0.0
-        all_outputs = []
-        all_targets = []
-        adj = adj.to(self.device)
-        
-        for batch_idx, (data, target) in enumerate(val_loader):
-            if max_batches is not None and batch_idx >= max_batches:
-                break
-            data, target = data.to(self.device), target.to(self.device)
-            data = data.unsqueeze(-1)  # (batch_size, 145, 1)
-            output = self.model(data, adj)  # (batch_size, 145, 47)
-            output = output.mean(dim=1)     # (batch_size, 47)
-            loss = self.criterion(output, target)
-            total_loss += loss.item()
-            all_outputs.append(output.cpu().numpy())
-            all_targets.append(target.cpu().numpy())
-        
-        if len(all_outputs) == 0:
-            return float('inf'), float('inf')
-        
-        all_outputs = np.vstack(all_outputs)
-        all_targets = np.vstack(all_targets)
-        mae = mean_absolute_error(all_targets, all_outputs)
-        avg_loss = total_loss / (min(len(val_loader), max_batches) if max_batches else len(val_loader))
-        self.val_losses.append(avg_loss)
-        self.val_mae.append(mae)
-        return avg_loss, mae
-    
-    def train(self, train_loader, val_loader, adj, epochs=None):
-        adj = adj.to(self.device)
-        epochs = epochs if epochs is not None else self.args.epochs
-        self.logger.info(f"开始训练,共 {epochs} 个epoch")
-        
-        for epoch in tqdm(range(epochs)):
-            train_loss, train_mae = self.train_epoch(train_loader, adj)
-            val_loss, val_mae = self.validate(val_loader, adj)
-            if self.scheduler is not None:
-                self.scheduler.step(val_loss)
-            
-            # 保存最佳模型
-            if val_loss < self.best_val_loss:
-                self.best_val_loss = val_loss
-                torch.save({
-                    'epoch': epoch,
-                    'model_state_dict': self.model.state_dict(),
-                    'optimizer_state_dict': self.optimizer.state_dict(),
-                    'loss': val_loss,
-                    'args': self.args  # 保存训练参数
-                }, self.model_save_path)
-                self.early_stop_counter = 0
-                self.logger.info(f"已更新最佳模型到 {self.model_save_path}")  # 日志
-            else:
-                self.early_stop_counter += 1
-                if self.early_stop_counter >= self.args.patience:
-                    self.logger.info(f"早停机制触发,在第 {epoch+1} 轮停止训练")
-                    break
-            
-            if (epoch + 1) % 10 == 0:
-                self.logger.info(f'Epoch {epoch+1}/{epochs}, '
-                                 f'Train Loss: {train_loss:.6f}, Train MAE: {train_mae:.6f}, '
-                                 f'Val Loss: {val_loss:.6f}, Val MAE: {val_mae:.6f}')
-        
-        self.plot_losses()
-        self.plot_mae()
-        
-        # 加载最佳模型
-        checkpoint = torch.load(self.model_save_path, map_location=self.device)
-        self.model.load_state_dict(checkpoint['model_state_dict'])
-        self.logger.info(f"加载最佳模型(第 {checkpoint['epoch']+1} 轮,验证损失: {checkpoint['loss']:.6f})")
-        
-        # 保存最终训练完成的模型(加载最佳模型后)
-        torch.save({
-            'model_state_dict': self.model.state_dict(),
-            'optimizer_state_dict': self.optimizer.state_dict(),
-            'best_val_loss': self.best_val_loss,
-            'args': self.args
-        }, self.final_model_path)
-        self.logger.info(f"已保存最终模型到 {self.final_model_path}")
-        
-        return self.model
-    
-    @torch.no_grad()
-    def test(self, test_loader, adj):
-        self.model.eval()
-        total_loss = 0.0
-        all_outputs = []
-        all_targets = []
-        adj = adj.to(self.device)
-        
-        for data, target in test_loader:
-            data, target = data.to(self.device), target.to(self.device)
-            data = data.unsqueeze(-1)
-            output = self.model(data, adj)
-            output = output.mean(dim=1)
-            loss = self.criterion(output, target)
-            total_loss += loss.item()
-            all_outputs.append(output.cpu().numpy())
-            all_targets.append(target.cpu().numpy())
-        
-        all_outputs = np.vstack(all_outputs)
-        all_targets = np.vstack(all_targets)
-        mse = total_loss / len(test_loader)
-        mae = mean_absolute_error(all_targets, all_outputs)
-        rmse = np.sqrt(mean_squared_error(all_targets, all_outputs))
-        
-        all_outputs_original = self.preprocessor.inverse_transform_targets(all_outputs)
-        all_targets_original = self.preprocessor.inverse_transform_targets(all_targets)
-        original_mse = mean_squared_error(all_targets_original, all_outputs_original)
-        original_mae = mean_absolute_error(all_targets_original, all_outputs_original)
-        original_rmse = np.sqrt(original_mse)
-        
-        self.logger.info(f'Test Loss (normalized): MSE={mse:.6f}, MAE={mae:.6f}, RMSE={rmse:.6f}')
-        self.logger.info(f'Test Loss (original scale): MSE={original_mse:.6f}, MAE={original_mae:.6f}, RMSE={original_rmse:.6f}')
-        self.plot_predictions(all_outputs_original, all_targets_original)
-        return {
-            'normalized_mse': mse,
-            'normalized_mae': mae,
-            'normalized_rmse': rmse,
-            'original_mse': original_mse,
-            'original_mae': original_mae,
-            'original_rmse': original_rmse,
-            'predictions': all_outputs_original,
-            'targets': all_targets_original
-        }
-    
-    def plot_losses(self):
-        plt.figure(figsize=(10, 6))
-        plt.plot(self.train_losses, label='Train Loss')
-        plt.plot(self.val_losses, label='Validation Loss')
-        plt.xlabel('Epoch')
-        plt.ylabel('MSE Loss')
-        plt.title('Training and Validation Loss')
-        plt.legend()
-        plt.savefig('plots/loss_curve.png')
-        plt.close()
-    
-    def plot_mae(self):
-        plt.figure(figsize=(10, 6))
-        plt.plot(self.train_mae, label='Train MAE')
-        plt.plot(self.val_mae, label='Validation MAE')
-        plt.xlabel('Epoch')
-        plt.ylabel('MAE')
-        plt.title('Training and Validation MAE')
-        plt.legend()
-        plt.savefig('plots/mae_curve.png')
-        plt.close()
-    
-    def plot_predictions(self, predictions, targets):
-        num_plots = min(3, self.args.num_targets)
-        plt.figure(figsize=(15, 5*num_plots))
-        for i in range(num_plots):
-            plt.subplot(num_plots, 1, i+1)
-            plt.plot(targets[:100, i], label='True Value')
-            plt.plot(predictions[:100, i], label='Predicted Value')
-            plt.xlabel('Time Step')
-            plt.ylabel(f'Target {i+1}')
-            plt.title(f'Prediction vs True Value for Target {i+1}')
-            plt.legend()
-        plt.tight_layout()
-        plt.savefig('plots/prediction_examples.png')
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import numpy as np
+import matplotlib.pyplot as plt
+from tqdm import tqdm
+import os
+from sklearn.metrics import mean_absolute_error, mean_squared_error
+import logging
+
+class DataTrainer:
+    def __init__(self, model, args, preprocessor, optimizer=None, scheduler=None, logger=None):
+        self.model = model
+        self.args = args
+        self.preprocessor = preprocessor
+        self.device = args.device
+        self.logger = logger if logger is not None else self._default_logger()
+        
+        self.criterion = nn.MSELoss()
+        self.optimizer = optimizer if optimizer is not None else optim.Adam(
+            model.parameters(),
+            lr=args.lr,
+            weight_decay=args.weight_decay
+        )
+        self.scheduler = scheduler
+        
+        self.train_losses = []
+        self.val_losses = []
+        self.train_mae = []
+        self.val_mae = []
+        self.best_val_loss = float('inf')
+        self.early_stop_counter = 0
+        self.model_save_path = os.path.join('models', 'best_model.pth')
+        self.final_model_path = os.path.join('models', 'final_model.pth')
+        
+        if not os.path.exists('models'):
+            os.makedirs('models')
+        if not os.path.exists('plots'):
+            os.makedirs('plots')
+        
+    def _default_logger(self):
+        logger = logging.getLogger('DefaultLogger')
+        logger.setLevel(logging.INFO)
+        console_handler = logging.StreamHandler()
+        console_handler.setLevel(logging.INFO)
+        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+        console_handler.setFormatter(formatter)
+        if not logger.handlers:
+            logger.addHandler(console_handler)
+        return logger
+    
+    @torch.enable_grad()
+    def train_epoch(self, train_loader, adj, max_batches=None):
+        """
+        支持通过 max_batches 限制本 epoch 使用的批次数(用于 RL 快速评估)。
+        """
+        self.model.train()
+        total_loss = 0.0
+        all_outputs = []
+        all_targets = []
+        adj = adj.to(self.device)
+
+        for batch_idx, (data, target) in enumerate(train_loader):
+            if max_batches is not None and batch_idx >= max_batches:
+                break
+            data, target = data.to(self.device), target.to(self.device)
+            data = data.unsqueeze(-1)  # (batch_size, 145, 1)
+            
+            self.optimizer.zero_grad()
+            output = self.model(data, adj)  # (batch_size, 145, 47)
+            output = output.mean(dim=1)     # (batch_size, 47)
+            
+            loss = self.criterion(output, target)
+            loss.backward()
+            if self.args.grad_clip > 0:
+                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)
+            self.optimizer.step()
+            
+            total_loss += loss.item()
+            all_outputs.append(output.detach().cpu().numpy())
+            all_targets.append(target.detach().cpu().numpy())
+        
+        if len(all_outputs) == 0:
+            # 在极小数据/极小 batch 情况下的保护
+            return float('inf'), float('inf')
+        
+        all_outputs = np.vstack(all_outputs)
+        all_targets = np.vstack(all_targets)
+        mae = mean_absolute_error(all_targets, all_outputs)
+        
+        avg_loss = total_loss / (min(len(train_loader), max_batches) if max_batches else len(train_loader))
+        self.train_losses.append(avg_loss)
+        self.train_mae.append(mae)
+        return avg_loss, mae
+    
+    @torch.no_grad()
+    def validate(self, val_loader, adj, max_batches=None):
+        """
+        支持通过 max_batches 限制验证批次数(用于 RL 快速评估)。
+        """
+        self.model.eval()
+        total_loss = 0.0
+        all_outputs = []
+        all_targets = []
+        adj = adj.to(self.device)
+        
+        for batch_idx, (data, target) in enumerate(val_loader):
+            if max_batches is not None and batch_idx >= max_batches:
+                break
+            data, target = data.to(self.device), target.to(self.device)
+            data = data.unsqueeze(-1)  # (batch_size, 145, 1)
+            output = self.model(data, adj)  # (batch_size, 145, 47)
+            output = output.mean(dim=1)     # (batch_size, 47)
+            loss = self.criterion(output, target)
+            total_loss += loss.item()
+            all_outputs.append(output.cpu().numpy())
+            all_targets.append(target.cpu().numpy())
+        
+        if len(all_outputs) == 0:
+            return float('inf'), float('inf')
+        
+        all_outputs = np.vstack(all_outputs)
+        all_targets = np.vstack(all_targets)
+        mae = mean_absolute_error(all_targets, all_outputs)
+        avg_loss = total_loss / (min(len(val_loader), max_batches) if max_batches else len(val_loader))
+        self.val_losses.append(avg_loss)
+        self.val_mae.append(mae)
+        return avg_loss, mae
+    
+    def train(self, train_loader, val_loader, adj, epochs=None):
+        adj = adj.to(self.device)
+        epochs = epochs if epochs is not None else self.args.epochs
+        self.logger.info(f"开始训练,共 {epochs} 个epoch")
+        
+        for epoch in tqdm(range(epochs)):
+            train_loss, train_mae = self.train_epoch(train_loader, adj)
+            val_loss, val_mae = self.validate(val_loader, adj)
+            if self.scheduler is not None:
+                self.scheduler.step(val_loss)
+            
+            # 保存最佳模型
+            if val_loss < self.best_val_loss:
+                self.best_val_loss = val_loss
+                torch.save({
+                    'epoch': epoch,
+                    'model_state_dict': self.model.state_dict(),
+                    'optimizer_state_dict': self.optimizer.state_dict(),
+                    'loss': val_loss,
+                    'args': self.args  # 保存训练参数
+                }, self.model_save_path)
+                self.early_stop_counter = 0
+                self.logger.info(f"已更新最佳模型到 {self.model_save_path}")  # 日志
+            else:
+                self.early_stop_counter += 1
+                if self.early_stop_counter >= self.args.patience:
+                    self.logger.info(f"早停机制触发,在第 {epoch+1} 轮停止训练")
+                    break
+            
+            if (epoch + 1) % 10 == 0:
+                self.logger.info(f'Epoch {epoch+1}/{epochs}, '
+                                 f'Train Loss: {train_loss:.6f}, Train MAE: {train_mae:.6f}, '
+                                 f'Val Loss: {val_loss:.6f}, Val MAE: {val_mae:.6f}')
+        
+        self.plot_losses()
+        self.plot_mae()
+        
+        # 加载最佳模型
+        checkpoint = torch.load(self.model_save_path, map_location=self.device)
+        self.model.load_state_dict(checkpoint['model_state_dict'])
+        self.logger.info(f"加载最佳模型(第 {checkpoint['epoch']+1} 轮,验证损失: {checkpoint['loss']:.6f})")
+        
+        # 保存最终训练完成的模型(加载最佳模型后)
+        torch.save({
+            'model_state_dict': self.model.state_dict(),
+            'optimizer_state_dict': self.optimizer.state_dict(),
+            'best_val_loss': self.best_val_loss,
+            'args': self.args
+        }, self.final_model_path)
+        self.logger.info(f"已保存最终模型到 {self.final_model_path}")
+        
+        return self.model
+    
+    @torch.no_grad()
+    def test(self, test_loader, adj):
+        self.model.eval()
+        total_loss = 0.0
+        all_outputs = []
+        all_targets = []
+        adj = adj.to(self.device)
+        
+        for data, target in test_loader:
+            data, target = data.to(self.device), target.to(self.device)
+            data = data.unsqueeze(-1)
+            output = self.model(data, adj)
+            output = output.mean(dim=1)
+            loss = self.criterion(output, target)
+            total_loss += loss.item()
+            all_outputs.append(output.cpu().numpy())
+            all_targets.append(target.cpu().numpy())
+        
+        all_outputs = np.vstack(all_outputs)
+        all_targets = np.vstack(all_targets)
+        mse = total_loss / len(test_loader)
+        mae = mean_absolute_error(all_targets, all_outputs)
+        rmse = np.sqrt(mean_squared_error(all_targets, all_outputs))
+        
+        all_outputs_original = self.preprocessor.inverse_transform_targets(all_outputs)
+        all_targets_original = self.preprocessor.inverse_transform_targets(all_targets)
+        original_mse = mean_squared_error(all_targets_original, all_outputs_original)
+        original_mae = mean_absolute_error(all_targets_original, all_outputs_original)
+        original_rmse = np.sqrt(original_mse)
+        
+        self.logger.info(f'Test Loss (normalized): MSE={mse:.6f}, MAE={mae:.6f}, RMSE={rmse:.6f}')
+        self.logger.info(f'Test Loss (original scale): MSE={original_mse:.6f}, MAE={original_mae:.6f}, RMSE={original_rmse:.6f}')
+        self.plot_predictions(all_outputs_original, all_targets_original)
+        return {
+            'normalized_mse': mse,
+            'normalized_mae': mae,
+            'normalized_rmse': rmse,
+            'original_mse': original_mse,
+            'original_mae': original_mae,
+            'original_rmse': original_rmse,
+            'predictions': all_outputs_original,
+            'targets': all_targets_original
+        }
+    
+    def plot_losses(self):
+        plt.figure(figsize=(10, 6))
+        plt.plot(self.train_losses, label='Train Loss')
+        plt.plot(self.val_losses, label='Validation Loss')
+        plt.xlabel('Epoch')
+        plt.ylabel('MSE Loss')
+        plt.title('Training and Validation Loss')
+        plt.legend()
+        plt.savefig('plots/loss_curve.png')
+        plt.close()
+    
+    def plot_mae(self):
+        plt.figure(figsize=(10, 6))
+        plt.plot(self.train_mae, label='Train MAE')
+        plt.plot(self.val_mae, label='Validation MAE')
+        plt.xlabel('Epoch')
+        plt.ylabel('MAE')
+        plt.title('Training and Validation MAE')
+        plt.legend()
+        plt.savefig('plots/mae_curve.png')
+        plt.close()
+    
+    def plot_predictions(self, predictions, targets):
+        num_plots = min(3, self.args.num_targets)
+        plt.figure(figsize=(15, 5*num_plots))
+        for i in range(num_plots):
+            plt.subplot(num_plots, 1, i+1)
+            plt.plot(targets[:100, i], label='True Value')
+            plt.plot(predictions[:100, i], label='Predicted Value')
+            plt.xlabel('Time Step')
+            plt.ylabel(f'Target {i+1}')
+            plt.title(f'Prediction vs True Value for Target {i+1}')
+            plt.legend()
+        plt.tight_layout()
+        plt.savefig('plots/prediction_examples.png')
         plt.close()