Skip to main content
Version: 4.55.1

Export from PyTorch

In general, the procedure for model export is pretty straightforward thanks to good integration of .onnx in PyTorch.

The code itself is simple. First we import torch and build a test model.

import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Conv2d(in_channels=3, out_channels=1,
kernel_size=3, stride=2, padding=1)

def forward(self, x):
out = self.layer(x)
out = nn.functional.interpolate(out, scale_factor=2,
mode='bilinear', align_corners=True)
out = torch.nn.functional.softmax(out, dim=1)
return out

The code for the export itself is then

model = Model()
model.eval()
random_input = torch.randn(1, 3, 64, 64, dtype=torch.float32)
# you can add however many inputs your model or task requires

input_names = ["image"]
output_names = ["pred"]

torch.onnx.export(model, random_input, './model.onnx', verbose=False,
input_names=input_names, output_names=output_names,
opset_version=11)

It is important to make sure that the number of elements in input_names is the same as the number of input arguments in your model’s forward method. As well as that the number of return variables of the forward method is the same as the number of elements in output_names.

Make sure that input shapes are correct since you won’t be able to change them when you import the model in Lens Studio. It just takes them from your .onnx file.

You can download this guide as .ipynb here

Loading Models into Lens Studio

The process for importing your model into LensStudio is again straightforward. You just need to add an ML component and it will prompt you to select a file containing your model. Select your ONNX file that you’ve exported previously and if everything is fine, the studio will prompt you to set your model’s input and output scale and bias.

In the upper left corner of this prompt you will see a compatibility table of Ops in your model and different available inference frameworks. These are used to accelerate model inference if the device that’s running your lens has the ability to utilize these frameworks and if all of the Ops are implemented in a given inference framework. If your model has unsupported Ops, you can hover over the warning icon and it will provide you with a layer name that has this issue.

Take a look at the Compatibility guide to learn more about supported Layers

One of the most useful apps for debugging ONNX files is Netron.

You can open your .onnx file in Netron and find the Op that’s causing the issue by name.

You can then see if you can reimplement your model avoiding unsupported operations.

Remember that Lens Studio textures have values in the range of [0, 255] so if your mode was trained with whitened input with values in [-1, 1] you need to correctly set scale and bias parameters to reflect that.

Check out the ML Component guide for more information!

Common issues

One of the major issues is bilinear interpolation. There is a discrepancy between PyTorch and mobile inference frameworks in handling edges of interpolated image with align_corners set to False. So you need to make sure your model uses align_corners=True everywhere it uses bilinear interpolation. Also you should use opset_version=11 if you have align_corners=True in your model, since default 9th opset doesn’t have this parameter in its op definition. It might also be better to use nearest neighbor interpolation or transposed convolution instead of bilinear interpolation, since iOS inference accelerator doesn’t support align_corners=True.

x = random_input
# bad
nn.functional.interpolate(x, scale_factor=2, mode='bilinear')
nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
nn.Upsample(scale_factor=2, mode='bilinear')

# better
nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

# best
nn.ConvTranspose2d(3, 3, 3)
nn.functional.interpolate(x, scale_factor=2, mode='nearest')
nn.Upsample(scale_factor=2, mode='nearest')

Pay attention to your .view() and .reshape() operations in your network. Main issue with those is that while PyTorch uses NCHW format for its 4 dimensional tensors, different mobile inference frameworks can use either NCHW or NHWC format. If possible replace them with operations that preserve tensor dimensions.

# bad: NCHW dim order is assumed
features = features.view(batch_size, num_channels, height, width)
result = conv(features)
# ok: dim order is irrelevant
flattened_features = features.view(batch_size, -1)
result = flattened_features.sum(dim=1)

Some PyTorch versions have issues with different ONNX opsets. If you are encountering issues exporting model with interpolation, softmax layer with set dim parameter, try to update your PyTorch to the latest available version and set opset_version=11 parameter in your torch.onnx.export function call.

Was this page helpful?
Yes
No