Skip to content

Commit

Permalink
Extend cast onnx support to tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Apr 15, 2024
1 parent e303e31 commit f855c1f
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 27 deletions.
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ fn main() {
.input("tests/add/add.onnx")
.input("tests/avg_pool2d/avg_pool2d.onnx")
.input("tests/batch_norm/batch_norm.onnx")
.input("tests/cast/cast.onnx")
.input("tests/clip/clip_opset16.onnx")
.input("tests/clip/clip_opset7.onnx")
.input("tests/concat/concat.onnx")
Expand Down
Binary file added crates/burn-import/onnx-tests/tests/cast/cast.onnx
Binary file not shown.
64 changes: 64 additions & 0 deletions crates/burn-import/onnx-tests/tests/cast/cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/cast/cast.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(
self,
x_bool,
x_int,
x_float,
x_scalar,
):
# NOTE: we clone same-type casts for int and bool, otherwise the exporter would
# link other type casts to the output of the bool cast, leading to additional casts
return (
x_bool.clone().bool(),
x_bool.int(),
x_bool.float(),
x_int.bool(),
x_int.clone().int(),
x_int.float(),
x_float.bool(),
x_float.int(),
x_float.float(),
x_scalar.int(),
)


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "cast.onnx"
test_bool = torch.ones((2, 1), device=device, dtype=torch.bool)
test_int = torch.ones((2, 1), device=device, dtype=torch.int)
test_float = torch.ones((2, 1), device=device, dtype=torch.float)
test_scalar = torch.ones(1, device=device, dtype=torch.float).squeeze()
test_input = (test_bool, test_int, test_float, test_scalar)

# NOTE: torch exports logical_not with a cast node even if the input is already bool
# https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset9.py#L2204-L2207
torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=16)

print(f"Finished exporting model to {onnx_name}")

# Output some test data for use in the test
print(f"Test input data: {test_input}")
output = model.forward(*test_input)
print(f"Test output data: {output}")


