forked from tracel-ai/burn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Full support for ONNX scalar operators and Constants (tracel-ai#578)
- Loading branch information
Showing
49 changed files
with
1,463 additions
and
562 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,3 +15,5 @@ pub use settings::*; | |
mod file; | ||
#[cfg(feature = "std")] | ||
pub use file::*; | ||
|
||
pub use primitive::ParamSerde; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
[package] | ||
name = "onnx-tests" | ||
version = "0.9.0" | ||
edition = "2021" | ||
|
||
[dev-dependencies] | ||
burn = { path = "../../burn" } | ||
burn-ndarray = { path = "../../burn-ndarray" } | ||
serde = { workspace = true } | ||
float-cmp = { workspace = true } | ||
|
||
[build-dependencies] | ||
burn-import = { path = "../" } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# ONNX Tests | ||
|
||
This crate contains ONNX models that are utilized in testing the conversion of ONNX to Burn source | ||
code through the `burn-import` crate. The tests are designed as end-to-end tests, ensuring that ONNX | ||
models are accurately converted into Burn source code. Of utmost importance is verifying that the | ||
converted Burn source code compiles without errors and produces the same output as the original ONNX | ||
model. | ||
|
||
Here is the directory structure of this crate: | ||
|
||
- `tests/<model>`: This directory contains the ONNX model and the Python script to generate it. | ||
- `tests/<model>/<model>.onnx`: The ONNX model is generated by the script. | ||
- `tests/<model>/<model>.py`: This is the Python script responsible for generating the ONNX model | ||
using PyTorch. | ||
- `tests/onnx_tests.rs`: This is the main test file, where all the tests are contained. | ||
- `build.rs`: This build script generates the ONNX models and is executed by `cargo test` before | ||
running the actual tests. | ||
|
||
## Adding new tests | ||
|
||
Here are the steps to add a new test: | ||
|
||
1. Add your Python script to the `tests/<model>` directory. Refer to existing scripts for examples. | ||
2. Run your Python script to generate the ONNX model and inspect the output of the model with the | ||
test data. Use the inputs and outputs in your test. | ||
3. Make sure the ONNX output contains the desired operators by verifying with the | ||
[Netron](https://github.com/lutzroeder/netron) app. Sometimes PyTorch will optimize the model and | ||
remove operators that are not necessary for the model to run. If this happens, you can disable | ||
optimization by setting `torch.onnx.export(..., do_constant_folding=False)`. | ||
4. Add an entry to the `build.rs` file to account for the generation of the new ONNX model. | ||
5. Include a test in `tests/onnx_tests.rs` to test the new ONNX model. | ||
6. Run `cargo test` to ensure your test passes. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
use burn_import::onnx::ModelGen; | ||
|
||
fn main() { | ||
// Re-run this build script if the onnx-tests directory changes. | ||
println!("cargo:rerun-if-changed=tests"); | ||
|
||
// Add onnx models. | ||
ModelGen::new() | ||
.input("tests/add/add.onnx") | ||
.input("tests/sub/sub.onnx") | ||
.input("tests/mul/mul.onnx") | ||
.input("tests/div/div.onnx") | ||
.input("tests/concat/concat.onnx") | ||
.input("tests/conv2d/conv2d.onnx") | ||
.out_dir("model/") | ||
.run_from_script(); | ||
|
||
// panic!("Purposefully failing build to output logs."); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# used to generate model: onnx-tests/tests/add/add.onnx | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self): | ||
# Declare a constant float tensor with ones | ||
self.a = torch.ones(1, 1, 1, 4) | ||
|
||
# Declare a scalar | ||
self.b = 5.0 | ||
super(Model, self).__init__() | ||
|
||
def forward(self, x, k): | ||
|
||
# Add a tensor input and a constant tensor | ||
x = x + self.a | ||
|
||
# Add a scalar constant and a scalar input | ||
d = self.b + k | ||
|
||
# Add a tensor and a scalar | ||
x = x + d | ||
|
||
return x | ||
|
||
|
||
def main(): | ||
|
||
# Export to onnx | ||
model = Model() | ||
model.eval() | ||
device = torch.device("cpu") | ||
onnx_name = "add.onnx" | ||
dummy_input = torch.randn(1, 2, 3, 4, device=device) | ||
|
||
scalar = 2.0 | ||
|
||
torch.onnx.export(model, (dummy_input, scalar), onnx_name, | ||
verbose=False, opset_version=16) | ||
|
||
print("Finished exporting model to {}".format(onnx_name)) | ||
|
||
# Output some test data for use in the test | ||
test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]]) | ||
|
||
print("Test input data: {}, {}".format(test_input, scalar)) | ||
output = model.forward(test_input, scalar) | ||
print("Test output data: {}".format(output)) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# used to generate model: onnx-tests/tests/conv2d/conv2d.onnx | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
class Model(nn.Module): | ||
def __init__(self): | ||
super(Model, self).__init__() | ||
self.conv1 = nn.Conv2d(4, 6, (3, 5), groups = 2, stride=(2, 1), padding=(4, 2), dilation=(3, 1)) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
return x | ||
|
||
def main(): | ||
|
||
# Export to onnx | ||
model = Model() | ||
model.eval() | ||
device = torch.device("cpu") | ||
|
||
file_name = "conv2d.onnx" | ||
test_input = torch.ones(2, 4, 10, 15, device=device) | ||
torch.onnx.export(model, test_input, file_name, | ||
verbose=False, opset_version=16) | ||
|
||
print("Finished exporting model to {}".format(file_name)) | ||
|
||
# Output some test data for use in the test | ||
print("Test input data shape of ones: {}".format(test_input.shape)) | ||
output = model.forward(test_input) | ||
print("Test output data shape: {}".format(output.shape)) | ||
|
||
sum = output.sum().item() | ||
|
||
print("Test output sum: {}".format(sum)) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
main() |
File renamed without changes.
Binary file not shown.
Oops, something went wrong.