train.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # 微调pytorch的预训练模型,在自己的数据上训练,完成分类任务。
  2. import time
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. from fontTools.misc.timeTools import epoch_diff
  7. from torch.utils.data import DataLoader
  8. import torchvision.transforms as transforms
  9. from torchvision.datasets import ImageFolder
  10. from torchvision.models import resnet18, ResNet18_Weights,resnet50,ResNet50_Weights, squeezenet1_0, SqueezeNet1_0_Weights,\
  11. shufflenet_v2_x1_0, ShuffleNet_V2_X1_0_Weights, swin_v2_s, Swin_V2_S_Weights, swin_v2_b, Swin_V2_B_Weights
  12. import matplotlib.pyplot as plt
  13. import numpy as np
  14. from torch.utils.tensorboard import SummaryWriter # 添加 TensorBoard 支持
  15. from datetime import datetime
  16. import os
  17. from dotenv import load_dotenv
  18. os.environ['CUDA_VISIBLE_DEVICES'] = os.getenv('CUDA_VISIBLE_DEVICES', '0')
  19. class Trainer:
  20. def __init__(self, batch_size, train_dir, val_dir, name, checkpoint):
  21. # 定义一些参数
  22. self.name = name # 采用的模型名称
  23. self.img_size = int(os.getenv('IMG_INPUT_SIZE', 224)) # 输入图片尺寸
  24. self.batch_size = batch_size # 批次大小
  25. self.cls_map = {"0": "non-muddy", "1":"muddy"} # 类别名称映射词典
  26. self.imagenet = True # 是否使用ImageNet预训练权重
  27. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 训练设备
  28. self.checkpoint = checkpoint
  29. self.__global_step = 0
  30. self.workers = int(os.getenv('WORKERS', 0))
  31. # 创建日志目录
  32. timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间戳
  33. log_dir = f'runs/turbidity_{self.name}_{timestamp}'# 按照模型的名称和时间戳创建日志目录
  34. self.writer = SummaryWriter(log_dir) # 创建 TensorBoard writer
  35. # 定义数据增强和预处理层
  36. self.train_transforms = transforms.Compose([
  37. transforms.Resize((self.img_size, self.img_size)), # 调整图像大小为256x256 (ResNet输入尺寸)
  38. transforms.RandomHorizontalFlip(p=0.3), # 随机水平翻转,增加数据多样性
  39. transforms.RandomVerticalFlip(p=0.3), # 随机垂直翻转,增加数据多样性
  40. transforms.RandomGrayscale(p=0.25), # 随机灰度化,增加数据多样性
  41. transforms.RandomRotation(10), # 随机旋转±10度
  42. transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0, hue=0), # 颜色抖动
  43. transforms.ToTensor(), # 转换为tensor并归一化到[0,1]
  44. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet标准化
  45. ])
  46. # 测试集基础变换
  47. self.val_transforms = transforms.Compose([
  48. transforms.Resize((self.img_size, self.img_size)),
  49. transforms.ToTensor(),
  50. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  51. ])
  52. # 创建数据集对象
  53. self.train_dataset = ImageFolder(root=train_dir, transform=self.train_transforms)
  54. self.val_dataset = ImageFolder(root=val_dir, transform=self.val_transforms)
  55. # 创建数据加载器 (Windows环境下设置num_workers=0避免多进程问题)
  56. self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.workers) # 可迭代对象,返回(inputs_tensor, label)
  57. self.val_loader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.workers)
  58. # 获取类别数量
  59. self.num_classes = len(self.train_dataset.classes)
  60. print(f"自动发现 {self.num_classes} 个类别:")
  61. # 打印类别和名称
  62. for cls in self.train_dataset.classes:
  63. print(f"{cls}: {self.cls_map.get(cls,'None')}")
  64. # 创建模型
  65. self.model = None
  66. self.model = self.__load_model()
  67. # 定义损失函数
  68. self.loss = nn.CrossEntropyLoss() # 多分类常用的交叉熵损失
  69. # 定义优化器
  70. # 只更新requires_grad=True的参数
  71. self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
  72. # 基于验证损失动态调整,更智能
  73. self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
  74. self.optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-7,cooldown=2
  75. )
  76. def __load_model(self):
  77. """加载模型结构"""
  78. # 加载模型
  79. if self.name == 'resnet50':
  80. self.weights = ResNet50_Weights.IMAGENET1K_V2 if self.imagenet else None
  81. self.model = resnet50(weights=self.weights)
  82. elif self.name == 'squeezenet':
  83. self.weights = SqueezeNet1_0_Weights.IMAGENET1K_V1 if self.imagenet else None
  84. self.model = squeezenet1_0(weights=self.weights)
  85. elif self.name == 'shufflenet':
  86. self.weights = ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1 if self.imagenet else None
  87. self.model = shufflenet_v2_x1_0(weights=self.weights)
  88. elif self.name == 'swin_v2_s':
  89. self.weights = Swin_V2_S_Weights.IMAGENET1K_V1 if self.imagenet else None
  90. self.model = swin_v2_s(weights=self.weights)
  91. elif self.name == 'swin_v2_b':
  92. self.weights = Swin_V2_B_Weights.IMAGENET1K_V1 if self.imagenet else None
  93. self.model = swin_v2_b(weights=self.weights)
  94. else:
  95. raise ValueError(f"Invalid model name: {self.name}")
  96. # 如果采用预训练的神经网络,就需要冻结特征提取层,只训练最后几层
  97. if self.imagenet:
  98. for param in self.model.parameters():
  99. param.requires_grad = False
  100. # 替换最后的分类层以适应新的分类任务
  101. if hasattr(self.model, 'fc'):
  102. # ResNet系列模型
  103. self.model.fc = nn.Linear(int(self.model.fc.in_features), self.num_classes, bias=False)
  104. # self.model.fc = nn.Sequential(
  105. # nn.Linear(int(self.model.fc.in_features), int(self.model.fc.in_features) // 2, bias=False),
  106. # nn.ReLU(inplace=True),
  107. # nn.Dropout(0.5),
  108. # nn.Linear(int(self.model.fc.in_features) // 2, self.num_classes, bias=False)
  109. # )
  110. elif hasattr(self.model, 'classifier'):
  111. # Swin Transformer等模型
  112. self.model.classifier = nn.Linear(int(self.model.classifier.in_features), self.num_classes, bias=False)
  113. # self.model.classifier = nn.Sequential(
  114. # nn.Linear(int(self.model.classifier.in_features), int(self.model.classifier.in_features) // 2,
  115. # bias=True),
  116. # nn.ReLU(inplace=True),
  117. # nn.Dropout(0.5),
  118. # nn.Linear(int(self.model.classifier.in_features) // 2, self.num_classes, bias=False)
  119. # )
  120. elif hasattr(self.model, 'head'):
  121. # Swin Transformer使用head层
  122. self.model.head = nn.Linear(int(self.model.head.in_features), self.num_classes, bias=False)
  123. # in_features = self.model.head.in_features
  124. # self.model.head = nn.Sequential(
  125. # nn.Linear(int(in_features), int(in_features) // 2, bias=True),
  126. # nn.ReLU(inplace=True),
  127. # nn.Dropout(0.5),
  128. # nn.Linear(int(in_features) // 2, self.num_classes, bias=False)
  129. # )
  130. else:
  131. raise ValueError(f"Model {self.name} does not have recognizable classifier layer")
  132. print(self.model)
  133. print(f'模型{self.name}结构已经加载,移动到设备{self.device}')
  134. # 将模型移动到GPU/cpu
  135. self.model = self.model.to(self.device)
  136. return self.model
  137. def train_step(self):
  138. """
  139. 单轮训练函数
  140. Args:
  141. Returns:
  142. average_loss: 平均损失
  143. accuracy: 准确率
  144. """
  145. self.model.train() # 设置模型为训练模式(启用dropout/batchnorm等)
  146. epoch_loss = 0.0 # epoch损失
  147. correct_predictions = 0.0 # 预测正确的样本数
  148. total_samples = 0.0 # 已经训练的总样本
  149. # 遍历训练数据
  150. for inputs, labels in self.train_loader:
  151. # 将数据移到指定设备上
  152. inputs = inputs.to(self.device) # b c h w
  153. labels = labels.to(self.device) # b,
  154. # 清零梯度缓存
  155. self.optimizer.zero_grad()
  156. # 前向传播
  157. outputs = self.model(inputs) # b, 2
  158. loss = self.loss(outputs, labels) # 标量
  159. # 反向传播
  160. loss.backward()
  161. # 更新参数
  162. self.optimizer.step()
  163. # 统计信息
  164. batch_loss = loss.item() * inputs.size(0) # 批损失
  165. self.writer.add_scalar('Batch_Loss/Train', batch_loss, self.__global_step)
  166. print(f'Training | Batch Loss: {batch_loss:.4f}\t', end='\r',flush=True)
  167. epoch_loss += batch_loss # epoch损失
  168. _, predicted = torch.max(outputs.data, 1) # predicted b,存储的是预测的类别,最大值的位置
  169. total_samples += labels.size(0) # 统计总样本数量
  170. correct_predictions += (predicted == labels).sum().item() # 统计预测正确的样本数
  171. epoch_loss = epoch_loss / len(self.train_loader.dataset)
  172. epoch_acc = correct_predictions / total_samples
  173. return epoch_loss, epoch_acc
  174. def val_step(self):
  175. """
  176. 验证模型性能
  177. Args:
  178. Returns:
  179. average_loss: 平均损失
  180. accuracy: 准确率
  181. """
  182. self.model.eval() # 设置模型为评估模式(关闭dropout/batchnorm等)
  183. epoch_loss = 0.0
  184. correct_predictions = 0.
  185. total_samples = 0.
  186. # 不计算梯度,提高推理速度
  187. with torch.no_grad():
  188. for inputs, labels in self.val_loader:
  189. inputs = inputs.to(self.device)
  190. labels = labels.to(self.device)
  191. outputs = self.model(inputs)
  192. loss = self.loss(outputs, labels)
  193. epoch_loss += loss.item() * inputs.size(0)
  194. _, predicted = torch.max(outputs.data, 1)
  195. total_samples += labels.size(0)
  196. correct_predictions += (predicted == labels).sum().item()
  197. epoch_loss = epoch_loss / len(self.val_loader.dataset) # 平均损失
  198. epoch_acc = correct_predictions / total_samples
  199. return epoch_loss, epoch_acc
  200. def train_and_validate(self, num_epochs=25):
  201. """
  202. 训练和验证
  203. Args:
  204. num_epochs: 训练轮数
  205. Returns:
  206. train_losses: 每轮训练损失
  207. train_accuracies: 每轮训练准确率
  208. val_losses: 每轮验证损失
  209. val_accuracies: 每轮验证准确率
  210. """
  211. best_val_acc = 0.0
  212. best_val_loss = float('inf')
  213. print("开始训练...")
  214. for epoch in range(num_epochs):
  215. print(f'Epoch {epoch + 1}/{num_epochs}')
  216. print('-' * 20)
  217. # 单步训练
  218. train_loss, train_acc = self.train_step()
  219. print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')
  220. # 验证阶段
  221. val_loss, val_acc = self.val_step()
  222. print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')
  223. # 学习率调度
  224. self.scheduler.step(val_loss)
  225. # 记录指标到 TensorBoard
  226. self.writer.add_scalar('Loss/Train', train_loss, epoch)
  227. self.writer.add_scalar('Loss/Validation', val_loss, epoch)
  228. self.writer.add_scalar('Accuracy/Train', train_acc, epoch)
  229. self.writer.add_scalar('Accuracy/Validation', val_acc, epoch)
  230. self.writer.add_scalar('Learning Rate', self.scheduler.get_last_lr()[0], epoch)
  231. # 保存最佳模型 (基于验证准确率)
  232. if val_acc > best_val_acc:
  233. best_val_acc = val_acc
  234. torch.save(self.model.state_dict(), f'{self.name}_best_model_acc.pth')
  235. print(f"保存了新的最佳准确率模型,验证准确率: {best_val_acc:.4f}")
  236. # 保存最低验证损失模型
  237. if val_loss < best_val_loss:
  238. best_val_loss = val_loss
  239. torch.save(self.model.state_dict(), f'{self.name}_best_model_loss.pth')
  240. print(f"保存了新的最低损失模型,验证损失: {best_val_loss:.4f}")
  241. # 关闭 TensorBoard writer
  242. self.writer.close()
  243. print(f"训练完成! 最佳验证准确率: {best_val_acc:.4f}, 最低验证损失: {best_val_loss:.4f}")
  244. return 1
  245. if __name__ == '__main__':
  246. # 开始训练
  247. import argparse
  248. parser = argparse.ArgumentParser('预训练模型调参')
  249. parser.add_argument('--train_dir',default='./label_data/train',help='help')
  250. parser.add_argument('--val_dir', default='./label_data/test',help='help')
  251. parser.add_argument('--model', default='shufflenet',help='help')
  252. args = parser.parse_args()
  253. num_epochs = 100
  254. trainer = Trainer(batch_size=128,
  255. train_dir=args.train_dir,
  256. val_dir=args.val_dir,
  257. name=args.model,
  258. checkpoint=False)
  259. trainer.train_and_validate(num_epochs)