Skip to content

Commit

Permalink
Full support for ONNX scalar operators and Constants (tracel-ai#578)
Browse files Browse the repository at this point in the history
  • Loading branch information
antimora authored Aug 4, 2023
1 parent ca9a880 commit 1554a3c
Show file tree
Hide file tree
Showing 49 changed files with 1,463 additions and 562 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ members = [
"burn-dataset",
"burn-derive",
"burn-import",
"burn-import/onnx-tests",
"burn-ndarray",
"burn-no-std-tests",
"burn-tch",
Expand All @@ -29,6 +30,7 @@ dashmap = "5.4.0"
dirs = "5.0.1"
fake = "2.6.1"
flate2 = "1.0.26"
float-cmp = "0.9.0"
gix-tempfile = {version = "7.0.0", features = ["signals"]}
hashbrown = "0.14.0"
indicatif = "0.17.5"
Expand Down
6 changes: 6 additions & 0 deletions burn-core/src/module/param/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ impl<T: Clone> Param<T> {
pub fn val(&self) -> T {
self.value.clone()
}

/// Execute the given function on the inner value.
pub fn map<F: FnOnce(T) -> T>(mut self, func: F) -> Self {
self.value = func(self.value);
self
}
}

impl<T> core::ops::Deref for Param<T> {
Expand Down
54 changes: 50 additions & 4 deletions burn-core/src/module/param/constant.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use core::marker::PhantomData;

use crate::{
self as burn,
module::{ADModule, Module, ModuleMapper, ModuleVisitor},
Expand Down Expand Up @@ -135,12 +137,10 @@ impl<const D: usize, B: Backend> Module<B> for Tensor<B, D> {
}

fn into_record(self) -> Self::Record {
// Treat as a constant and do not record
ConstantRecord::new()
ConstantRecord
}

fn load_record(self, _record: Self::Record) -> Self {
// Treat as a constant and do not load
self
}
}
Expand All @@ -153,15 +153,49 @@ impl<const D: usize, B: ADBackend> ADModule<B> for Tensor<B, D> {
}
}

impl<B: Backend> Module<B> for PhantomData<B> {
type Record = ConstantRecord;

fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
// Nothing to do
}

fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
self
}

fn load_record(self, _record: Self::Record) -> Self {
self
}

fn into_record(self) -> Self::Record {
ConstantRecord::new()
}
}

impl<B: ADBackend> ADModule<B> for PhantomData<B> {
type InnerModule = PhantomData<B::InnerBackend>;

fn valid(&self) -> Self::InnerModule {
PhantomData
}
}

