train.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. # 微调pytorch的预训练模型,在自己的数据上训练,完成分类任务。
  2. import time
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. from torch.utils.data import DataLoader
  7. import torchvision.transforms as transforms
  8. from torchvision.datasets import ImageFolder
  9. from model.model_zoon import load_model
  10. from torch.utils.tensorboard import SummaryWriter # 添加 TensorBoard 支持
  11. from datetime import datetime
  12. import os
  13. from dotenv import load_dotenv
  14. load_dotenv()
  15. os.environ['CUDA_VISIBLE_DEVICES'] = os.getenv('CUDA_VISIBLE_DEVICES', '0')
  16. # 读取并打印.env文件中的变量
  17. def print_env_variables():
  18. print("从.env文件加载的变量:")
  19. env_vars = {
  20. 'PATCH_WIDTH': os.getenv('PATCH_WIDTH'),
  21. 'PATCH_HEIGHT': os.getenv('PATCH_HEIGHT'),
  22. 'CONFIDENCE_THRESHOLD': os.getenv('CONFIDENCE_THRESHOLD'),
  23. 'IMG_INPUT_SIZE': os.getenv('IMG_INPUT_SIZE'),
  24. 'WORKERS': os.getenv('WORKERS'),
  25. 'CUDA_VISIBLE_DEVICES': os.getenv('CUDA_VISIBLE_DEVICES')
  26. }
  27. for var, value in env_vars.items():
  28. print(f"{var}: {value}")
  29. class Trainer:
  30. def __init__(self, batch_size, train_dir, val_dir, name, checkpoint:bool=False):
  31. # 定义一些参数
  32. self.name = name # 采用的模型名称
  33. self.img_size = int(os.getenv('IMG_INPUT_SIZE', 224)) # 输入图片尺寸
  34. self.batch_size = batch_size # 批次大小
  35. self.cls_map = {"0": "non-muddy", "1":"muddy"} # 类别名称映射词典
  36. self.imagenet = os.getenv('PRETRAINED', True) # 是否使用ImageNet预训练权重
  37. # 训练设备 - 优先使用GPU,如果不可用则使用CPU
  38. if torch.cuda.is_available():
  39. try:
  40. # 尝试进行简单的CUDA操作以确认CUDA功能正常
  41. _ = torch.zeros(1).cuda()
  42. self.device = torch.device("cuda")
  43. print("成功检测到CUDA设备,使用GPU进行训练")
  44. except Exception as e:
  45. print(f"CUDA设备存在问题: {e},回退到CPU")
  46. self.device = torch.device("cpu")
  47. else:
  48. self.device = torch.device("cpu")
  49. print("CUDA不可用,使用CPU进行训练")
  50. self.__global_step = 0
  51. self.best_val_acc = 0.0
  52. self.epoch = 0
  53. self.best_val_loss = float('inf')
  54. self.checkpoint_root_path = './checkpoints'
  55. self.best_acc_model_path = os.path.join(self.checkpoint_root_path , f'{self.name}_best_model_acc.pth')
  56. self.best_loss_model_path = os.path.join(self.checkpoint_root_path , f'{self.name}_best_model_loss.pth')
  57. self.latest_checkpoint_path = os.path.join(self.checkpoint_root_path , f'{self.name}_latest_checkpoint.pth')
  58. self.workers = int(os.getenv('WORKERS', 0))
  59. # 创建日志目录
  60. timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间戳
  61. log_dir = f'runs/turbidity_{self.name}_{timestamp}'# 按照模型的名称和时间戳创建日志目录
  62. self.writer = SummaryWriter(log_dir) # 创建 TensorBoard writer
  63. # 定义数据增强和预处理层
  64. self.train_transforms = transforms.Compose([
  65. transforms.Resize((self.img_size, self.img_size)), # 调整图像大小为256x256 (ResNet输入尺寸)
  66. transforms.RandomHorizontalFlip(p=0.3), # 随机水平翻转,增加数据多样性
  67. transforms.RandomVerticalFlip(p=0.3), # 随机垂直翻转,增加数据多样性
  68. transforms.RandomGrayscale(p=0.25), # 随机灰度化,增加数据多样性
  69. transforms.RandomRotation(10), # 随机旋转±10度
  70. transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0, hue=0), # 颜色抖动
  71. transforms.ToTensor(), # 转换为tensor并归一化到[0,1]
  72. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet标准化
  73. ])
  74. # 测试集基础变换
  75. self.val_transforms = transforms.Compose([
  76. transforms.Resize((self.img_size, self.img_size)),
  77. transforms.ToTensor(),
  78. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  79. ])
  80. # 创建数据集对象
  81. self.train_dataset = ImageFolder(root=train_dir, transform=self.train_transforms)
  82. print(f"训练样本数量: {len(self.train_dataset)}")
  83. self.val_dataset = ImageFolder(root=val_dir, transform=self.val_transforms)
  84. print(f"验证样本数量: {len(self.val_dataset)}")
  85. # 创建数据加载器 (Windows环境下设置num_workers=0避免多进程问题)
  86. self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.workers) # 可迭代对象,返回(inputs_tensor, label)
  87. self.val_loader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.workers)
  88. # 获取类别数量
  89. self.num_classes = len(self.train_dataset.classes)
  90. print(f"自动发现 {self.num_classes} 个类别:")
  91. # 打印类别和名称
  92. for cls in self.train_dataset.classes:
  93. print(f"{cls}: {self.cls_map.get(cls,'None')}")
  94. # 打印训练集和测试集图像数量
  95. print(f"训练集图像数量: {len(self.train_dataset)}")
  96. print(f"验证集图像数量: {len(self.val_dataset)}")
  97. # 创建模型
  98. self.model = load_model(name=self.name, imagenet=self.imagenet, num_classes=self.num_classes, device=self.device)
  99. # 定义损失函数
  100. self.loss = nn.CrossEntropyLoss() # 多分类常用的交叉熵损失
  101. # 定义优化器
  102. # 只更新requires_grad=True的参数
  103. self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
  104. # 基于验证损失动态调整,更智能
  105. self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
  106. self.optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-7,cooldown=2
  107. )
  108. # 加载检查点
  109. if checkpoint and os.path.exists(self.latest_checkpoint_path):
  110. self.load_checkpoint()
  111. def save_checkpoint(self):
  112. if not os.path.exists(self.checkpoint_root_path ):
  113. os.makedirs(self.checkpoint_root_path )
  114. checkpoint = {
  115. 'epoch': self.epoch,
  116. 'model_state_dict': self.model.state_dict(),
  117. 'optimizer_state_dict': self.optimizer.state_dict(),
  118. 'scheduler_state_dict': self.scheduler.state_dict(),
  119. 'best_val_acc': self.best_val_acc,
  120. 'best_val_loss': self.best_val_loss
  121. }
  122. torch.save(checkpoint, self.latest_checkpoint_path)
  123. print(f"已保存检查点到 {self.latest_checkpoint_path}")
  124. def load_checkpoint(self, from_where='latest'):
  125. checkpoint = torch.load(self.latest_checkpoint_path)
  126. self.model.load_state_dict(checkpoint['model_state_dict'])
  127. self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  128. self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
  129. self.best_val_acc = checkpoint['best_val_acc']
  130. self.best_val_loss = checkpoint['best_val_loss']
  131. self.epoch = checkpoint['epoch'] + 1
  132. print(f"从 {self.latest_checkpoint_path} 加载检查点")
  133. def train_step(self):
  134. """
  135. 单轮训练函数
  136. Args:
  137. Returns:
  138. average_loss: 平均损失
  139. accuracy: 准确率
  140. """
  141. self.model.train() # 设置模型为训练模式(启用dropout/batchnorm等)
  142. epoch_loss = 0.0 # epoch损失
  143. correct_predictions = 0.0 # 预测正确的样本数
  144. total_samples = 0.0 # 已经训练的总样本
  145. # 遍历训练数据
  146. for inputs, labels in self.train_loader:
  147. # 将数据移到指定设备上
  148. inputs = inputs.to(self.device) # b c h w
  149. labels = labels.to(self.device) # b,
  150. # 清零梯度缓存
  151. self.optimizer.zero_grad()
  152. # 前向传播
  153. outputs = self.model(inputs) # b, 2
  154. loss = self.loss(outputs, labels) # 标量
  155. # 反向传播
  156. loss.backward()
  157. # 更新参数
  158. self.optimizer.step()
  159. # 统计信息
  160. batch_loss = loss.item() * inputs.size(0) # 批损失
  161. self.writer.add_scalar('Batch_Loss/Train', batch_loss, self.__global_step)
  162. print(f'Training | Batch Loss: {batch_loss:.4f}\t', end='\r',flush=True)
  163. epoch_loss += batch_loss # epoch损失
  164. _, predicted = torch.max(outputs.data, 1) # predicted b,存储的是预测的类别,最大值的位置
  165. total_samples += labels.size(0) # 统计总样本数量
  166. correct_predictions += (predicted == labels).sum().item() # 统计预测正确的样本数
  167. epoch_loss = epoch_loss / len(self.train_loader.dataset)
  168. epoch_acc = correct_predictions / total_samples
  169. return epoch_loss, epoch_acc
  170. def val_step(self):
  171. """
  172. 验证模型性能
  173. Args:
  174. Returns:
  175. average_loss: 平均损失
  176. accuracy: 准确率
  177. """
  178. self.model.eval() # 设置模型为评估模式(关闭dropout/batchnorm等)
  179. epoch_loss = 0.0
  180. correct_predictions = 0.
  181. total_samples = 0.
  182. # 不计算梯度,提高推理速度
  183. with torch.no_grad():
  184. for inputs, labels in self.val_loader:
  185. inputs = inputs.to(self.device)
  186. labels = labels.to(self.device)
  187. outputs = self.model(inputs)
  188. loss = self.loss(outputs, labels)
  189. epoch_loss += loss.item() * inputs.size(0)
  190. _, predicted = torch.max(outputs.data, 1)
  191. total_samples += labels.size(0)
  192. correct_predictions += (predicted == labels).sum().item()
  193. epoch_loss = epoch_loss / len(self.val_loader.dataset) # 平均损失
  194. epoch_acc = correct_predictions / total_samples
  195. return epoch_loss, epoch_acc
  196. def train_and_validate(self, num_epochs=25):
  197. """
  198. 训练和验证
  199. Args:
  200. num_epochs: 训练轮数
  201. Returns:
  202. train_losses: 每轮训练损失
  203. train_accuracies: 每轮训练准确率
  204. val_losses: 每轮验证损失
  205. val_accuracies: 每轮验证准确率
  206. """
  207. # 在你的代码中调用
  208. print_env_variables()
  209. print("开始训练...")
  210. for epoch in range(self.epoch, num_epochs):
  211. print(f'Epoch {epoch + 1}/{num_epochs}')
  212. print('-' * 20)
  213. # 单步训练
  214. train_loss, train_acc = self.train_step()
  215. print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')
  216. # 验证阶段
  217. val_loss, val_acc = self.val_step()
  218. print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')
  219. # 学习率调度
  220. self.scheduler.step(val_loss)
  221. # 记录指标到 TensorBoard
  222. self.writer.add_scalar('Loss/Train', train_loss, epoch)
  223. self.writer.add_scalar('Loss/Validation', val_loss, epoch)
  224. self.writer.add_scalar('Accuracy/Train', train_acc, epoch)
  225. self.writer.add_scalar('Accuracy/Validation', val_acc, epoch)
  226. self.writer.add_scalar('Learning Rate', self.optimizer.param_groups[0]['lr'], epoch)
  227. # 保存最佳模型 (基于验证准确率)
  228. if val_acc > self.best_val_acc:
  229. best_val_acc = val_acc
  230. torch.save(self.model.state_dict(), f'{self.name}_best_model_acc.pth')
  231. print(f"保存了新的最佳准确率模型,验证准确率: {best_val_acc:.4f}")
  232. # 保存最低验证损失模型
  233. if val_loss < self.best_val_loss:
  234. best_val_loss = val_loss
  235. torch.save(self.model.state_dict(), f'{self.name}_best_model_loss.pth')
  236. print(f"保存了新的最低损失模型,验证损失: {best_val_loss:.4f}")
  237. self.save_checkpoint()
  238. self.epoch += 1
  239. # 关闭 TensorBoard writer
  240. self.writer.close()
  241. print(f"训练完成! 最佳验证准确率: {self.best_val_acc:.4f}, 最低验证损失: {self.best_val_loss:.4f}")
  242. return 1
  243. if __name__ == '__main__':
  244. # 开始训练
  245. import argparse
  246. parser = argparse.ArgumentParser('预训练模型调参')
  247. parser.add_argument('--train_dir',default='./label_data/train',help='help')
  248. parser.add_argument('--val_dir', default='./label_data/test',help='help')
  249. parser.add_argument('--model', default='squeezenet',help='help')
  250. parser.add_argument('--resume', action='store_true',help='是否恢复继续训练')
  251. args = parser.parse_args()
  252. num_epochs = 100
  253. trainer = Trainer(batch_size=int(os.getenv('BATCH_SIZE', 32)),
  254. train_dir=args.train_dir,
  255. val_dir=args.val_dir,
  256. name=args.model,
  257. checkpoint=args.resume)
  258. trainer.train_and_validate(num_epochs)