data_trainer.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # data_trainer.py
  2. import torch
  3. import joblib
  4. import numpy as np
  5. import pandas as pd
  6. from sklearn.metrics import r2_score
  7. from datetime import datetime, timedelta
  8. from sklearn.preprocessing import MinMaxScaler
  9. class Trainer:
  10. def __init__(self, model, args, data):
  11. self.args = args
  12. self.model = model
  13. self.data = data
  14. self.patience = args.patience
  15. self.min_delta = args.min_delta
  16. self.counter = 0
  17. self.early_stop = False
  18. self.best_val_loss = float('inf')
  19. self.best_model_state = None
  20. self.best_epoch = 0
  21. def train_full_model(self, train_loader, val_loader, optimizer, criterion, scheduler):
  22. self.counter = 0
  23. self.best_val_loss = float('inf')
  24. self.early_stop = False
  25. self.best_model_state = None
  26. self.best_epoch = 0
  27. max_epochs = self.args.epochs
  28. for epoch in range(max_epochs):
  29. self.model.train()
  30. running_loss = 0.0
  31. for inputs, targets in train_loader:
  32. inputs = inputs.to(self.args.device)
  33. targets = targets.to(self.args.device)
  34. optimizer.zero_grad()
  35. outputs = self.model(inputs)
  36. loss = criterion(outputs, targets)
  37. loss.backward()
  38. optimizer.step()
  39. running_loss += loss.item()
  40. train_loss = running_loss / len(train_loader)
  41. val_loss = self.validate_full(val_loader, criterion) if val_loader else 0.0
  42. print(f'Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}')
  43. if val_loader:
  44. if val_loss < (self.best_val_loss - self.min_delta):
  45. self.best_val_loss = val_loss
  46. self.counter = 0
  47. self.best_model_state = self.model.state_dict()
  48. self.best_epoch = epoch
  49. else:
  50. self.counter += 1
  51. if self.counter >= self.patience:
  52. self.early_stop = True
  53. print(f"早停触发")
  54. scheduler.step()
  55. torch.cuda.empty_cache()
  56. if self.early_stop:
  57. break
  58. if self.best_model_state is not None:
  59. self.model.load_state_dict(self.best_model_state)
  60. print(f"最佳迭代: {self.best_epoch+1}, 最佳验证损失: {self.best_val_loss:.6f}")
  61. return self.model
  62. def validate_full(self, val_loader, criterion):
  63. self.model.eval()
  64. total_loss = 0.0
  65. with torch.no_grad():
  66. for inputs, targets in val_loader:
  67. inputs = inputs.to(self.args.device)
  68. targets = targets.to(self.args.device)
  69. outputs = self.model(inputs)
  70. loss = criterion(outputs, targets)
  71. total_loss += loss.item()
  72. return total_loss / len(val_loader)
  73. def save_model(self):
  74. torch.save(self.model.state_dict(), self.args.model_path)
  75. print(f"模型已保存到:{self.args.model_path}")
  76. def evaluate_model(self, test_loader, criterion):
  77. self.model.eval()
  78. scaler = joblib.load(self.args.scaler_path)
  79. predictions = []
  80. true_values = []
  81. with torch.no_grad():
  82. for inputs, targets in test_loader:
  83. inputs = inputs.to(self.args.device)
  84. targets = targets.to(self.args.device)
  85. outputs = self.model(inputs)
  86. predictions.append(outputs.cpu().numpy())
  87. true_values.append(targets.cpu().numpy())
  88. predictions = np.concatenate(predictions, axis=0)
  89. true_values = np.concatenate(true_values, axis=0)
  90. # 重塑
  91. reshaped_predictions = predictions.reshape(predictions.shape[0], self.args.output_size, self.args.labels_num)
  92. predictions = reshaped_predictions.reshape(-1, self.args.labels_num)
  93. reshaped_true_values = true_values.reshape(true_values.shape[0], self.args.output_size, self.args.labels_num)
  94. true_values = reshaped_true_values.reshape(-1, self.args.labels_num)
  95. # 反归一化 (仅标签列)
  96. column_scaler = MinMaxScaler(feature_range=(0, 1))
  97. column_scaler.min_ = scaler.min_[-self.args.labels_num:]
  98. column_scaler.scale_ = scaler.scale_[-self.args.labels_num:]
  99. true_values = column_scaler.inverse_transform(true_values)
  100. predictions = column_scaler.inverse_transform(predictions)
  101. # 定义4个核心变量
  102. column_names = [
  103. "ns=3;s=UF_TMP", # SSD跨膜压差
  104. "ns=3;s=RO_CHA1YL_SSD", # SSD_PressCha1
  105. "ns=3;s=RO_CHA2YL_SSD", # SSD_PressCha2
  106. "ns=3;s=RO_ZCS_SSD", # SSD_Flow_zcs
  107. ]
  108. # 生成时间
  109. start_datetime = datetime.strptime(self.args.test_start_date, "%Y-%m-%d")
  110. time_interval = timedelta(minutes=(4 * self.args.resolution / 60))
  111. total_points = len(predictions)
  112. date_times = [start_datetime + i * time_interval for i in range(total_points)]
  113. results = pd.DataFrame({'date': date_times})
  114. metrics_details = []
  115. for i, col_name in enumerate(column_names):
  116. if i >= self.args.labels_num: break # 防止越界
  117. results[f'{col_name}_True'] = true_values[:, i]
  118. results[f'{col_name}_Predicted'] = predictions[:, i]
  119. var_true = true_values[:, i]
  120. var_pred = predictions[:, i]
  121. # 指标计算
  122. non_zero_mask = var_true != 0
  123. var_true_nonzero = var_true[non_zero_mask]
  124. var_pred_nonzero = var_pred[non_zero_mask]
  125. if len(var_true_nonzero) > 0:
  126. r2 = r2_score(var_true_nonzero, var_pred_nonzero)
  127. rmse = np.sqrt(np.mean((var_true_nonzero - var_pred_nonzero) ** 2))
  128. mape = np.mean(np.abs((var_true_nonzero - var_pred_nonzero) / np.abs(var_true_nonzero))) * 100
  129. metrics_details.append(f"{col_name}: R2={r2:.4f}, RMSE={rmse:.4f}, MAPE={mape:.4f}%")
  130. else:
  131. metrics_details.append(f"{col_name}: 无效数据")
  132. results.to_csv(self.args.output_csv_path, index=False)
  133. txt_path = self.args.output_csv_path.replace('.csv', '_metrics.txt')
  134. with open(txt_path, 'w') as f:
  135. f.write('\n'.join(metrics_details))
  136. return metrics_details