burn-import
is a crate designed to simplify the process of importing models trained in other
machine learning frameworks into the Burn framework. This tool generates a Rust source file that
aligns the imported model with Burn's model and converts tensor data into a format compatible with
Burn.
Currently, burn-import
supports importing ONNX models with a limited set of operators, as it is
still under development.
- BatchNorm
- Conv2d
- Flatten
- Gemm (Linear layer)
- LogSoftmax
- Relu
To import ONNX models, follow these steps:
-
Add the following code to your
build.rs
file:use burn_import::onnx::ModelGen; fn main() { // Generate the model code and state file from the ONNX file. ModelGen::new() .input("src/model/mnist.onnx") // Path to the ONNX model .out_dir("model/") // Directory for the generated Rust source file (under target/) .run_from_script(); }
-
Add the following code to the
mod.rs
file undersrc/model
:pub mod mnist { include!(concat!(env!("OUT_DIR"), "/model/mnist.rs")); }
-
Use the imported model in your code as shown below:
use burn::tensor; use burn_ndarray::NdArrayBackend; use onnx_inference::model::mnist::{Model, INPUT1_SHAPE}; fn main() { // Create a new model let model: Model<NdArrayBackend<f32>> = Model::new(); // Create a new input tensor (all zeros for demonstration purposes) let input = tensor::Tensor::<NdArrayBackend<f32>, 4>::zeros(INPUT1_SHAPE); // Run the model let output = model.forward(input); // Print the output println!("{:?}", output); }
A working example can be found in the
examples/onnx-inference
directory.
To add support for new operators to burn-import
, follow these steps:
-
Optimize the ONNX model using onnxoptimizer. This will remove unnecessary operators and constants, making the model easier to understand.
-
Use the Netron app to visualize the ONNX model.
-
Generate artifact files for the ONNX model (
my-model.onnx
) and its components:cargo r -- ./my-model.onnx ./
-
Implement the missing operators when you encounter an error stating that the operator is not supported. Ideally, the
my-model.graph.txt
file is generated before the error occurs, providing information about the ONNX model. -
The newly generated
my-model.graph.txt
file contains IR information about the model, while themy-model.rs
file contains an actual Burn model in Rust code. Themy-model.json
file contains the model data. -
The
srs/onnx
directory contains the following ONNX modules (continued):coalesce.rs
: Coalesces multiple ONNX operators into a single Burn operator. This is useful for operators that are not supported by Burn but can be represented by a combination of supported operators.op_configuration.rs
: Contains helper functions for configuring Burn operators from operator nodes.shape_inference.rs
: Contains helper functions for inferring shapes of tensors for inputs and outputs of operators.
-
Add unit tests for the new operator in the
burn-import/tests/onnx_tests.rs
file. Add the ONNX file and expected output to thetests/data
directory. Ensure the ONNX file is small, as large files can increase repository size and make it difficult to maintain and clone. Refer to existing unit tests for examples.