How to Convert a PyTorch Model to ONNX in 5 Minutes

Deep learning frameworks have been springing up like mushrooms after the rain. The emergence of so many deep learning frameworks raises several technical issues.  For instance, switching from one framework to another is a challenge because they have different operations and different data types. Moreover, deploying models to production or saving models can be awkward in almost all frameworks. ONNX solves these challenges by providing a standard for the operations as well as the data types. ONNX also has an inference engine package in Python that allows running inference on `onnx` models. You’ll need to install it because we’ll use it later to run inference using the `onnx` model. 

In this article, you will learn about ONNX and how to convert a ResNet-50 model to ONNX. Let’s start with an overview of ONNX. 

An open standard for ML interoperability

ONNX, short for Open Neural Network Exchange, is an open standard that enables developers to port machine learning models from different frameworks to ONNX.  This interoperability allows developers to easily move between various machine learning frameworks. ONNX  supports all the popular machine learning frameworks including Keras, TensorFlow, Scikit-learn, PyTorch, and XGBoost. ONNX also enables vendors of hardware products aimed at accelerating machine learning to focus on a single ONNX graph representation.

ONNX prevents developers from getting locked into any particular machine learning framework by providing tools that make it easy to move from one to the other. ONNX does this by:

  • Defining an extensible computation graph. Initially, various frameworks would have different graph representations. ONNX provides a standard graph representation for all of them. The ONNX graph represents the model graph through various computational nodes and can be visualized using tools such as Netron.
  • Creating standard data types. Each node in a graph usually has a certain data type. To provide interoperability between various frameworks, ONNX defines standard data types including int8, int16, and float16, just to name a few.
  • Built-in operators: These operators are responsible for mapping the operator types in ONNX to the required framework. If you are converting a PyTorch model to ONNX, all the PyTorch operators are mapped to their associated operators in ONNX. For example, a PyTorch sigmoid operation will be converted to the corresponding sigmoid operation in ONNX.
  • Provision of a single file format. Each machine learning library has its own file format. For instance, Keras models can be saved with the `h5` extension, PyTorch as `pt`, and scikit-learn models as pickle files. ONNX provides a single standard for saving and exporting model files. That format is the `onnx` file extension.

ONNX also makes it easier to optimize machine learning models using ONNX-compatible runtimes and tools that can improve the model’s performance across different hardware.

Now that you understand what ONNX is, let’s take a look at how to convert a PyTorch model to ONNX.

Converting the model to ONNX

Converting deep learning models from PyTorch to ONNX is quite straightforward. Let’s start by loading the pre-trained ResNet-50 model. 

import torch
import torchvision.models as models
model = models.resnet50(pretrained=True)

The model conversion process requires the following:

  • The model is in inference mode. This is because some operations such as batch normalization and dropout behave differently during inference and training.
  • Dummy input in the shape the model would expect. For ResNet-50 this will be in the form; [batch_size, channels, image_size, image_size] indicating the batch size, the channels of the image, and its shape. For example, on ImageNet channels is 3 and image_size is 224.
  • The input and names that you would like to use for the exported model.

Let’s start by ensuring that the model is in inference mode.

model.eval()

Next, we create that dummy input variable.

dummy_input = torch.randn(1, 3, 224, 224)

Let’s also define the input and output names.

input_names = [ "actual_input" ]
output_names = [ "output" ]

The next step is to use the `torch.onnx.export` function to convert the model to ONNX. The function expects the:

  • Model
  • Dummy input
  • Name of the exported file
  • Input names
  • Output names
  • `export_params` that determine whether the trained parameter weights will be stored in the model file
torch.onnx.export(model, 
                  dummy_input,
                  "resnet50.onnx",
                  verbose=False,
                  input_names=input_names,
                  output_names=output_names,
                  export_params=True,
                  )

That’s it, folks. You just converted the PyTorch model to ONNX!

Assuming you would like to use the model for inference, we create an inference session using the ‘onnxruntime’ python package and use it to make predictions. Here’s how it’s done.

import onnxruntime as onnxrt
onnx_session= onnxrt.InferenceSession("resnet50.onnx")
onnx_inputs= {onnx_session.get_inputs()[0].name:
to_numpy(img)}
onnx_output = onnx_session.run(None, onnx_inputs)
img_label = onnx_outputort_outs[0]

Final thoughts

In this article, you learned about ONNX and saw how easy it is to convert a PyTorch model to ONNX. If you’re using this model in a production environment, your next step will be to optimize the model to reduce its latency and increase its throughput. Sign up for free to Deci’s deep learning platform to try out the optimization. 

SIGN UP FOR FREE