import torch from torchvision.models import resnet18,resnet50, squeezenet1_0,shufflenet_v2_x1_0,shufflenet_v2_x2_0 import torch.nn as nn def load_model(name:str, num_classes:int, device:torch.device,imagenet:bool=None, weights_path:str=None): """加载模型结构""" # 加载模型 pretrained = True if imagenet else False if name == 'resnet50': model = resnet50(pretrained=pretrained) elif name == 'squeezenet': model = squeezenet1_0(pretrained=pretrained) elif name == 'shufflenet' or name == 'shufflenet-x1': model = shufflenet_v2_x1_0(pretrained=pretrained) elif name == 'shufflenet-x2': model = shufflenet_v2_x2_0(pretrained=False) imagenet = False print('shufflenet-x2无预训练权重,重新训练所有权重') else: raise ValueError(f"Invalid model name: {name}") # 如果采用预训练的神经网络,就需要冻结特征提取层,只训练最后几层 if imagenet: for param in model.parameters(): param.requires_grad = False # 替换最后的分类层以适应新的分类任务 print(model) print(f"正在将模型{name}的分类层替换为{num_classes}个类别") if hasattr(model, 'fc'): # ResNet系列模型 model.fc = nn.Linear(int(model.fc.in_features), num_classes, bias=True) elif hasattr(model, 'classifier'): # SqueezeNet、ShuffleNet系列模型 if name == 'squeezenet': # 获取SqueezeNet的最后一个卷积层的输入通道数 final_conv_in_channels = model.classifier[1].in_channels # 替换classifier为新的Sequential,将输出改为2类 model.classifier = nn.Sequential( nn.Dropout(p=0.5), nn.Conv2d(final_conv_in_channels, num_classes, kernel_size=(1, 1)), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)) ) else: # Swin Transformer等模型 model.classifier = nn.Linear(int(model.classifier.in_features), num_classes, bias=True) elif hasattr(model, 'head'): # Swin Transformer使用head层 model.head = nn.Linear(int(model.head.in_features), num_classes, bias=True) else: raise ValueError(f"Model {name} does not have recognizable classifier layer") print(f'模型{name}结构已经加载,移动到设备{device}') if weights_path: model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # 将模型移动到GPU/cpu model = model.to(device) return model