burn-import
is a crate designed to simplify the process of importing models trained in other
machine learning frameworks into the Burn framework. This tool generates a Rust source file that
aligns the imported model with Burn's model and converts tensor data into a format compatible with
Burn.
Currently, burn-import
supports importing ONNX models with a limited set of operators, as it is
still under development.
List taken from here
- Abs
- Acos
- Acosh
- Add
- And
- ArgMax
- ArgMin
- Asin
- Asinh
- Atan
- Atanh
- AveragePool1d
- AveragePool2d
- BatchNormalization
- Bernoulli
- BitShift
- BitwiseAnd
- BitwiseNot
- BitwiseOr
- BitwiseXor
- BlackmanWindow
- Cast
- CastLike
- Ceil
- Celu
- CenterCropPad
- Clip
- Col
- Compress
- Concat
- ConcatFromSequence
- Constant
- ConstantOfShape
- Conv
- Conv1d
- Conv2d
- ConvInteger
- ConvTranspose
- Cos
- Cosh
- CumSum
- DepthToSpace
- DequantizeLinear
- Det
- DFT
- Div
- Dropout
- DynamicQuantizeLinear
- Einsum
- Elu
- Equal
- Erf
- Exp
- Expand
- EyeLike
- Flatten
- Floor
- Gather
- GatherElements
- GatherND
- Gelu
- Gemm (Linear Layer)
- GlobalAveragePool
- GlobalLpPool
- GlobalMaxPool
- Greater
- GreaterOrEqual
- GridSample
- GroupNormalization
- GRU
- HammingWindow
- HannWindow
- Hardmax
- HardSigmoid
- HardSwish
- Identity
- If
- Im
- InstanceNormalization
- IsInf
- IsNaN
- LayerNormalization
- LeakyRelu
- Less
- LessOrEqual
- Linear
- Log
- LogSoftmax
- Loop
- LpNormalization
- LpPool
- LRN
- LSTM
- MatMul
- MatMulInteger
- Max
- MaxPool1d
- MaxPool2d
- MaxRoiPool
- MaxUnpool
- Mean
- MeanVarianceNormalization
- MelWeightMatrix
- Min
- Mish
- Mod
- Mul
- Multinomial
- Neg
- NegativeLogLikelihoodLoss
- NonMaxSuppression
- NonZero
- Not
- OneHot
- Optional
- OptionalGetElement
- OptionalHasElement
- Or
- Pad
- Pow
- PRelu
- QLinearConv
- QLinearMatMul
- QuantizeLinear
- RandomNormal
- RandomNormalLike
- RandomUniform
- RandomUniformLike
- Range
- Reciprocal
- ReduceL
- ReduceLogSum
- ReduceLogSumExp
- ReduceMax
- ReduceMean
- ReduceMin
- ReduceProd
- ReduceSum
- ReduceSumSquare
- Relu
- Reshape
- Resize
- ReverseSequence
- RNN
- RoiAlign
- Round
- Scan
- Scatter
- ScatterElements
- ScatterND
- Selu
- SequenceAt
- SequenceConstruct
- SequenceEmpty
- SequenceErase
- SequenceInsert
- SequenceLength
- SequenceMap
- Shape
- Shrink
- Sigmoid
- Sign
- Sin
- Sinh
- Size
- Slice
- Softmax
- SoftmaxCrossEntropyLoss
- Softplus
- Softsign
- SpaceToDepth
- Split
- SplitToSequence
- Sqrt
- Squeeze
- STFT
- StringNormalizer
- Sub
- Sum
- Tan
- Tanh
- TfIdfVectorizer
- ThresholdedRelu
- Tile
- TopK
- Transpose
- Trilu
- Unique
- Unsqueeze
- Upsample
- Where
- Xor
To import ONNX models, follow these steps:
-
Add the following code to your
build.rs
file:use burn_import::onnx::ModelGen; fn main() { // Generate the model code and state file from the ONNX file. ModelGen::new() .input("src/model/mnist.onnx") // Path to the ONNX model .out_dir("model/") // Directory for the generated Rust source file (under target/) .run_from_script(); }
-
Add the following code to the
mod.rs
file undersrc/model
:pub mod mnist { include!(concat!(env!("OUT_DIR"), "/model/mnist.rs")); }
-
Use the imported model in your code as shown below:
mod model; use burn::tensor; use burn_ndarray::NdArrayBackend; use model::mnist::Model; fn main() { // Create a new model let model: Model<NdArrayBackend<f32>> = Model::new(); // Create a new input tensor (all zeros for demonstration purposes) let input = tensor::Tensor::<NdArrayBackend<f32>, 4>::zeros([1, 1, 28, 28]); // Run the model let output = model.forward(input); // Print the output println!("{:?}", output); }
A working example can be found in the
examples/onnx-inference
directory.
To add support for new operators to burn-import
, follow these steps:
-
Optimize the ONNX model using onnxoptimizer. This will remove unnecessary operators and constants, making the model easier to understand.
-
Use the Netron app to visualize the ONNX model.
-
Generate artifact files for the ONNX model (
my-model.onnx
) and its components:cargo r -- ./my-model.onnx ./
-
Implement the missing operators when you encounter an error stating that the operator is not supported. Ideally, the
my-model.graph.txt
file is generated before the error occurs, providing information about the ONNX model. -
The newly generated
my-model.graph.txt
file contains IR information about the model, while themy-model.rs
file contains an actual Burn model in Rust code. Themy-model.json
file contains the model data. -
The
srs/onnx
directory contains the following ONNX modules (continued):coalesce.rs
: Coalesces multiple ONNX operators into a single Burn operator. This is useful for operators that are not supported by Burn but can be represented by a combination of supported operators.op_configuration.rs
: Contains helper functions for configuring Burn operators from operator nodes.shape_inference.rs
: Contains helper functions for inferring shapes of tensors for inputs and outputs of operators.
-
Add unit tests for the new operator in the
burn-import/tests/onnx_tests.rs
file. Add the ONNX file and expected output to thetests/data
directory. Ensure the ONNX file is small, as large files can increase repository size and make it difficult to maintain and clone. Refer to existing unit tests for examples.