| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- import torch
- from torchvision.models import resnet18,resnet50, squeezenet1_0,shufflenet_v2_x1_0,shufflenet_v2_x2_0
- import torch.nn as nn
- def load_model(name:str, num_classes:int, device:torch.device,imagenet:bool=None, weights_path:str=None):
- """加载模型结构"""
- # 加载模型
- pretrained = True if imagenet else False
- if name == 'resnet50':
- model = resnet50(pretrained=pretrained)
- elif name == 'squeezenet':
- model = squeezenet1_0(pretrained=pretrained)
- elif name == 'shufflenet' or name == 'shufflenet-x1':
- model = shufflenet_v2_x1_0(pretrained=pretrained)
- elif name == 'shufflenet-x2':
- model = shufflenet_v2_x2_0(pretrained=False)
- imagenet = False
- print('shufflenet-x2无预训练权重,重新训练所有权重')
- else:
- raise ValueError(f"Invalid model name: {name}")
- # 如果采用预训练的神经网络,就需要冻结特征提取层,只训练最后几层
- if imagenet:
- for param in model.parameters():
- param.requires_grad = False
- # 替换最后的分类层以适应新的分类任务
- print(model)
- print(f"正在将模型{name}的分类层替换为{num_classes}个类别")
- if hasattr(model, 'fc'):
- # ResNet系列模型
- model.fc = nn.Linear(int(model.fc.in_features), num_classes, bias=True)
- elif hasattr(model, 'classifier'):
- # SqueezeNet、ShuffleNet系列模型
- if name == 'squeezenet':
- # 获取SqueezeNet的最后一个卷积层的输入通道数
- final_conv_in_channels = model.classifier[1].in_channels
- # 替换classifier为新的Sequential,将输出改为2类
- model.classifier = nn.Sequential(
- nn.Dropout(p=0.5),
- nn.Conv2d(final_conv_in_channels, num_classes, kernel_size=(1, 1)),
- nn.ReLU(inplace=True),
- nn.AdaptiveAvgPool2d((1, 1))
- )
- else:
- # Swin Transformer等模型
- model.classifier = nn.Linear(int(model.classifier.in_features), num_classes, bias=True)
- elif hasattr(model, 'head'):
- # Swin Transformer使用head层
- model.head = nn.Linear(int(model.head.in_features), num_classes, bias=True)
- else:
- raise ValueError(f"Model {name} does not have recognizable classifier layer")
- print(f'模型{name}结构已经加载,移动到设备{device}')
- if weights_path:
- model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
- # 将模型移动到GPU/cpu
- model = model.to(device)
- return model
|