train.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # 微调pytorch的预训练模型,在自己的数据上训练,完成分类任务。
  2. import time
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. from torch.utils.data import DataLoader
  7. import torchvision.transforms as transforms
  8. from torchvision.datasets import ImageFolder
  9. from torchvision.models import resnet18, ResNet18_Weights,resnet50,ResNet50_Weights, squeezenet1_0, SqueezeNet1_0_Weights,\
  10. shufflenet_v2_x1_0, ShuffleNet_V2_X1_0_Weights, swin_v2_s, Swin_V2_S_Weights, swin_v2_b, Swin_V2_B_Weights
  11. import matplotlib.pyplot as plt
  12. import numpy as np
  13. from torch.utils.tensorboard import SummaryWriter # 添加 TensorBoard 支持
  14. from datetime import datetime
  15. import os
  16. os.environ['CUDA_VISIBLE_DEVICES'] = '1'
  17. class Trainer:
  18. def __init__(self, batch_size, train_dir, val_dir, name, checkpoint):
  19. # 初始化 TensorBoard writer
  20. self.name = name
  21. self.checkpoint = checkpoint
  22. # 获取当前时间戳
  23. timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
  24. log_dir = f'runs/turbidity_{self.name}_{timestamp}'
  25. self.writer = SummaryWriter(log_dir)
  26. # 定义数据增强和处理
  27. self.train_transforms = transforms.Compose([
  28. transforms.Resize((256, 256)), # 调整图像大小为256x256 (ResNet输入尺寸)
  29. transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转,增加数据多样性
  30. transforms.RandomRotation(10), # 随机旋转±10度
  31. transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0, hue=0), # 颜色抖动
  32. transforms.ToTensor(), # 转换为tensor并归一化到[0,1]
  33. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet标准化
  34. ])
  35. # 测试集基础变换
  36. self.val_transforms = transforms.Compose([
  37. transforms.Resize((256, 256)),
  38. transforms.ToTensor(),
  39. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  40. ])
  41. # 创建数据集对象
  42. self.train_dataset = ImageFolder(root=train_dir, transform=self.train_transforms)
  43. self.val_dataset = ImageFolder(root=val_dir, transform=self.val_transforms)
  44. # 创建数据加载器 (Windows环境下设置num_workers=0避免多进程问题)
  45. self.batch_size = batch_size
  46. self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=10)
  47. self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, num_workers=10)
  48. # 获取类别数量
  49. self.num_classes = len(self.train_dataset.classes)
  50. print(f"发现 {self.num_classes} 个类别: {self.train_dataset.classes}")
  51. # 加载模型
  52. if name == 'resnet50':
  53. self.weights = ResNet50_Weights.IMAGENET1K_V2
  54. self.model = resnet50(weights=self.weights)
  55. elif name == 'squeezenet':
  56. self.weights = SqueezeNet1_0_Weights.IMAGENET1K_V1
  57. self.model = squeezenet1_0(weights=self.weights)
  58. elif name == 'shufflenet':
  59. self.weights = ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1
  60. self.model = shufflenet_v2_x1_0(weights=self.weights)
  61. elif name == 'swin_v2_s':
  62. self.weights = Swin_V2_S_Weights.IMAGENET1K_V1
  63. self.model = swin_v2_s(weights=self.weights)
  64. elif name == 'swin_v2_b':
  65. self.weights = Swin_V2_B_Weights.IMAGENET1K_V1
  66. self.model = swin_v2_b(weights=self.weights)
  67. else:
  68. raise ValueError(f"Invalid model name: {name}")
  69. print(self.model)
  70. # 冻结特征提取层,只训练最后几层,
  71. for param in self.model.parameters():
  72. param.requires_grad = False
  73. # 替换最后的分类层以适应新的分类任务
  74. if hasattr(self.model, 'fc'):
  75. # ResNet系列模型
  76. self.model.fc = nn.Sequential(
  77. nn.Linear(int(self.model.fc.in_features), int(self.model.fc.in_features) // 2, bias=True),
  78. nn.ReLU(inplace=True),
  79. nn.Dropout(0.5),
  80. nn.Linear(int(self.model.fc.in_features) // 2, self.num_classes, bias=False)
  81. )
  82. elif hasattr(self.model, 'classifier'):
  83. # Swin Transformer等模型
  84. self.model.classifier = nn.Sequential(
  85. nn.Linear(int(self.model.classifier.in_features), int(self.model.classifier.in_features) // 2,
  86. bias=True),
  87. nn.ReLU(inplace=True),
  88. nn.Dropout(0.5),
  89. nn.Linear(int(self.model.classifier.in_features) // 2, self.num_classes, bias=False)
  90. )
  91. elif hasattr(self.model, 'head'):
  92. # Swin Transformer使用head层
  93. in_features = self.model.head.in_features
  94. self.model.head = nn.Sequential(
  95. nn.Linear(int(in_features), int(in_features) // 2, bias=True),
  96. nn.ReLU(inplace=True),
  97. nn.Dropout(0.5),
  98. nn.Linear(int(in_features) // 2, self.num_classes, bias=False)
  99. )
  100. else:
  101. raise ValueError(f"Model {name} does not have recognizable classifier layer")
  102. # 将模型移动到GPU
  103. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  104. self.model = self.model.to(self.device)
  105. # 定义损失函数
  106. self.loss = nn.CrossEntropyLoss() # 多分类常用的交叉熵损失
  107. # 定义优化器
  108. # 只更新requires_grad=True的参数
  109. self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
  110. # 基于验证损失动态调整,更智能
  111. self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
  112. self.optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-7
  113. )
  114. def train_model(self):
  115. """
  116. 单轮训练函数
  117. Args:
  118. Returns:
  119. average_loss: 平均损失
  120. accuracy: 准确率
  121. """
  122. self.model.train() # 设置模型为训练模式(启用dropout/batchnorm等)
  123. running_loss = 0.0
  124. correct_predictions = 0
  125. total_samples = 0
  126. # 遍历训练数据
  127. for inputs, labels in self.train_loader:
  128. # 将数据移到指定设备上
  129. inputs = inputs.to(self.device)
  130. labels = labels.to(self.device)
  131. # 清零梯度缓存
  132. self.optimizer.zero_grad()
  133. # 前向传播
  134. outputs = self.model(inputs)
  135. loss = self.loss(outputs, labels)
  136. # 反向传播
  137. loss.backward()
  138. # 更新参数
  139. self.optimizer.step()
  140. # 统计信息
  141. running_loss += loss.item() * inputs.size(0)
  142. _, predicted = torch.max(outputs.data, 1)
  143. total_samples += labels.size(0)
  144. correct_predictions += (predicted == labels).sum().item()
  145. epoch_loss = running_loss / len(self.train_loader.dataset)
  146. epoch_acc = correct_predictions / total_samples
  147. return epoch_loss, epoch_acc
  148. def validate_model(self):
  149. """
  150. 验证模型性能
  151. Args:
  152. Returns:
  153. average_loss: 平均损失
  154. accuracy: 准确率
  155. """
  156. self.model.eval() # 设置模型为评估模式(关闭dropout/batchnorm等)
  157. running_loss = 0.0
  158. correct_predictions = 0
  159. total_samples = 0
  160. # 不计算梯度,提高推理速度
  161. with torch.no_grad():
  162. for inputs, labels in self.val_loader:
  163. inputs = inputs.to(self.device)
  164. labels = labels.to(self.device)
  165. outputs = self.model(inputs)
  166. loss = self.loss(outputs, labels)
  167. running_loss += loss.item() * inputs.size(0)
  168. _, predicted = torch.max(outputs.data, 1)
  169. total_samples += labels.size(0)
  170. correct_predictions += (predicted == labels).sum().item()
  171. epoch_loss = running_loss / len(self.val_loader.dataset)
  172. epoch_acc = correct_predictions / total_samples
  173. return epoch_loss, epoch_acc
  174. def train_and_validate(self, num_epochs=25):
  175. """
  176. 训练和验证
  177. Args:
  178. num_epochs: 训练轮数
  179. Returns:
  180. train_losses: 每轮训练损失
  181. train_accuracies: 每轮训练准确率
  182. val_losses: 每轮验证损失
  183. val_accuracies: 每轮验证准确率
  184. """
  185. # 存储训练过程中的指标
  186. train_losses = []
  187. train_accuracies = []
  188. val_losses = []
  189. val_accuracies = []
  190. best_val_acc = 0.0
  191. best_val_loss = float('inf')
  192. print("开始训练...")
  193. for epoch in range(num_epochs):
  194. print(f'Epoch {epoch + 1}/{num_epochs}')
  195. print('-' * 20)
  196. # 训练阶段
  197. train_loss, train_acc = self.train_model()
  198. print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')
  199. # 验证阶段
  200. val_loss, val_acc = self.validate_model()
  201. print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')
  202. # 学习率调度
  203. self.scheduler.step(val_loss)
  204. # 记录指标到 TensorBoard
  205. self.writer.add_scalar('Loss/Train', train_loss, epoch)
  206. self.writer.add_scalar('Loss/Validation', val_loss, epoch)
  207. self.writer.add_scalar('Accuracy/Train', train_acc, epoch)
  208. self.writer.add_scalar('Accuracy/Validation', val_acc, epoch)
  209. self.writer.add_scalar('Learning Rate', self.scheduler.get_last_lr()[0], epoch)
  210. # 记录指标
  211. train_losses.append(train_loss)
  212. train_accuracies.append(train_acc)
  213. val_losses.append(val_loss)
  214. val_accuracies.append(val_acc)
  215. # 保存最佳模型 (基于验证准确率)
  216. if val_acc > best_val_acc:
  217. best_val_acc = val_acc
  218. torch.save(self.model.state_dict(), f'{self.name}_best_model_acc.pth')
  219. print(f"保存了新的最佳准确率模型,验证准确率: {best_val_acc:.4f}")
  220. # 保存最低验证损失模型
  221. if val_loss < best_val_loss:
  222. best_val_loss = val_loss
  223. torch.save(self.model.state_dict(), f'{self.name}_best_model_loss.pth')
  224. print(f"保存了新的最低损失模型,验证损失: {best_val_loss:.4f}")
  225. # 关闭 TensorBoard writer
  226. self.writer.close()
  227. print(f"训练完成! 最佳验证准确率: {best_val_acc:.4f}, 最低验证损失: {best_val_loss:.4f}")
  228. return train_losses, train_accuracies, val_losses, val_accuracies
  229. if __name__ == '__main__':
  230. # 开始训练
  231. import argparse
  232. parser = argparse.ArgumentParser('预训练模型调参')
  233. parser.add_argument('--train_dir',default='./label_data/train',help='help')
  234. parser.add_argument('--val_dir', default='./label_data/test',help='help')
  235. parser.add_argument('--model', default='resnet18',help='help')
  236. args = parser.parse_args()
  237. num_epochs = 100
  238. trainer = Trainer(batch_size=64,
  239. train_dir=args.train_dir,
  240. val_dir=args.val_dir,
  241. name=args.model,
  242. checkpoint=False)
  243. train_losses, train_accuracies, val_losses, val_accuracies = trainer.train_and_validate(num_epochs)