Export from PyTorch
In general, the procedure for model export is pretty straightforward thanks to good integration of .onnx
in PyTorch.
The code itself is simple. First we import torch and build a test model.
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Conv2d(in_channels=3, out_channels=1,
kernel_size=3, stride=2, padding=1)
def forward(self, x):
out = self.layer(x)
out = nn.functional.interpolate(out, scale_factor=2,
mode='bilinear', align_corners=True)
out = torch.nn.functional.softmax(out, dim=1)
return out
The code for the export itself is then
model = Model()
model.eval()
random_input = torch.randn(1, 3, 64, 64, dtype=torch.float32)
# you can add however many inputs your model or task requires
input_names = ["image"]
output_names = ["pred"]
torch.onnx.export(model, random_input, './model.onnx', verbose=False,
input_names=input_names, output_names=output_names,
opset_version=11)
It is important to make sure that the number of elements in input_names
is the same as the number of input arguments in your model’s forward
method. As well as that the number of return variables of the forward method is the same as the number of elements in output_names
.
Make sure that input shapes are correct since you won’t be able to change them when you import the model in Lens Studio. It just takes them from your .onnx
file.
You can download this guide as .ipynb here
Loading Models into Lens Studio
The process for importing your model into LensStudio is again straightforward. You just need to add an ML component and it will prompt you to select a file containing your model. Select your ONNX
file that you’ve exported previously and if everything is fine, the studio will prompt you to set your model’s input and output scale and bias.
In the upper left corner of this prompt you will see a compatibility table of Ops in your model and different available inference frameworks. These are used to accelerate model inference if the device that’s running your lens has the ability to utilize these frameworks and if all of the Ops are implemented in a given inference framework. If your model has unsupported Ops, you can hover over the warning icon and it will provide you with a layer name that has this issue.
Take a look at the Compatibility guide to learn more about supported Layers
One of the most useful apps for debugging ONNX files is Netron.
You can open your .onnx
file in Netron and find the Op that’s causing the issue by name.
You can then see if you can reimplement your model avoiding unsupported operations.
Remember that Lens Studio textures have values in the range of [0, 255]
so if your mode was trained with whitened input with values in [-1, 1]
you need to correctly set scale and bias parameters to reflect that.
Check out the ML Component guide for more information!
Common issues
One of the major issues is bilinear interpolation. There is a discrepancy between PyTorch and mobile inference frameworks in handling edges of interpolated image with align_corners
set to False
. So you need to make sure your model uses align_corners=True
everywhere it uses bilinear interpolation. Also you should use opset_version=11
if you have align_corners=True
in your model, since default 9th opset doesn’t have this parameter in its op definition. It might also be better to use nearest neighbor interpolation or transposed convolution instead of bilinear interpolation, since iOS inference accelerator doesn’t support align_corners=True
.
x = random_input
# bad
nn.functional.interpolate(x, scale_factor=2, mode='bilinear')
nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
nn.Upsample(scale_factor=2, mode='bilinear')
# better
nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
# best
nn.ConvTranspose2d(3, 3, 3)
nn.functional.interpolate(x, scale_factor=2, mode='nearest')
nn.Upsample(scale_factor=2, mode='nearest')
Pay attention to your .view()
and .reshape()
operations in your network. Main issue with those is that while PyTorch uses NCHW format for its 4 dimensional tensors, different mobile inference frameworks can use either NCHW or NHWC format. If possible replace them with operations that preserve tensor dimensions.
# bad: NCHW dim order is assumed
features = features.view(batch_size, num_channels, height, width)
result = conv(features)
# ok: dim order is irrelevant
flattened_features = features.view(batch_size, -1)
result = flattened_features.sum(dim=1)
Some PyTorch versions have issues with different ONNX opsets. If you are encountering issues exporting model with interpolation, softmax layer with set dim parameter, try to update your PyTorch to the latest available version and set opset_version=11
parameter in your torch.onnx.export
function call.