# data_trainer.py import torch import joblib import numpy as np import pandas as pd from sklearn.metrics import r2_score from datetime import datetime, timedelta from sklearn.preprocessing import MinMaxScaler class Trainer: def __init__(self, model, args, data): self.args = args self.model = model self.data = data # 早停相关参数 self.patience = args.patience self.min_delta = args.min_delta self.counter = 0 self.early_stop = False self.best_val_loss = float('inf') self.best_model_state = None self.best_epoch = 0 def train_full_model(self, train_loader, val_loader, optimizer, criterion, scheduler): """联合训练所有16个子模型""" self.counter = 0 self.best_val_loss = float('inf') self.early_stop = False self.best_model_state = None self.best_epoch = 0 max_epochs = self.args.epochs for epoch in range(max_epochs): self.model.train() running_loss = 0.0 for inputs, targets in train_loader: inputs = inputs.to(self.args.device) targets = targets.to(self.args.device) # 整体目标值(包含所有16个因变量) optimizer.zero_grad() outputs = self.model(inputs) # 整体模型输出 loss = criterion(outputs, targets) # 计算整体损失 loss.backward() optimizer.step() running_loss += loss.item() train_loss = running_loss / len(train_loader) val_loss = self.validate_full(val_loader, criterion) if val_loader else 0.0 print(f'Epoch {epoch+1}/{max_epochs}, ' f'Train Loss: {train_loss:.6f}, ' f'Val Loss: {val_loss:.6f}, ' f'LR: {optimizer.param_groups[0]["lr"]:.6f}') # 早停逻辑(基于整体验证损失) if val_loader: improved = val_loss < (self.best_val_loss - self.min_delta) if improved: self.best_val_loss = val_loss self.counter = 0 self.best_model_state = self.model.state_dict() self.best_epoch = epoch else: self.counter += 1 if self.counter >= self.patience: self.early_stop = True print(f"早停触发") scheduler.step() torch.cuda.empty_cache() if self.early_stop: break # 加载最佳状态 if self.best_model_state is not None: self.model.load_state_dict(self.best_model_state) print(f"最佳迭代: {self.best_epoch+1}, 最佳验证损失: {self.best_val_loss:.6f}") return self.model def validate_full(self, val_loader, criterion): """验证整个模型(包含所有16个子模型)""" self.model.eval() total_loss = 0.0 with torch.no_grad(): for inputs, targets in val_loader: inputs = inputs.to(self.args.device) targets = targets.to(self.args.device) # 整体目标值 outputs = self.model(inputs) # 整体模型输出 loss = criterion(outputs, targets) # 整体损失计算 total_loss += loss.item() return total_loss / len(val_loader) def save_model(self): torch.save(self.model.state_dict(), self.args.model_path) print(f"模型已保存到:{self.args.model_path}") def evaluate_model(self, test_loader, criterion): self.model.eval() scaler_path = 'scaler.pkl' scaler = joblib.load(scaler_path) predictions = [] true_values = [] device = self.args.device with torch.no_grad(): for inputs, targets in test_loader: inputs = inputs.to(device) targets = targets.to(device) outputs = self.model(inputs) predictions.append(outputs.cpu().numpy()) true_values.append(targets.cpu().numpy()) predictions = np.concatenate(predictions, axis=0) true_values = np.concatenate(true_values, axis=0) # 重塑预测值和真实值形状以匹配反归一化要求 reshaped_predictions = predictions.reshape(predictions.shape[0], self.args.output_size, self.args.labels_num) predictions = reshaped_predictions.reshape(-1, self.args.labels_num) reshaped_true_values = true_values.reshape(true_values.shape[0], self.args.output_size, self.args.labels_num) true_values = reshaped_true_values.reshape(-1, self.args.labels_num) # 反归一化(仅对标签列) column_scaler = MinMaxScaler(feature_range=(0, 1)) column_scaler.min_ = scaler.min_[-self.args.labels_num:] column_scaler.scale_ = scaler.scale_[-self.args.labels_num:] true_values = column_scaler.inverse_transform(true_values) predictions = column_scaler.inverse_transform(predictions) # 定义列名(16个因变量) column_names = [ 'C.M.RO1_DB@DPT_1', 'C.M.RO2_DB@DPT_1', 'C.M.RO3_DB@DPT_1', 'C.M.RO4_DB@DPT_1', 'C.M.RO1_DB@DPT_2', 'C.M.RO2_DB@DPT_2', 'C.M.RO3_DB@DPT_2', 'C.M.RO4_DB@DPT_2', ] # 生成时间序列 start_datetime = datetime.strptime(self.args.test_start_date, "%Y-%m-%d") time_interval = timedelta(minutes=(4 * self.args.resolution / 60)) total_points = len(predictions) date_times = [start_datetime + i * time_interval for i in range(total_points)] # 保存结果到DataFrame results = pd.DataFrame({'date': date_times}) # 计算评估指标 r2_scores = {} rmse_scores = {} mape_scores = {} metrics_details = [] for i, col_name in enumerate(column_names): results[f'{col_name}_True'] = true_values[:, i] results[f'{col_name}_Predicted'] = predictions[:, i] var_true = true_values[:, i] var_pred = predictions[:, i] # 过滤零值(避免除零错误) non_zero_mask = var_true != 0 var_true_nonzero = var_true[non_zero_mask] var_pred_nonzero = var_pred[non_zero_mask] r2 = float('nan') rmse = float('nan') mape = float('nan') if len(var_true_nonzero) > 0: r2 = r2_score(var_true_nonzero, var_pred_nonzero) rmse = np.sqrt(np.mean((var_true_nonzero - var_pred_nonzero) ** 2)) mape = np.mean(np.abs((var_true_nonzero - var_pred_nonzero) / np.abs(var_true_nonzero))) * 100 r2_scores[col_name] = r2 rmse_scores[col_name] = rmse mape_scores[col_name] = mape detail = f"{col_name}:\n R方 = {r2:.6f}\n RMSE = {rmse:.6f}\n MAPE = {mape:.6f}%" metrics_details.append(detail) print(f"{col_name} R方: {r2:.6f}") else: metrics_details.append(f"{col_name}: 没有有效数据用于计算指标") print(f"{col_name} 没有有效数据用于计算R方") # 计算平均指标 valid_r2 = [score for score in r2_scores.values() if not np.isnan(score)] valid_rmse = [score for score in rmse_scores.values() if not np.isnan(score)] valid_mape = [score for score in mape_scores.values() if not np.isnan(score)] avg_r2 = np.mean(valid_r2) if valid_r2 else float('nan') avg_rmse = np.mean(valid_rmse) if valid_rmse else float('nan') avg_mape = np.mean(valid_mape) if valid_mape else float('nan') avg_detail = f"\n平均指标:\n R方 = {avg_r2:.6f}\n RMSE = {avg_rmse:.6f}\n MAPE = {avg_mape:.6f}%" if np.isnan(avg_r2): avg_detail = "\n平均指标: 没有有效的指标可用于计算平均值" metrics_details.append(avg_detail) print(avg_detail) # 保存结果 results.to_csv(self.args.output_csv_path, index=False) print(f"预测结果已保存到:{self.args.output_csv_path}") txt_path = self.args.output_csv_path.replace('.csv', '_metrics_results.txt') with open(txt_path, 'w') as f: f.write("各变量预测指标结果:\n") f.write("===================\n\n") for detail in metrics_details: f.write(detail + '\n') print(f"预测指标结果已保存到:{txt_path}") return r2_scores, rmse_scores, mape_scores