Skip to content

Commit

Permalink
add gather (tracel-ai#947)
Browse files Browse the repository at this point in the history
  • Loading branch information
CohenAriel authored Nov 13, 2023
1 parent 831335a commit cb4c23b
Show file tree
Hide file tree
Showing 11 changed files with 252 additions and 7 deletions.
2 changes: 1 addition & 1 deletion burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ represent the corresponding Burn Op.
| [EyeLike][55] |||
| [Flatten][56] |||
| [Floor][57] |||
| [Gather][58] | ||
| [Gather][58] | ||
| [GatherElements][59] |||
| [GatherND][60] |||
| [Gelu][61] |||
Expand Down
1 change: 1 addition & 0 deletions burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ fn main() {
.input("tests/equal/equal.onnx")
.input("tests/erf/erf.onnx")
.input("tests/flatten/flatten.onnx")
.input("tests/gather/gather.onnx")
.input("tests/global_avr_pool/global_avr_pool.onnx")
.input("tests/linear/linear.onnx")
.input("tests/log_softmax/log_softmax.onnx")
Expand Down
18 changes: 18 additions & 0 deletions burn-import/onnx-tests/tests/gather/gather.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
pytorch2.1.0:�
a
onnx::GatherElements_0
onnx::GatherElements_12/GatherElements"GatherElements*
axis�
main_graphZ(
onnx::GatherElements_0


Z(
onnx::GatherElements_1


b
2


B
47 changes: 47 additions & 0 deletions burn-import/onnx-tests/tests/gather/gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/gather/gather.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, index):
x = torch.gather(x, 1, index)
return x


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "gather.onnx"
dummy_input = torch.randn(2, 2, device=device)
dummy_index = torch.randint(high=2, size=(2, 2), device=device, dtype=torch.int64)

torch.onnx.export(model, (dummy_input, dummy_index), onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

# Output some test data for use in the test
test_input = torch.tensor([[1.0, 2.0],
[3.0, 4.0]])
test_index = torch.tensor([[0, 0],
[1, 0]])

print("Test input data: {}, {}".format(test_input, test_index))
output = model.forward(test_input, test_index)
print("Test output data: {}".format(output))


if __name__ == '__main__':
main()
15 changes: 15 additions & 0 deletions burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ include_models!(
equal,
erf,
flatten,
gather,
global_avr_pool,
linear,
log_softmax,
Expand Down Expand Up @@ -246,6 +247,20 @@ mod tests {
output.to_data().assert_approx_eq(&expected.to_data(), 4);
}

#[test]
fn gather() {
// Initialize the model with weights (loaded from the exported file)
let model: gather::Model<Backend> = gather::Model::default();

// Run the model
let input = Tensor::<Backend, 2>::from_floats([[1., 2.], [3., 4.]]);
let index = Tensor::<Backend, 2, Int>::from_ints([[0, 0], [1, 0]]);
let output = model.forward(input, index);
let expected = Data::from([[1., 1.], [4., 3.]]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn globalavrpool_1d_2d() {
// The model contains 1d and 2d global average pooling nodes
Expand Down
8 changes: 6 additions & 2 deletions burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use super::{
avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode,
concat::ConcatNode, constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode,
dropout::DropoutNode, global_avg_pool::GlobalAvgPoolNode, linear::LinearNode,
matmul::MatmulNode, max_pool2d::MaxPool2dNode, reshape::ReshapeNode, unary::UnaryNode,
dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode,
linear::LinearNode, matmul::MatmulNode, max_pool2d::MaxPool2dNode, reshape::ReshapeNode,
unary::UnaryNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::record::PrecisionSettings;
Expand Down Expand Up @@ -81,6 +82,7 @@ pub enum Node<PS: PrecisionSettings> {
Conv1d(Conv1dNode<PS>),
Conv2d(Conv2dNode<PS>),
Dropout(DropoutNode),
Gather(GatherNode),
GlobalAvgPool(GlobalAvgPoolNode),
Linear(LinearNode<PS>),
Matmul(MatmulNode),
Expand All @@ -102,6 +104,7 @@ macro_rules! match_all {
Node::Conv1d(node) => $func(node),
Node::Conv2d(node) => $func(node),
Node::Dropout(node) => $func(node),
Node::Gather(node) => $func(node),
Node::GlobalAvgPool(node) => $func(node),
Node::Linear(node) => $func(node),
Node::Matmul(node) => $func(node),
Expand Down Expand Up @@ -133,6 +136,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Conv1d(_) => "conv1d",
Node::Conv2d(_) => "conv2d",
Node::Dropout(_) => "dropout",
Node::Gather(_) => "gather",
Node::GlobalAvgPool(_) => "global_avg_pool",
Node::Linear(_) => "linear",
Node::Matmul(_) => "matmul",
Expand Down
106 changes: 106 additions & 0 deletions burn-import/src/burn/node/gather.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use super::{Node, NodeCodegen};
use crate::burn::{TensorType, ToTokens, Type};

use burn::record::PrecisionSettings;
use quote::quote;

#[derive(Debug, Clone, new)]
pub struct GatherNode {
pub input: TensorType,
pub index: TensorType,
pub output: TensorType,
pub dim: usize,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}

fn input_types(&self) -> Vec<crate::burn::Type> {
vec![
Type::Tensor(self.input.clone()),
Type::Tensor(self.index.clone()),
]
}

fn forward(
&self,
scope: &mut crate::burn::Scope,
node_position: usize,
) -> proc_macro2::TokenStream {
let dim = self.dim.to_tokens();
let input = scope.tensor_use_owned(&self.input, node_position);
let index = scope.tensor_use_owned(&self.index, node_position);
let output = &self.output.name;

quote! {
let #output = #input.gather(#dim, #index);
}
}

fn into_node(self) -> super::Node<PS> {
Node::Gather(self)
}
}

#[cfg(test)]
mod tests {

use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{gather::GatherNode, test::assert_tokens},
TensorType,
};

#[test]
fn test_codegen_gather() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(GatherNode::new(
TensorType::new_float("tensor1", 2),
TensorType::new_int("tensor2", 2),
TensorType::new_float("tensor3", 2),
1,
));

graph.register_input_output(
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor3".to_string()],
);

let expected = quote! {
use burn::tensor::Int;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new_with(record: ModelRecord<B>) -> Self {
Self {
phantom: core::marker::PhantomData,
}
}

#[allow(clippy::let_and_return)]
pub fn forward(&self, tensor1: Tensor<B, 2>, tensor2: Tensor<B, 2, Int>) -> Tensor<B, 2> {
let tensor3 = tensor1.gather(1, tensor2);

tensor3
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
1 change: 1 addition & 0 deletions burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub(crate) mod constant;
pub(crate) mod conv1d;
pub(crate) mod conv2d;
pub(crate) mod dropout;
pub(crate) mod gather;
pub(crate) mod global_avg_pool;
pub(crate) mod linear;
pub(crate) mod matmul;
Expand Down
1 change: 1 addition & 0 deletions burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ pub fn dim_inference(
NodeType::MaxPool2d => same_as_input(node),
NodeType::Linear => linear_update_outputs(node),
NodeType::Flatten => flatten_update_outputs(node),
NodeType::GatherElements => same_as_input(node),
NodeType::Relu => same_as_input(node),
NodeType::LogSoftmax => same_as_input(node),
NodeType::BatchNormalization => same_as_input(node),
Expand Down
32 changes: 32 additions & 0 deletions burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,38 @@ pub fn flatten_config(curr: &Node) -> (usize, usize) {
(start_dim as usize, end_dim)
}

/// Create a GatherConfig from the attributes of the node
pub fn gather_config(curr: &Node) -> usize {
// Default: 0 per ONNX spec
let mut dim: i64 = 0;

// check if the node has only one input
if curr.inputs.len() != 2 {
panic!("Gather: index tensor must be present");
}

// extract the shape of the input tensor
let tensor = match curr.inputs.get(0).unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

// extract the attributes
for (key, value) in curr.attrs.iter() {
match key.as_str() {
"axis" => dim = value.clone().into_i64(),
_ => {}
}
}

// if dim is negative, it is counted from the end
if dim < 0 {
dim += tensor.dim as i64;
}

dim as usize
}

/// Create a LinearConfig from the attributes of the node
pub fn linear_config(node: &Node) -> LinearConfig {
if node.inputs.len() < 2 {
Expand Down
28 changes: 24 additions & 4 deletions burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::{
conv1d::Conv1dNode,
conv2d::Conv2dNode,
dropout::DropoutNode,
gather::GatherNode,
global_avg_pool::GlobalAvgPoolNode,
linear::LinearNode,
matmul::MatmulNode,
Expand All @@ -37,15 +38,15 @@ use crate::{
from_onnx::convert_constant_value,
ir::{Node, NodeType},
op_configuration::{
batch_norm_config, conv1d_config, conv2d_config, flatten_config, linear_config,
log_softmax_config, max_pool2d_config,
batch_norm_config, conv1d_config, conv2d_config, flatten_config, gather_config,
linear_config, log_softmax_config, max_pool2d_config,
},
},
};

use super::{
from_onnx::parse_onnx,
ir::{ArgType, Argument, Data, ElementType, ONNXGraph},
ir::{self, ArgType, Argument, Data, ElementType, ONNXGraph},
op_configuration::{
avg_pool2d_config, clip_config, concat_config, dropout_config, reshape_config,
softmax_config,
Expand Down Expand Up @@ -243,6 +244,7 @@ impl ONNXGraph {
}
NodeType::Relu => graph.register(Self::relu_conversion(node)),
NodeType::Flatten => graph.register(Self::flatten_conversion(node)),
NodeType::GatherElements => graph.register(Self::gather_conversion(node)),
NodeType::LogSoftmax => graph.register(Self::log_softmax_conversion(node)),
NodeType::Softmax => graph.register(Self::softmax_conversion(node)),
NodeType::Tanh => graph.register(Self::tanh_conversion(node)),
Expand Down Expand Up @@ -399,6 +401,15 @@ impl ONNXGraph {
UnaryNode::flatten(input, output, start_dim, end_dim)
}

fn gather_conversion(node: Node) -> GatherNode {
let input = node.inputs.get(0).unwrap().to_tensor_type();
let index = node.inputs.get(1).unwrap().to_tensor_type();
let output = node.outputs.get(0).unwrap().to_tensor_type();
let dim = gather_config(&node);

GatherNode::new(input, index, output, dim)
}

fn transpose_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
Expand Down Expand Up @@ -629,7 +640,16 @@ fn serialize_data<E: Element>(data: Data, shape: Vec<usize>) -> DataSerialize<E>
impl Argument {
pub fn to_tensor_type(&self) -> TensorType {
match &self.ty {
ArgType::Tensor(tensor) => TensorType::new_float(self.name.clone(), tensor.dim),
ArgType::Tensor(ir::TensorType {
elem_type: ElementType::Float16 | ElementType::Float32 | ElementType::Float64,
dim,
..
}) => TensorType::new_float(self.name.clone(), *dim),
ArgType::Tensor(ir::TensorType {
elem_type: ElementType::Int32 | ElementType::Int64,
dim,
..
}) => TensorType::new_int(self.name.clone(), *dim),
_ => panic!("Can't transform to tensor."),
}
}
Expand Down

0 comments on commit cb4c23b

Please sign in to comment.