#[cfg(all(test, feature = "std"))]
mod tests {
use core::marker::PhantomData;

use burn_tensor::backend::Backend;
use burn_tensor::Tensor;

use crate::module::Module;
use crate::TestBackend;
use crate::{
record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
TestADBackend,
};
use burn::module::Module;

use crate as burn;

#[test]
fn tensor_load_record_setting() {
Expand All @@ -185,4 +219,16 @@ mod tests {
assert!(!no_grad_is_require_grad);
assert!(!with_default_is_require_grad);
}

#[test]
fn empty_module_with_phantom() {
#[derive(Module, Debug, new)]
struct EmptyModule<B: Backend> {
_phantom: PhantomData<B>,
}

let _module = EmptyModule::<TestBackend>::new();

assert_eq!(core::mem::size_of::<EmptyModule<TestBackend>>(), 0);
}
}
2 changes: 2 additions & 0 deletions burn-core/src/record/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ pub use settings::*;
mod file;
#[cfg(feature = "std")]
pub use file::*;

pub use primitive::ParamSerde;
6 changes: 3 additions & 3 deletions burn-import/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@ List taken from [here](https://github.com/onnx/onnx/blob/main/docs/Operators.md)
- [ ] BitwiseOr
- [ ] BitwiseXor
- [ ] BlackmanWindow
- [ ] Cast
- [x] Cast
- [ ] CastLike
- [ ] Ceil
- [ ] Celu
- [ ] CenterCropPad
- [ ] Clip
- [ ] Col
- [ ] Compress
- [ ] Concat
- [x] Concat
- [ ] ConcatFromSequence
- [ ] Constant
- [x] Constant
- [ ] ConstantOfShape
- [ ] Conv
- [ ] Conv1d
Expand Down
13 changes: 13 additions & 0 deletions burn-import/onnx-tests/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "onnx-tests"
version = "0.9.0"
edition = "2021"

[dev-dependencies]
burn = { path = "../../burn" }
burn-ndarray = { path = "../../burn-ndarray" }
serde = { workspace = true }
float-cmp = { workspace = true }

[build-dependencies]
burn-import = { path = "../" }
32 changes: 32 additions & 0 deletions burn-import/onnx-tests/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# ONNX Tests

This crate contains ONNX models that are utilized in testing the conversion of ONNX to Burn source
code through the `burn-import` crate. The tests are designed as end-to-end tests, ensuring that ONNX
models are accurately converted into Burn source code. Of utmost importance is verifying that the
converted Burn source code compiles without errors and produces the same output as the original ONNX
model.

Here is the directory structure of this crate:

- `tests/<model>`: This directory contains the ONNX model and the Python script to generate it.
- `tests/<model>/<model>.onnx`: The ONNX model is generated by the script.
- `tests/<model>/<model>.py`: This is the Python script responsible for generating the ONNX model
using PyTorch.
- `tests/onnx_tests.rs`: This is the main test file, where all the tests are contained.
- `build.rs`: This build script generates the ONNX models and is executed by `cargo test` before
running the actual tests.

## Adding new tests

Here are the steps to add a new test:

1. Add your Python script to the `tests/<model>` directory. Refer to existing scripts for examples.
2. Run your Python script to generate the ONNX model and inspect the output of the model with the
test data. Use the inputs and outputs in your test.
3. Make sure the ONNX output contains the desired operators by verifying with the
[Netron](https://github.com/lutzroeder/netron) app. Sometimes PyTorch will optimize the model and
remove operators that are not necessary for the model to run. If this happens, you can disable
optimization by setting `torch.onnx.export(..., do_constant_folding=False)`.
4. Add an entry to the `build.rs` file to account for the generation of the new ONNX model.
5. Include a test in `tests/onnx_tests.rs` to test the new ONNX model.
6. Run `cargo test` to ensure your test passes.
19 changes: 19 additions & 0 deletions burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use burn_import::onnx::ModelGen;

fn main() {
// Re-run this build script if the onnx-tests directory changes.
println!("cargo:rerun-if-changed=tests");

// Add onnx models.
ModelGen::new()
.input("tests/add/add.onnx")
.input("tests/sub/sub.onnx")
.input("tests/mul/mul.onnx")
.input("tests/div/div.onnx")
.input("tests/concat/concat.onnx")
.input("tests/conv2d/conv2d.onnx")
.out_dir("model/")
.run_from_script();

// panic!("Purposefully failing build to output logs.");
}
1 change: 1 addition & 0 deletions burn-import/onnx-tests/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Binary file added burn-import/onnx-tests/tests/add/add.onnx
Binary file not shown.
57 changes: 57 additions & 0 deletions burn-import/onnx-tests/tests/add/add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/add/add.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
# Declare a constant float tensor with ones
self.a = torch.ones(1, 1, 1, 4)

# Declare a scalar
self.b = 5.0
super(Model, self).__init__()

def forward(self, x, k):

# Add a tensor input and a constant tensor
x = x + self.a

# Add a scalar constant and a scalar input
d = self.b + k

# Add a tensor and a scalar
x = x + d

return x


def main():

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "add.onnx"
dummy_input = torch.randn(1, 2, 3, 4, device=device)

scalar = 2.0

torch.onnx.export(model, (dummy_input, scalar), onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

# Output some test data for use in the test
test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]])

print("Test input data: {}, {}".format(test_input, scalar))
output = model.forward(test_input, scalar)
print("Test output data: {}".format(output))


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pytorch2.0.1:�
pytorch2.0.1:�
P
onnx::Concat_0
onnx::Concat_0/Concat_output_0/Concat"Concat*
Expand All @@ -9,16 +9,16 @@ P
/Concat_output_0
/Concat_output_0
/Concat_output_02 /Concat_1"Concat*
axis� torch_jitZ)
onnx::Concat_0

axis� torch_jitZ(
onnx::Concat_0


�

b
2



b
2


�

B


B
19 changes: 15 additions & 4 deletions burn-import/tests/data/concat/concat.py → ...-import/onnx-tests/tests/concat/concat.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# used to generate model: burn-import/tests/data/conv2d/conv2d.onnx
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/concat/concat.onnx

import torch
import torch.nn as nn
import onnx
from onnxoptimizer import optimize

class Model(nn.Module):
def __init__(self):
Expand All @@ -24,9 +24,20 @@ def main():
model.eval()
device = torch.device("cpu")
onnx_name = "concat.onnx"
dummy_input = torch.randn(1,256,13,13, device=device)
dummy_input = torch.randn(1,2,3,5, device=device)
torch.onnx.export(model, dummy_input, onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

# Output some test data for use in the test
test_input = torch.randn(1,2,3,5, device=device)
print("Test input data shape: {}".format(test_input.shape))
output = model.forward(test_input)

print("Test output data shape: {}".format(output.shape))



if __name__ == '__main__':
main()
Binary file added burn-import/onnx-tests/tests/conv2d/conv2d.onnx
Binary file not shown.
43 changes: 43 additions & 0 deletions burn-import/onnx-tests/tests/conv2d/conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/conv2d/conv2d.onnx

import torch
import torch.nn as nn

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(4, 6, (3, 5), groups = 2, stride=(2, 1), padding=(4, 2), dilation=(3, 1))

def forward(self, x):
x = self.conv1(x)
return x

def main():

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

file_name = "conv2d.onnx"
test_input = torch.ones(2, 4, 10, 15, device=device)
torch.onnx.export(model, test_input, file_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data shape of ones: {}".format(test_input.shape))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))

sum = output.sum().item()

print("Test output sum: {}".format(sum))



if __name__ == '__main__':
main()
File renamed without changes.
Binary file added burn-import/onnx-tests/tests/div/div.onnx
Binary file not shown.
Loading

0 comments on commit 1554a3c

Please sign in to comment.