torch2trt is a PyTorch to TensorRT converter which utilizes the TensorRT Python API. The converter is
Easy to use - Convert modules with a single function call
Easy to extend - Write your own layer converter in Python and register it with
If you find an issue, please let us know!
Please note, this converter has limited coverage of TensorRT / PyTorch. We created it primarily to easily optimize the models used in the JetBot project. If you find the converter helpful with other models, please let us know.
Below are some usage examples, for more check out the notebooks.
import torch from torch2trt import torch2trt from torchvision.models.alexnet import alexnet # create some regular pytorch model... model = alexnet(pretrained=True).eval().cuda() # create example data x = torch.ones((1, 3, 224, 224)).cuda() # convert to TensorRT feeding sample data as input model_trt = torch2trt(model, [x])
We can execute the returned
TRTModule just like the original PyTorch model
y = model(x) y_trt = model_trt(x) # check the output against PyTorch print(torch.max(torch.abs(y - y_trt)))
Save and load
We can save the model as a
We can load the saved model into a
from torch2trt import TRTModule model_trt = TRTModule() model_trt.load_state_dict(torch.load('alexnet_trt.pth'))
We tested the converter against these models using the test.sh script. You can generate the results by calling
The results below show the throughput in FPS. You can find the raw output, which includes latency, in the benchmarks folder.
|Model||Nano (PyTorch)||Nano (TensorRT)||Xavier (PyTorch)||Xavier (TensorRT)|
Option 1 - Without plugins
To install without compiling plugins, call the following
git clone https://github.com/NVIDIA-AI-IOT/torch2trt cd torch2trt sudo python setup.py install
Option 2 - With plugins (experimental)
To install with plugins to support some operations in PyTorch that are not natviely supported with TensorRT, call the following
Please note, this currently only includes the interpolate plugin. This plugin requires PyTorch 1.3+ for serialization.
sudo apt-get install libprotobuf* protobuf-compiler ninja-build git clone https://github.com/NVIDIA-AI-IOT/torch2trt cd torch2trt sudo python setup.py install --plugins
How does it work?
This converter works by attaching conversion functions (like
convert_ReLU) to the original PyTorch functional calls (like
torch.nn.ReLU.forward). The sample input data is passed through the network, just as before, except now whenever a registered function (
torch.nn.ReLU.forward) is encountered, the corresponding converter (
convert_ReLU) is also called afterwards. The converter is passed the arguments and return statement of the original PyTorch function, as well as the TensorRT network that is being constructed. The input tensors to the original PyTorch function are modified to have an attribute
_trt, which is the TensorRT counterpart to the PyTorch tensor. The conversion function uses this
_trt to add layers to the TensorRT network, and then sets the
_trt attribute for relevant output tensors. Once the model is fully executed, the final tensors returns are marked as outputs of the TensorRT network, and the optimized TensorRT engine is built.
How to add (or override) a converter
Here we show how to add a converter for the
ReLU module using the TensorRT python API.
import tensorrt as trt from torch2trt import tensorrt_converter @tensorrt_converter('torch.nn.ReLU.forward') def convert_ReLU(ctx): input = ctx.method_args output = ctx.method_return layer = ctx.network.add_activation(input=input._trt, type=trt.ActivationType.RELU) output._trt = layer.get_output(0)
The converter takes one argument, a
ConversionContext, which will contain the following
ctx.network- The TensorRT network that is being constructed.
ctx.method_args- Positional arguments that were passed to the specified PyTorch function. The
_trtattribute is set for relevant input tensors.
ctx.method_kwargs- Keyword arguments that were passed to the specified PyTorch function.
ctx.method_return- The value returned by the specified PyTorch function. The converter must set the
_trtattribute where relevant.
Please see this folder for more examples.