pth2onnx.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import torch
  2. import torch.onnx
  3. from torchvision.models import resnet50, ResNet50_Weights
  4. from torch import nn
  5. if __name__ == '__main__':
  6. input = torch.randn(1, 3, 256, 256) # [1,3,224,224]分别对应[B,C,H,W]
  7. # 载入模型框架
  8. model = resnet50()
  9. # model.fc = nn.Sequential(
  10. # nn.Linear(int(model.fc.in_features), int(model.fc.in_features) // 2, bias=True),
  11. # nn.ReLU(inplace=True),
  12. # nn.Dropout(0.5),
  13. # nn.Linear(int(model.fc.in_features) // 2, 2, bias=False)
  14. # )
  15. # model.load_state_dict(torch.load("resnet50_best_model_acc.pth")) # xxx.pth表示.pth文件, 这一步载入模型权重
  16. model.load_state_dict(torch.load(r'D:\code\water_turbidity_det\resnet50-11ad3fa6.pth')) # xxx.pth表示.pth文件, 这一步载入模型权重
  17. model.eval() # 设置模型为推理模式
  18. # print(model)
  19. # model = torch.jit.script(model) # 先转换为TorchScript
  20. torch.onnx.export(model,
  21. input,
  22. "resnet50_best_model_acc.onnx",
  23. training=torch.onnx.TrainingMode.EVAL,
  24. opset_version=18,
  25. export_params=True,
  26. do_constant_folding=True,
  27. input_names=['input'],
  28. output_names=['output']
  29. ) # xxx.onnx表示.onnx文件, 这一步导出为onnx模型, 并不做任何算子融合操作。
  30. # 验证模型
  31. import onnx
  32. model = onnx.load("resnet50_best_model_acc.onnx")
  33. onnx.checker.check_model(model) # 验证模型完整性
  34. #mean_mlir = [0.485×255, 0.456×255, 0.406×255] = [123.675, 116.28, 103.53]
  35. #scale_mlir = [0.229*255, 0.224*255, 0.225*255] = [58.395, 57.12, 57.375]