pth2onnx.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import torch
  2. import onnx
  3. import torch.onnx
  4. import onnxruntime as ort
  5. import numpy as np
  6. from torchvision.models import resnet50, shufflenet_v2_x1_0, shufflenet_v2_x2_0, squeezenet1_0
  7. from torch import nn
  8. # from simple_model import SimpleModel
  9. if __name__ == '__main__':
  10. # 载入模型框架
  11. # model = SimpleModel()
  12. # model = resnet50(pretrained=False)
  13. model = shufflenet_v2_x1_0()
  14. # model = shufflenet_v2_x2_0()
  15. # model = squeezenet1_0()
  16. model_name = "shufflenet"
  17. if model_name == "squeezenet":
  18. # 获取SqueezeNet的最后一个卷积层的输入通道数
  19. final_conv_in_channels = model.classifier[1].in_channels
  20. # 替换classifier为新的Sequential,将输出改为2类
  21. model.classifier = nn.Sequential(
  22. nn.Dropout(p=0.5),
  23. nn.Conv2d(final_conv_in_channels, 2, kernel_size=(1, 1)),
  24. nn.ReLU(inplace=True),
  25. nn.AdaptiveAvgPool2d((1, 1))
  26. )
  27. if model_name == "shufflenet":
  28. model.fc = nn.Linear(int(model.fc.in_features), 2, bias=True)
  29. model.load_state_dict(torch.load(rf'./{model_name}.pth')) # xxx.pth表示.pth文件, 这一步载入模型权重
  30. print("加载模型成功")
  31. model.eval() # 设置模型为推理模式
  32. example_input = torch.randn(1, 3, 256, 256) # [1,3,224,224]分别对应[B,C,H,W]
  33. # print(model)
  34. torch.onnx.export(model,
  35. example_input,
  36. f"{model_name}.onnx",
  37. opset_version=13,
  38. export_params=True,
  39. do_constant_folding=True,
  40. ) # xxx.onnx表示.onnx文件, 这一步导出为onnx模型, 并不做任何算子融合操作。
  41. # 验证模型
  42. onnx_model = onnx.load(f"{model_name}.onnx") # 使用不同变量名
  43. onnx.checker.check_model(onnx_model) # 验证模型完整性
  44. # 使用ONNX Runtime进行推理
  45. ort_session = ort.InferenceSession(f"{model_name}.onnx")
  46. ort_inputs = {ort_session.get_inputs()[0].name: example_input.detach().numpy()}
  47. ort_outs = ort_session.run(None, ort_inputs)
  48. # 与PyTorch原始输出对比
  49. with torch.no_grad():
  50. torch_out = model(example_input)
  51. # 检查最大误差
  52. print("输出差异最大为:", np.max(np.abs(torch_out.numpy() - ort_outs[0])))
  53. #mean_mlir = [0.485×255, 0.456×255, 0.406×255] = [123.675, 116.28, 103.53]
  54. #scale_mlir = [0.229*255, 0.224*255, 0.225*255] = [58.395, 57.12, 57.375]