Skip to content

Commit

Permalink
Onnx tests and bug fixes - BatchNorm, Identity, Relu, Sigmoid and Tra…
Browse files Browse the repository at this point in the history
…nspose (tracel-ai#661)
  • Loading branch information
antimora authored Aug 21, 2023
1 parent a557caa commit b06fe66
Show file tree
Hide file tree
Showing 25 changed files with 869 additions and 74 deletions.
2 changes: 1 addition & 1 deletion burn-import/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ List taken from [here](https://github.com/onnx/onnx/blob/main/docs/Operators.md)
- [ ] Hardmax
- [ ] HardSigmoid
- [ ] HardSwish
- [ ] Identity
- [x] Identity
- [ ] If
- [ ] Im
- [ ] InstanceNormalization
Expand Down
7 changes: 7 additions & 0 deletions burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,28 @@ fn main() {
// Add onnx models.
ModelGen::new()
.input("tests/add/add.onnx")
.input("tests/add/add_int.onnx")
.input("tests/avg_pool2d/avg_pool2d.onnx")
.input("tests/batch_norm/batch_norm.onnx")
.input("tests/concat/concat.onnx")
.input("tests/conv1d/conv1d.onnx")
.input("tests/conv2d/conv2d.onnx")
.input("tests/div/div.onnx")
.input("tests/dropout/dropout_opset16.onnx")
.input("tests/dropout/dropout_opset7.onnx")
.input("tests/equal/equal.onnx")
.input("tests/flatten/flatten.onnx")
.input("tests/global_avr_pool/global_avr_pool.onnx")
.input("tests/log_softmax/log_softmax.onnx")
.input("tests/maxpool2d/maxpool2d.onnx")
.input("tests/mul/mul.onnx")
.input("tests/relu/relu.onnx")
.input("tests/reshape/reshape.onnx")
.input("tests/sigmoid/sigmoid.onnx")
.input("tests/softmax/softmax.onnx")
.input("tests/sub/sub.onnx")
.input("tests/sub/sub_int.onnx")
.input("tests/transpose/transpose.onnx")
.out_dir("model/")
.run_from_script();

Expand Down
Binary file added burn-import/onnx-tests/tests/add/add_int.onnx
Binary file not shown.
60 changes: 60 additions & 0 deletions burn-import/onnx-tests/tests/add/add_int.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/add/add.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):

# TODO enable this after https://github.com/burn-rs/burn/issues/665 is fixed
# Declare a constant int tensor with ones
# self.a = torch.ones(1, 1, 1, 4, dtype=torch.int32)

# Declare a scalar
self.b = 5
super(Model, self).__init__()

def forward(self, x, k):

# Add tensor inputs
x = x + x

# Add a scalar constant and a scalar input
d = self.b + k

# Add a tensor and a scalar
x = x + d

return x


def main():

# set seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "add_int.onnx"
# Output some test data for use in the test
test_input = torch.tensor([[[[1, 2, 3, 4]]]], dtype=torch.int32)

scalar = 2

torch.onnx.export(model, (test_input, scalar), onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

print("Test input data: {}, {}".format(test_input, scalar))
output = model.forward(test_input, scalar)
print("Test output data: {}".format(output))


if __name__ == '__main__':
main()
Binary file not shown.
51 changes: 51 additions & 0 deletions burn-import/onnx-tests/tests/batch_norm/batch_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python3

# used to generate model: batch_norm.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.batch_norm1d = nn.BatchNorm1d(20)
self.batch_norm2d = nn.BatchNorm2d(5)

def forward(self, x):
x = self.batch_norm1d(x)
x = x.reshape(1, 5, 2, 2)
x = self.batch_norm2d(x)
return x


def main():

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

# reproducibility
torch.manual_seed(0)

file_name = "batch_norm.onnx"
test_input = torch.ones(1, 20, 1, device=device)
torch.onnx.export(model, test_input, file_name,
# do_constant_folding=False,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data shape of ones: {}".format(test_input.shape))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))

sum = output.sum().item()

print("Test output sum: {}".format(sum))


if __name__ == '__main__':
main()
Binary file added burn-import/onnx-tests/tests/equal/equal.onnx
Binary file not shown.
52 changes: 52 additions & 0 deletions burn-import/onnx-tests/tests/equal/equal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/usr/bin/env python3

# used to generate model: equal.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
# Declare a constant float tensor with ones
self.a = torch.ones(1, 1, 1, 4)

# Declare a scalar
self.b = 5.0
super(Model, self).__init__()

def forward(self, x, k):

x = x == self.a

k = k == self.b

return x, k


def main():

# Set seed for reproducibility
torch.manual_seed(42)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "equal.onnx"
input = torch.ones(1, 1, 1, 4, device=device)

scalar = 2.0

torch.onnx.export(model, (input, scalar), onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

print("Test input data: {}, {}".format(input, scalar))
output = model.forward(input, scalar)
print("Test output data: {}".format(output))


if __name__ == '__main__':
main()
128 changes: 126 additions & 2 deletions burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,35 @@ macro_rules! include_models {
// ATTENTION: Modify this macro to include all models in the `model` directory.
include_models!(
add,
add_int,
avg_pool2d,
batch_norm,
concat,
conv1d,
conv2d,
div,
dropout_opset16,
dropout_opset7,
equal,
flatten,
global_avr_pool,
log_softmax,
maxpool2d,
mul,
relu,
reshape,
sigmoid,
softmax,
sub
sub,
sub_int,
transpose
);

#[cfg(test)]
mod tests {
use super::*;

use burn::tensor::{Data, Shape, Tensor};
use burn::tensor::{Data, Int, Shape, Tensor};

use float_cmp::ApproxEq;

Expand All @@ -53,6 +60,20 @@ mod tests {
assert_eq!(output.to_data(), expected);
}

#[test]
fn add_scalar_to_int_tensor_and_int_tensor_to_int_tensor() {
// Initialize the model with weights (loaded from the exported file)
let model: add_int::Model<Backend> = add_int::Model::default();

// Run the model
let input = Tensor::<Backend, 4, Int>::from_ints([[[[1, 2, 3, 4]]]]);
let scalar = 2;
let output = model.forward(input, scalar);
let expected = Data::from([[[[9, 11, 13, 15]]]]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn sub_scalar_from_tensor_and_tensor_from_tensor() {
// Initialize the model with weights (loaded from the exported file)
Expand All @@ -67,6 +88,19 @@ mod tests {
assert_eq!(output.to_data(), expected);
}

#[test]
fn sub_scalar_from_int_tensor_and_int_tensor_from_tensor() {
// Initialize the model with weights (loaded from the exported file)
let model: sub_int::Model<Backend> = sub_int::Model::default();

// Run the model
let input = Tensor::<Backend, 4, Int>::from_ints([[[[1, 2, 3, 4]]]]);
let scalar = 3;
let output = model.forward(input, scalar);
let expected = Data::from([[[[6, 6, 6, 6]]]]);

assert_eq!(output.to_data(), expected);
}
#[test]
fn mul_scalar_with_tensor_and_tensor_with_tensor() {
// Initialize the model with weights (loaded from the exported file)
Expand Down Expand Up @@ -323,4 +357,94 @@ mod tests {
let expected_shape = Shape::from([1, 75]);
assert_eq!(expected_shape, output.shape());
}

#[test]
fn batch_norm() {
let model: batch_norm::Model<Backend> = batch_norm::Model::default();

// Run the model with ones as input for easier testing
let input = Tensor::<Backend, 3>::ones([1, 20, 1]);
let output = model.forward(input);

let expected_shape = Shape::from([1, 5, 2, 2]);
assert_eq!(output.shape(), expected_shape);

let output_sum = output.sum().into_scalar();
let expected_sum = 19.999801635742188; // from pytorch
assert!(expected_sum.approx_eq(output_sum, (1.0e-8, 2)));
}

#[test]
fn relu() {
// Initialize the model without weights (because the exported file does not contain them)
let model: relu::Model<Backend> = relu::Model::new();

// Run the model
let input = Tensor::<Backend, 2>::from_floats([
[0.33669037, 0.12880941, 0.23446237],
[0.23033303, -1.12285638, -0.18632829],
]);
let output = model.forward(input);
let expected = Data::from([
[0.33669037, 0.12880941, 0.23446237],
[0.23033303, 0.00000000, 0.00000000],
]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn sigmoid() {
// Initialize the model without weights (because the exported file does not contain them)
let model: sigmoid::Model<Backend> = sigmoid::Model::new();

// Run the model
let input = Tensor::<Backend, 2>::from_floats([
[0.33669037, 0.12880941, 0.23446237],
[0.23033303, -1.12285638, -0.18632829],
]);
let output = model.forward(input);
let expected = Data::from([
[0.58338636, 0.53215790, 0.55834854],
[0.55733001, 0.24548186, 0.45355222],
]);

output.to_data().assert_approx_eq(&expected, 7);
}

#[test]
fn transpose() {
// Initialize the model without weights (because the exported file does not contain them)
let model: transpose::Model<Backend> = transpose::Model::new();

// Run the model
let input = Tensor::<Backend, 2>::from_floats([
[0.33669037, 0.12880941, 0.23446237],
[0.23033303, -1.12285638, -0.18632829],
]);
let output = model.forward(input);
let expected = Data::from([
[0.33669037, 0.23033303],
[0.12880941, -1.12285638],
[0.23446237, -0.18632829],
]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn equal_scalar_to_scalar_and_tensor_to_tensor() {
// Initialize the model with weights (loaded from the exported file)
let model: equal::Model<Backend> = equal::Model::default();

// Run the model
let input = Tensor::<Backend, 4>::from_floats([[[[1., 1., 1., 1.]]]]);
let scalar = 2f64;
let (tensor_out, scalar_out) = model.forward(input, scalar);
let expected_tensor = Data::from([[[[true, true, true, true]]]]);
let expected_scalar = false;

assert_eq!(tensor_out.to_data(), expected_tensor);
assert_eq!(scalar_out, expected_scalar);
}
}
11 changes: 11 additions & 0 deletions burn-import/onnx-tests/tests/relu/relu.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
pytorch2.0.1:X

input1 /relu1/Relu"Relu torch_jitZ
input


b
1


B
Loading

0 comments on commit b06fe66

Please sign in to comment.