Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing ONNX support for sqrt #991

Merged
merged 6 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ represent the corresponding Burn Op.
| [SpaceToDepth][172] |||
| [Split][173] |||
| [SplitToSequence][174] |||
| [Sqrt][175] | ||
| [Sqrt][175] | ||
| [Squeeze][176] |||
| [STFT][177] |||
| [StringNormalizer][178] |||
Expand Down
4 changes: 4 additions & 0 deletions burn-import/onnx-tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ Here is the directory structure of this crate:
- `build.rs`: This build script generates the ONNX models and is executed by `cargo test` before
running the actual tests.

## Setting up your python environment

You need to install `onnx==1.15.0` and `torch-2.1.1` in your python environment to add a new test

## Adding new tests

Here are the steps to add a new test:
Expand Down
1 change: 1 addition & 0 deletions burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ fn main() {
.input("tests/reshape/reshape.onnx")
.input("tests/sigmoid/sigmoid.onnx")
.input("tests/softmax/softmax.onnx")
.input("tests/sqrt/sqrt.onnx")
.input("tests/sub/sub_int.onnx")
.input("tests/sub/sub.onnx")
.input("tests/tanh/tanh.onnx")
Expand Down
12 changes: 12 additions & 0 deletions burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
reshape,
sigmoid,
softmax,
sqrt,
sub_int,
sub,
tanh,
Expand Down Expand Up @@ -326,6 +327,17 @@
assert_eq!(output.to_data(), expected);
}

#[test]
fn sqrt() {
let model: sqrt::Model<Backend> = sqrt::Model::new();

let input = Tensor::<Backend, 2>::from_floats([[0.0, 4.0, 9.0], [25.0, 36.0, 49.0]]);
let output = model.forward(input);

Check failure on line 335 in burn-import/onnx-tests/tests/onnx_tests.rs

View workflow job for this annotation

GitHub Actions / tests (stable, std)

[clippy] reported by reviewdog 🐶 error[E0308]: mismatched types --> burn-import/onnx-tests/tests/onnx_tests.rs:335:36 | 335 | let output = model.forward(input); | ------- ^^^^^ expected `4`, found `2` | | | arguments to this method are incorrect | = note: expected struct `burn::tensor::Tensor<_, _, 4>` found struct `burn::tensor::Tensor<_, _, 2>` note: method defined here --> /home/runner/work/burn/burn/target/debug/build/onnx-tests-23a76c21bf8f5d8f/out/model/sqrt.rs:47:12 | 47 | pub fn forward(&self, input1: Tensor<B, 4>) -> Tensor<B, 4> { | ^^^^^^^ -------------------- Raw Output: burn-import/onnx-tests/tests/onnx_tests.rs:335:36:e:error[E0308]: mismatched types --> burn-import/onnx-tests/tests/onnx_tests.rs:335:36 | 335 | let output = model.forward(input); | ------- ^^^^^ expected `4`, found `2` | | | arguments to this method are incorrect | = note: expected struct `burn::tensor::Tensor<_, _, 4>` found struct `burn::tensor::Tensor<_, _, 2>` note: method defined here --> /home/runner/work/burn/burn/target/debug/build/onnx-tests-23a76c21bf8f5d8f/out/model/sqrt.rs:47:12 | 47 | pub fn forward(&self, input1: Tensor<B, 4>) -> Tensor<B, 4> { | ^^^^^^^ -------------------- __END__
let expected = Data::from([[0.0, 2.0, 3.0], [5.0, 6.0, 7.0]]);

Check failure on line 336 in burn-import/onnx-tests/tests/onnx_tests.rs

View workflow job for this annotation

GitHub Actions / tests (stable, std)

[clippy] reported by reviewdog 🐶 error[E0277]: the trait bound `burn::tensor::Data<f32, 4>: core::convert::From<[[{float}; 3]; 2]>` is not satisfied --> burn-import/onnx-tests/tests/onnx_tests.rs:336:24 | 336 | let expected = Data::from([[0.0, 2.0, 3.0], [5.0, 6.0, 7.0]]); | ^^^^ the trait `core::convert::From<[[{float}; 3]; 2]>` is not implemented for `burn::tensor::Data<f32, 4>` | = help: the following other types implement trait `core::convert::From<T>`: <burn::tensor::Data<E, D> as core::convert::From<burn::tensor::DataSerialize<E>>> <burn::tensor::Data<E, 1> as core::convert::From<[E; A]>> <burn::tensor::Data<E, 2> as core::convert::From<[[E; B]; A]>> <burn::tensor::Data<E, 3> as core::convert::From<[[[E; C]; B]; A]>> <burn::tensor::Data<E, 4> as core::convert::From<[[[[E; D]; C]; B]; A]>> <burn::tensor::Data<E, D> as core::convert::From<&burn::tensor::DataSerialize<E>>> <burn::tensor::Data<E, 1> as core::convert::From<&[E]>> Raw Output: burn-import/onnx-tests/tests/onnx_tests.rs:336:24:e:error[E0277]: the trait bound `burn::tensor::Data<f32, 4>: core::convert::From<[[{float}; 3]; 2]>` is not satisfied --> burn-import/onnx-tests/tests/onnx_tests.rs:336:24 | 336 | let expected = Data::from([[0.0, 2.0, 3.0], [5.0, 6.0, 7.0]]); | ^^^^ the trait `core::convert::From<[[{float}; 3]; 2]>` is not implemented for `burn::tensor::Data<f32, 4>` | = help: the following other types implement trait `core::convert::From<T>`: <burn::tensor::Data<E, D> as core::convert::From<burn::tensor::DataSerialize<E>>> <burn::tensor::Data<E, 1> as core::convert::From<[E; A]>> <burn::tensor::Data<E, 2> as core::convert::From<[[E; B]; A]>> <burn::tensor::Data<E, 3> as core::convert::From<[[[E; C]; B]; A]>> <burn::tensor::Data<E, 4> as core::convert::From<[[[[E; D]; C]; B]; A]>> <burn::tensor::Data<E, D> as core::convert::From<&burn::tensor::DataSerialize<E>>> <burn::tensor::Data<E, 1> as core::convert::From<&[E]>> __END__

assert_eq!(output.to_data(), expected);
}

#[test]
fn maxpool2d() {
// Initialize the model without weights (because the exported file does not contain them)
Expand Down
16 changes: 16 additions & 0 deletions burn-import/onnx-tests/tests/sqrt/sqrt.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
pytorch2.1.1:q

onnx::Sqrt_01/Sqrt"Sqrt
main_graphZ&
onnx::Sqrt_0




b
1




B
42 changes: 42 additions & 0 deletions burn-import/onnx-tests/tests/sqrt/sqrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/sqrt/sqrt.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x):
return torch.sqrt(x)


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "sqrt.onnx"
dummy_input = torch.randn(1, 4, 9, 25, device=device)

torch.onnx.export(model, (dummy_input), onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

# Output some test data for use in the test
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]])