if __name__ == "__main__":
main()
51 changes: 50 additions & 1 deletion crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ include_models!(
add,
avg_pool2d,
batch_norm,
cast,
clip_opset16,
clip_opset7,
concat,
Expand Down Expand Up @@ -62,7 +63,7 @@ mod tests {

use super::*;

use burn::tensor::{Data, Int, Shape, Tensor};
use burn::tensor::{Bool, Data, Int, Shape, Tensor};

use float_cmp::ApproxEq;

Expand Down Expand Up @@ -877,4 +878,52 @@ mod tests {
let output = model.forward(input);
assert_eq!(output.shape(), expected_shape);
}

#[test]
fn cast() {
let device = Default::default();
let model: cast::Model<Backend> = cast::Model::new(&device);

let input_bool =
Tensor::<Backend, 2, Bool>::from_bool(Data::from([[true], [true]]), &device);
let input_int = Tensor::<Backend, 2, Int>::from_ints([[1], [1]], &device);
let input_float = Tensor::<Backend, 2>::from_floats([[1.], [1.]], &device);
let input_scalar = 1f32;

let (
output1,
output2,
output3,
output4,
output5,
output6,
output7,
output8,
output9,
output_scalar,
) = model.forward(
input_bool.clone(),
input_int.clone(),
input_float.clone(),
input_scalar,
);
let expected_bool = input_bool.to_data();
let expected_int = input_int.to_data();
let expected_float = input_float.to_data();
let expected_scalar = 1;

assert_eq!(output1.to_data(), expected_bool);
assert_eq!(output2.to_data(), expected_int);
output3.to_data().assert_approx_eq(&expected_float, 4);

assert_eq!(output4.to_data(), expected_bool);
assert_eq!(output5.to_data(), expected_int);
output6.to_data().assert_approx_eq(&expected_float, 4);

assert_eq!(output7.to_data(), expected_bool);
assert_eq!(output8.to_data(), expected_int);
output9.to_data().assert_approx_eq(&expected_float, 4);

assert_eq!(output_scalar, expected_scalar);
}
}
115 changes: 93 additions & 22 deletions crates/burn-import/src/burn/node/unary.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{Node, NodeCodegen};
use crate::burn::{BurnImports, Scope, ToTokens, Type};
use crate::burn::{BurnImports, Scope, TensorKind, ToTokens, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;
Expand All @@ -20,7 +20,8 @@ pub struct UnaryNode {
/// Type of unary node.
#[derive(Clone)]
pub enum UnaryNodeKind {
Cast,
// Input and output tensor types (required for codegen imports)
Cast(Option<TensorKind>, Option<TensorKind>),
Cos,
Erf,
Exp,
Expand All @@ -42,7 +43,7 @@ pub enum UnaryNodeKind {
impl UnaryNodeKind {
pub fn as_str(&self) -> &str {
match self {
Self::Cast => "cast",
Self::Cast(..) => "cast",
Self::Cos => "cos",
Self::Erf => "erf",
Self::Exp => "exp",
Expand Down Expand Up @@ -116,6 +117,18 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for UnaryNode {
UnaryNodeKind::Neg => {
imports.register("core::ops::Neg");
}
UnaryNodeKind::Cast(input, output) => {
if let Some(input_kind) = input {

Check failure on line 121 in crates/burn-import/src/burn/node/unary.rs

View workflow job for this annotation

GitHub Actions / clippy

[clippy] crates/burn-import/src/burn/node/unary.rs#L121

error: this `if let` can be collapsed into the outer `match` --> crates/burn-import/src/burn/node/unary.rs:121:17 | 121 | / if let Some(input_kind) = input { 122 | | if let Some(output_kind) = output { 123 | | if input_kind == TensorKind::Bool || output_kind == TensorKind::Bool { 124 | | imports.register("burn::tensor::Bool"); ... | 129 | | } 130 | | } | |_________________^ | help: the outer pattern can be modified to include the inner pattern --> crates/burn-import/src/burn/node/unary.rs:120:33 | 120 | UnaryNodeKind::Cast(input, output) => { | ^^^^^ replace this binding 121 | if let Some(input_kind) = input { | ^^^^^^^^^^^^^^^^ with this pattern = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#collapsible_match = note: `-D clippy::collapsible-match` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(clippy::collapsible_match)]`
Raw output
crates/burn-import/src/burn/node/unary.rs:121:17:e:error: this `if let` can be collapsed into the outer `match`
   --> crates/burn-import/src/burn/node/unary.rs:121:17
    |
121 | /                 if let Some(input_kind) = input {
122 | |                     if let Some(output_kind) = output {
123 | |                         if input_kind == TensorKind::Bool || output_kind == TensorKind::Bool {
124 | |                             imports.register("burn::tensor::Bool");
...   |
129 | |                     }
130 | |                 }
    | |_________________^
    |
help: the outer pattern can be modified to include the inner pattern
   --> crates/burn-import/src/burn/node/unary.rs:120:33
    |
120 |             UnaryNodeKind::Cast(input, output) => {
    |                                 ^^^^^ replace this binding
121 |                 if let Some(input_kind) = input {
    |                        ^^^^^^^^^^^^^^^^ with this pattern
    = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#collapsible_match
    = note: `-D clippy::collapsible-match` implied by `-D warnings`
    = help: to override `-D warnings` add `#[allow(clippy::collapsible_match)]`


__END__
if let Some(output_kind) = output {
if input_kind == TensorKind::Bool || output_kind == TensorKind::Bool {
imports.register("burn::tensor::Bool");
}
if input_kind == TensorKind::Int || output_kind == TensorKind::Int {
imports.register("burn::tensor::Int");
}
}
}
}
_ => {}
}
}
Expand Down Expand Up @@ -209,41 +222,54 @@ impl UnaryNode {
}

/// Casts the input to the output type.
///
/// Currently this function only supports the following conversions:
/// 1) scalar -> scalar
///
/// TODO: Implement the following conversions:
/// 2) tensor int -> tensor float
/// 3) tensor float -> tensor int
/// 4) tensor -> scalar
/// 5) scalar -> tensor
pub(crate) fn cast(input: Type, output: Type) -> Self {
match (input.clone(), output.clone()) {
(Type::Scalar(input_scalar), Type::Scalar(output_scalar)) => {
if input_scalar.kind == output_scalar.kind {
// If the input and output types are the same, we don't need to cast.
Self::new(input, output, UnaryNodeKind::Cast, Rc::new(|input| input))
Self::new(
input,
output,
UnaryNodeKind::Cast(None, None),
Rc::new(|input| input),
)
} else {
// If the input and output types are different, we need to cast.
let ty = output_scalar.ty();
Self::new(
input,
output,
UnaryNodeKind::Cast,
UnaryNodeKind::Cast(None, None),
Rc::new(move |input| quote! { #input as #ty }),
)
}
}
(Type::Tensor(_input_tensor), Type::Tensor(_output_tensor)) => {
// TODO: Implement this after tensor Int is implemented (@antimora 8/2/2023)
// TODO: If the input is scalar and the output type is a tensor,
// we should generate another code block. (@antimora 8/4/2023)
// Tensor::from_data(Data::from([#input]).convert()).unsqueeze();
todo!()
}
(Type::Tensor(input_tensor), Type::Tensor(output_tensor)) => {
if input_tensor.kind == output_tensor.kind {
// If the input and output types are the same, we don't need to cast.
Self::new(
input,
output,
UnaryNodeKind::Cast(Some(input_tensor.kind), Some(output_tensor.kind)),
Rc::new(|input| input),
)
} else {
// If the input and output types are different, we need to cast.
let function = match output_tensor.kind {
TensorKind::Bool => move |input| quote! { #input.bool()},
TensorKind::Int => move |input| quote! { #input.int()},
TensorKind::Float => move |input| quote! { #input.float()},
};

_ => panic!("output must be a tensor"),
Self::new(
input,
output,
UnaryNodeKind::Cast(Some(input_tensor.kind), Some(output_tensor.kind)),
Rc::new(function),
)
}
}
_ => panic!("output must be a tensor or scalar"),
}
}
}
Expand Down Expand Up @@ -481,6 +507,51 @@ mod tests {
vec!["scalar1".to_string()],
vec!["scalar2".to_string()],
);
one_node_graph(
UnaryNode::cast(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_int("tensor2", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4, Int> {
let tensor2 = tensor1.int();

tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
one_node_graph(
UnaryNode::cast(
Type::Tensor(TensorType::new_int("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4, Int>) -> Tensor<B, 4> {
let tensor2 = tensor1.float();

tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
one_node_graph(
UnaryNode::cast(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_bool("tensor2", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4, Bool> {
let tensor2 = tensor1.bool();

tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-import/src/burn/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub struct TensorType {
pub shape: Option<Vec<usize>>,
}

#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorKind {
Int,
Float,
Expand Down
14 changes: 11 additions & 3 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ fn cast_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
panic!("Cast: multiple inputs are not supported");
}
let input = &mut node.inputs[0];
let output = &mut node.outputs[0];

// Extract cast type and update the output tensor
Expand All @@ -144,26 +145,33 @@ fn cast_update_outputs(node: &mut Node) {
DataType::INT32 => ElementType::Int32,
DataType::INT64 => ElementType::Int64,
DataType::DOUBLE => ElementType::Float64,
DataType::BOOL => ElementType::Bool,
_ => panic!("Cast: unsupported type"),
},
_ => panic!("'to' attribute must be an Int64"),
},
None => panic!("Constant node must have a value attribute"),
};

match output.ty.clone() {
match input.ty.clone() {
ArgType::Tensor(tensor) => {
if tensor.dim == 0 {
// treat 0-dim tensor as scalar
output.ty = ArgType::Scalar(elem_type);
input.ty = ArgType::Scalar(tensor.elem_type);
} else {
todo!("Cast: support casting from different tensor types");
// Cast input and output are the same shape, but possibly different types
output.ty = ArgType::Tensor(TensorType {
elem_type,
dim: tensor.dim,
shape: tensor.shape.clone(),
});
}
}
ArgType::Scalar(_scalar) => {
output.ty = ArgType::Scalar(elem_type);
}
_ => panic!("Cast: only scalar input is valid"),
_ => panic!("Cast: only scalar and tensor inputs are valid"),
}
}

Expand Down

0 comments on commit f855c1f

Please sign in to comment.