model_zoon.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import torch
  2. from torchvision.models import resnet18,resnet50, squeezenet1_0,shufflenet_v2_x1_0,shufflenet_v2_x2_0
  3. import torch.nn as nn
  4. def load_model(name:str, num_classes:int, device:torch.device,imagenet:bool=None, weights_path:str=None):
  5. """加载模型结构"""
  6. # 加载模型
  7. pretrained = True if imagenet else False
  8. if name == 'resnet50':
  9. model = resnet50(pretrained=pretrained)
  10. elif name == 'squeezenet':
  11. model = squeezenet1_0(pretrained=pretrained)
  12. elif name == 'shufflenet' or name == 'shufflenet-x1':
  13. model = shufflenet_v2_x1_0(pretrained=pretrained)
  14. elif name == 'shufflenet-x2':
  15. model = shufflenet_v2_x2_0(pretrained=False)
  16. imagenet = False
  17. print('shufflenet-x2无预训练权重,重新训练所有权重')
  18. else:
  19. raise ValueError(f"Invalid model name: {name}")
  20. # 如果采用预训练的神经网络,就需要冻结特征提取层,只训练最后几层
  21. if imagenet:
  22. for param in model.parameters():
  23. param.requires_grad = False
  24. # 替换最后的分类层以适应新的分类任务
  25. print(model)
  26. print(f"正在将模型{name}的分类层替换为{num_classes}个类别")
  27. if hasattr(model, 'fc'):
  28. # ResNet系列模型
  29. model.fc = nn.Linear(int(model.fc.in_features), num_classes, bias=True)
  30. elif hasattr(model, 'classifier'):
  31. # SqueezeNet、ShuffleNet系列模型
  32. if name == 'squeezenet':
  33. # 获取SqueezeNet的最后一个卷积层的输入通道数
  34. final_conv_in_channels = model.classifier[1].in_channels
  35. # 替换classifier为新的Sequential,将输出改为2类
  36. model.classifier = nn.Sequential(
  37. nn.Dropout(p=0.5),
  38. nn.Conv2d(final_conv_in_channels, num_classes, kernel_size=(1, 1)),
  39. nn.ReLU(inplace=True),
  40. nn.AdaptiveAvgPool2d((1, 1))
  41. )
  42. else:
  43. # Swin Transformer等模型
  44. model.classifier = nn.Linear(int(model.classifier.in_features), num_classes, bias=True)
  45. elif hasattr(model, 'head'):
  46. # Swin Transformer使用head层
  47. model.head = nn.Linear(int(model.head.in_features), num_classes, bias=True)
  48. else:
  49. raise ValueError(f"Model {name} does not have recognizable classifier layer")
  50. print(f'模型{name}结构已经加载,移动到设备{device}')
  51. if weights_path:
  52. model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
  53. # 将模型移动到GPU/cpu
  54. model = model.to(device)
  55. return model