print("Test input data: {}".format(test_input))
output = model.forward(test_input)
print("Test output data: {}".format(output))


if __name__ == '__main__':
main()
7 changes: 7 additions & 0 deletions burn-import/src/burn/node/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub enum UnaryNodeKind {
Relu,
Reciprocal,
Sigmoid,
Sqrt,
Tanh,
Transpose,
}
Expand All @@ -43,6 +44,7 @@ impl UnaryNodeKind {
Self::Relu => "relu",
Self::Reciprocal => "reciprocal",
Self::Sigmoid => "sigmoid",
Self::Sqrt => "sqrt",
Self::Tanh => "tanh",
Self::Transpose => "transpose",
}
Expand Down Expand Up @@ -133,6 +135,11 @@ impl UnaryNode {
Self::new(input, output, UnaryNodeKind::Softmax, Rc::new(function))
}

pub(crate) fn sqrt(input: Type, output: Type) -> Self {
let function = move |input| quote! { burn::tensor::activation::sqrt(#input)};
Self::new(input, output, UnaryNodeKind::Sqrt, Rc::new(function))
}

pub(crate) fn tanh(input: Type, output: Type) -> Self {
let function = move |input| quote! { burn::tensor::activation::tanh(#input)};
Self::new(input, output, UnaryNodeKind::Tanh, Rc::new(function))
Expand Down
8 changes: 8 additions & 0 deletions burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ impl ONNXGraph {
NodeType::GatherElements => graph.register(Self::gather_conversion(node)),
NodeType::LogSoftmax => graph.register(Self::log_softmax_conversion(node)),
NodeType::Softmax => graph.register(Self::softmax_conversion(node)),
NodeType::Sqrt => graph.register(Self::sqrt_conversion(node)),
NodeType::Tanh => graph.register(Self::tanh_conversion(node)),
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
Expand Down Expand Up @@ -471,6 +472,13 @@ impl ONNXGraph {
UnaryNode::softmax(input, output, dim)
}

fn sqrt_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();

UnaryNode::sqrt(input, output)
}

fn tanh_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
Expand Down
Loading