# 微调pytorch的预训练模型,在自己的数据上训练,完成分类任务。 import time import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader import torchvision.transforms as transforms from torchvision.datasets import ImageFolder from torchvision.models import resnet18, ResNet18_Weights,resnet50,ResNet50_Weights, squeezenet1_0, SqueezeNet1_0_Weights,\ shufflenet_v2_x1_0, ShuffleNet_V2_X1_0_Weights, swin_v2_s, Swin_V2_S_Weights, swin_v2_b, Swin_V2_B_Weights import matplotlib.pyplot as plt import numpy as np from torch.utils.tensorboard import SummaryWriter # 添加 TensorBoard 支持 from datetime import datetime import os os.environ['CUDA_VISIBLE_DEVICES'] = '1' class Trainer: def __init__(self, batch_size, train_dir, val_dir, name, checkpoint): # 初始化 TensorBoard writer self.name = name self.checkpoint = checkpoint # 获取当前时间戳 timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") log_dir = f'runs/turbidity_{self.name}_{timestamp}' self.writer = SummaryWriter(log_dir) # 定义数据增强和处理 self.train_transforms = transforms.Compose([ transforms.Resize((256, 256)), # 调整图像大小为256x256 (ResNet输入尺寸) transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转,增加数据多样性 transforms.RandomRotation(10), # 随机旋转±10度 transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0, hue=0), # 颜色抖动 transforms.ToTensor(), # 转换为tensor并归一化到[0,1] transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet标准化 ]) # 测试集基础变换 self.val_transforms = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 创建数据集对象 self.train_dataset = ImageFolder(root=train_dir, transform=self.train_transforms) self.val_dataset = ImageFolder(root=val_dir, transform=self.val_transforms) # 创建数据加载器 (Windows环境下设置num_workers=0避免多进程问题) self.batch_size = batch_size self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=10) self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, num_workers=10) # 获取类别数量 self.num_classes = len(self.train_dataset.classes) print(f"发现 {self.num_classes} 个类别: {self.train_dataset.classes}") # 加载模型 if name == 'resnet50': self.weights = ResNet50_Weights.IMAGENET1K_V2 self.model = resnet50(weights=self.weights) elif name == 'squeezenet': self.weights = SqueezeNet1_0_Weights.IMAGENET1K_V1 self.model = squeezenet1_0(weights=self.weights) elif name == 'shufflenet': self.weights = ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1 self.model = shufflenet_v2_x1_0(weights=self.weights) elif name == 'swin_v2_s': self.weights = Swin_V2_S_Weights.IMAGENET1K_V1 self.model = swin_v2_s(weights=self.weights) elif name == 'swin_v2_b': self.weights = Swin_V2_B_Weights.IMAGENET1K_V1 self.model = swin_v2_b(weights=self.weights) else: raise ValueError(f"Invalid model name: {name}") print(self.model) # 冻结特征提取层,只训练最后几层, for param in self.model.parameters(): param.requires_grad = False # 替换最后的分类层以适应新的分类任务 if hasattr(self.model, 'fc'): # ResNet系列模型 self.model.fc = nn.Sequential( nn.Linear(int(self.model.fc.in_features), int(self.model.fc.in_features) // 2, bias=True), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Linear(int(self.model.fc.in_features) // 2, self.num_classes, bias=False) ) elif hasattr(self.model, 'classifier'): # Swin Transformer等模型 self.model.classifier = nn.Sequential( nn.Linear(int(self.model.classifier.in_features), int(self.model.classifier.in_features) // 2, bias=True), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Linear(int(self.model.classifier.in_features) // 2, self.num_classes, bias=False) ) elif hasattr(self.model, 'head'): # Swin Transformer使用head层 in_features = self.model.head.in_features self.model.head = nn.Sequential( nn.Linear(int(in_features), int(in_features) // 2, bias=True), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Linear(int(in_features) // 2, self.num_classes, bias=False) ) else: raise ValueError(f"Model {name} does not have recognizable classifier layer") # 将模型移动到GPU self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self.model.to(self.device) # 定义损失函数 self.loss = nn.CrossEntropyLoss() # 多分类常用的交叉熵损失 # 定义优化器 # 只更新requires_grad=True的参数 self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3, weight_decay=1e-4) # 基于验证损失动态调整,更智能 self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-7 ) def train_model(self): """ 单轮训练函数 Args: Returns: average_loss: 平均损失 accuracy: 准确率 """ self.model.train() # 设置模型为训练模式(启用dropout/batchnorm等) running_loss = 0.0 correct_predictions = 0 total_samples = 0 # 遍历训练数据 for inputs, labels in self.train_loader: # 将数据移到指定设备上 inputs = inputs.to(self.device) labels = labels.to(self.device) # 清零梯度缓存 self.optimizer.zero_grad() # 前向传播 outputs = self.model(inputs) loss = self.loss(outputs, labels) # 反向传播 loss.backward() # 更新参数 self.optimizer.step() # 统计信息 running_loss += loss.item() * inputs.size(0) _, predicted = torch.max(outputs.data, 1) total_samples += labels.size(0) correct_predictions += (predicted == labels).sum().item() epoch_loss = running_loss / len(self.train_loader.dataset) epoch_acc = correct_predictions / total_samples return epoch_loss, epoch_acc def validate_model(self): """ 验证模型性能 Args: Returns: average_loss: 平均损失 accuracy: 准确率 """ self.model.eval() # 设置模型为评估模式(关闭dropout/batchnorm等) running_loss = 0.0 correct_predictions = 0 total_samples = 0 # 不计算梯度,提高推理速度 with torch.no_grad(): for inputs, labels in self.val_loader: inputs = inputs.to(self.device) labels = labels.to(self.device) outputs = self.model(inputs) loss = self.loss(outputs, labels) running_loss += loss.item() * inputs.size(0) _, predicted = torch.max(outputs.data, 1) total_samples += labels.size(0) correct_predictions += (predicted == labels).sum().item() epoch_loss = running_loss / len(self.val_loader.dataset) epoch_acc = correct_predictions / total_samples return epoch_loss, epoch_acc def train_and_validate(self, num_epochs=25): """ 训练和验证 Args: num_epochs: 训练轮数 Returns: train_losses: 每轮训练损失 train_accuracies: 每轮训练准确率 val_losses: 每轮验证损失 val_accuracies: 每轮验证准确率 """ # 存储训练过程中的指标 train_losses = [] train_accuracies = [] val_losses = [] val_accuracies = [] best_val_acc = 0.0 best_val_loss = float('inf') print("开始训练...") for epoch in range(num_epochs): print(f'Epoch {epoch + 1}/{num_epochs}') print('-' * 20) # 训练阶段 train_loss, train_acc = self.train_model() print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}') # 验证阶段 val_loss, val_acc = self.validate_model() print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}') # 学习率调度 self.scheduler.step(val_loss) # 记录指标到 TensorBoard self.writer.add_scalar('Loss/Train', train_loss, epoch) self.writer.add_scalar('Loss/Validation', val_loss, epoch) self.writer.add_scalar('Accuracy/Train', train_acc, epoch) self.writer.add_scalar('Accuracy/Validation', val_acc, epoch) self.writer.add_scalar('Learning Rate', self.scheduler.get_last_lr()[0], epoch) # 记录指标 train_losses.append(train_loss) train_accuracies.append(train_acc) val_losses.append(val_loss) val_accuracies.append(val_acc) # 保存最佳模型 (基于验证准确率) if val_acc > best_val_acc: best_val_acc = val_acc torch.save(self.model.state_dict(), f'{self.name}_best_model_acc.pth') print(f"保存了新的最佳准确率模型,验证准确率: {best_val_acc:.4f}") # 保存最低验证损失模型 if val_loss < best_val_loss: best_val_loss = val_loss torch.save(self.model.state_dict(), f'{self.name}_best_model_loss.pth') print(f"保存了新的最低损失模型,验证损失: {best_val_loss:.4f}") # 关闭 TensorBoard writer self.writer.close() print(f"训练完成! 最佳验证准确率: {best_val_acc:.4f}, 最低验证损失: {best_val_loss:.4f}") return train_losses, train_accuracies, val_losses, val_accuracies if __name__ == '__main__': # 开始训练 import argparse parser = argparse.ArgumentParser('预训练模型调参') parser.add_argument('--train_dir',default='./label_data/train',help='help') parser.add_argument('--val_dir', default='./label_data/test',help='help') parser.add_argument('--model', default='resnet18',help='help') args = parser.parse_args() num_epochs = 100 trainer = Trainer(batch_size=64, train_dir=args.train_dir, val_dir=args.val_dir, name=args.model, checkpoint=False) train_losses, train_accuracies, val_losses, val_accuracies = trainer.train_and_validate(num_epochs)