Skip to content

Commit

Permalink
Implementing ONNX support for sqrt (tracel-ai#991)
Browse files Browse the repository at this point in the history
* Implementing ONNX support for sqrt

* Formatting, my bad

* Implementing feedback from pull request

* Fixing codegen

* Fixing tests

* Fixing tests
  • Loading branch information
edmondop authored Nov 22, 2023
1 parent 17f5905 commit b86bc58
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 1 deletion.
2 changes: 1 addition & 1 deletion burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ represent the corresponding Burn Op.
| [SpaceToDepth][172] |||
| [Split][173] |||
| [SplitToSequence][174] |||
| [Sqrt][175] | ||
| [Sqrt][175] | ||
| [Squeeze][176] |||
| [STFT][177] |||
| [StringNormalizer][178] |||
Expand Down
4 changes: 4 additions & 0 deletions burn-import/onnx-tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ Here is the directory structure of this crate:
- `build.rs`: This build script generates the ONNX models and is executed by `cargo test` before
running the actual tests.

## Setting up your python environment

You need to install `onnx==1.15.0` and `torch-2.1.1` in your python environment to add a new test

## Adding new tests

Here are the steps to add a new test:
Expand Down
1 change: 1 addition & 0 deletions burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ fn main() {
.input("tests/reshape/reshape.onnx")
.input("tests/sigmoid/sigmoid.onnx")
.input("tests/softmax/softmax.onnx")
.input("tests/sqrt/sqrt.onnx")
.input("tests/sub/sub_int.onnx")
.input("tests/sub/sub.onnx")
.input("tests/tanh/tanh.onnx")
Expand Down
13 changes: 13 additions & 0 deletions burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ include_models!(
reshape,
sigmoid,
softmax,
sqrt,
sub_int,
sub,
tanh,
Expand Down Expand Up @@ -326,6 +327,18 @@ mod tests {
assert_eq!(output.to_data(), expected);
}

#[test]
fn sqrt() {
let model: sqrt::Model<Backend> = sqrt::Model::new();

let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]);

let output = model.forward(input);
let expected = Data::from([[[[1.0, 2.0, 3.0, 5.0]]]]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn maxpool2d() {
// Initialize the model without weights (because the exported file does not contain them)
Expand Down
16 changes: 16 additions & 0 deletions burn-import/onnx-tests/tests/sqrt/sqrt.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
pytorch2.1.1:q

onnx::Sqrt_01/Sqrt"Sqrt
main_graphZ&
onnx::Sqrt_0




b
1




B
42 changes: 42 additions & 0 deletions burn-import/onnx-tests/tests/sqrt/sqrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/sqrt/sqrt.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x):
return torch.sqrt(x)


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "sqrt.onnx"
dummy_input = torch.randn(1, 4, 9, 25, device=device)

torch.onnx.export(model, (dummy_input), onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

# Output some test data for use in the test
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]])

print("Test input data: {}".format(test_input))
output = model.forward(test_input)
print("Test output data: {}".format(output))


if __name__ == '__main__':
main()
7 changes: 7 additions & 0 deletions burn-import/src/burn/node/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub enum UnaryNodeKind {
Relu,
Reciprocal,
Sigmoid,
Sqrt,
Tanh,
Transpose,
}
Expand All @@ -43,6 +44,7 @@ impl UnaryNodeKind {
Self::Relu => "relu",
Self::Reciprocal => "reciprocal",
Self::Sigmoid => "sigmoid",
Self::Sqrt => "sqrt",
Self::Tanh => "tanh",
Self::Transpose => "transpose",
}
Expand Down Expand Up @@ -133,6 +135,11 @@ impl UnaryNode {
Self::new(input, output, UnaryNodeKind::Softmax, Rc::new(function))
}

pub(crate) fn sqrt(input: Type, output: Type) -> Self {
let function = move |input| quote! { #input.sqrt()};
Self::new(input, output, UnaryNodeKind::Sqrt, Rc::new(function))
}

pub(crate) fn tanh(input: Type, output: Type) -> Self {
let function = move |input| quote! { burn::tensor::activation::tanh(#input)};
Self::new(input, output, UnaryNodeKind::Tanh, Rc::new(function))
Expand Down
8 changes: 8 additions & 0 deletions burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ impl ONNXGraph {
NodeType::GatherElements => graph.register(Self::gather_conversion(node)),
NodeType::LogSoftmax => graph.register(Self::log_softmax_conversion(node)),
NodeType::Softmax => graph.register(Self::softmax_conversion(node)),
NodeType::Sqrt => graph.register(Self::sqrt_conversion(node)),
NodeType::Tanh => graph.register(Self::tanh_conversion(node)),
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
Expand Down Expand Up @@ -471,6 +472,13 @@ impl ONNXGraph {
UnaryNode::softmax(input, output, dim)
}

fn sqrt_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();

UnaryNode::sqrt(input, output)
}

fn tanh_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
Expand Down

0 comments on commit b86bc58

Please sign in to comment.