| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- # data_trainer.py
- import torch
- import joblib
- import numpy as np
- import pandas as pd
- from sklearn.metrics import r2_score
- from datetime import datetime, timedelta
- from sklearn.preprocessing import MinMaxScaler
- from config import config
- class Trainer:
- def __init__(self, model, data):
- self.model = model
- self.data = data
- self.device = torch.device(f"cuda:{config.DEVICE_ID}" if torch.cuda.is_available() else "cpu")
- self.best_val_loss = float('inf')
- self.best_model_state = None
- def train_full_model(self, train_loader, val_loader, optimizer, criterion, scheduler):
- counter = 0
- for epoch in range(config.EPOCHS):
- self.model.train()
- running_loss = 0.0
-
- for inputs, targets in train_loader:
- optimizer.zero_grad()
- outputs = self.model(inputs)
- loss = criterion(outputs, targets)
- loss.backward()
- optimizer.step()
- running_loss += loss.item()
-
- train_loss = running_loss / len(train_loader)
- val_loss = self.validate_full(val_loader, criterion) if val_loader else 0.0
- print(f'Epoch {epoch+1}/{config.EPOCHS}, Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}')
- if val_loader:
- if val_loss < (self.best_val_loss - config.MIN_DELTA):
- self.best_val_loss = val_loss
- counter = 0
- self.best_model_state = self.model.state_dict()
- else:
- counter += 1
- if counter >= config.PATIENCE:
- print(f"早停触发")
- break
-
- scheduler.step()
- torch.cuda.empty_cache()
- if self.best_model_state:
- self.model.load_state_dict(self.best_model_state)
- return self.model
- def validate_full(self, val_loader, criterion):
- self.model.eval()
- total_loss = 0.0
- with torch.no_grad():
- for inputs, targets in val_loader:
- outputs = self.model(inputs)
- loss = criterion(outputs, targets)
- total_loss += loss.item()
- return total_loss / len(val_loader)
- def save_model(self):
- torch.save(self.model.state_dict(), config.MODEL_PATH)
- print(f"模型已保存到:{config.MODEL_PATH}")
-
- def evaluate_model(self, test_loader):
- self.model.eval()
- scaler = joblib.load(config.SCALER_PATH)
- predictions, true_values = [], []
-
- with torch.no_grad():
- for inputs, targets in test_loader:
- predictions.append(self.model(inputs).cpu().numpy())
- true_values.append(targets.cpu().numpy())
-
- predictions = np.concatenate(predictions, axis=0)
- true_values = np.concatenate(true_values, axis=0)
-
- predictions = predictions.reshape(-1, config.LABELS_NUM)
- true_values = true_values.reshape(-1, config.LABELS_NUM)
-
- column_scaler = MinMaxScaler(feature_range=(0, 1))
- column_scaler.min_ = scaler.min_[-config.LABELS_NUM:]
- column_scaler.scale_ = scaler.scale_[-config.LABELS_NUM:]
-
- true_values = column_scaler.inverse_transform(true_values)
- predictions = column_scaler.inverse_transform(predictions)
-
- start_datetime = datetime.strptime(config.TEST_START_DATE, "%Y-%m-%d")
- time_interval = timedelta(minutes=(4 * config.RESOLUTION / 60))
- date_times = [start_datetime + i * time_interval for i in range(len(predictions) // config.OUTPUT_SIZE)]
-
- # 扩展时间以便和输出形状对齐
- date_times = np.repeat(date_times, config.OUTPUT_SIZE)
-
- results = pd.DataFrame({'date': date_times[:len(predictions)]})
- metrics_details = []
-
- for i, col_name in enumerate(config.TARGET_COLUMNS):
- results[f'{col_name}_True'] = true_values[:, i]
- results[f'{col_name}_Predicted'] = predictions[:, i]
-
- var_true = true_values[:, i]
- var_pred = predictions[:, i]
-
- mask = var_true != 0
- if mask.sum() > 0:
- r2 = r2_score(var_true[mask], var_pred[mask])
- rmse = np.sqrt(np.mean((var_true[mask] - var_pred[mask]) ** 2))
- mape = np.mean(np.abs((var_true[mask] - var_pred[mask]) / np.abs(var_true[mask]))) * 100
- metrics_details.append(f"{col_name}: R2={r2:.4f}, RMSE={rmse:.4f}, MAPE={mape:.4f}%")
- else:
- metrics_details.append(f"{col_name}: 无效数据")
- results.to_csv(config.OUTPUT_CSV_PATH, index=False)
- with open(config.OUTPUT_CSV_PATH.replace('.csv', '_metrics.txt'), 'w') as f:
- f.write('\n'.join(metrics_details))
- return metrics_details
|