Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Implement ONNX RandomUniform + RandomNormal in burn-import #1806

Merged
merged 4 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ represent the corresponding Burn Op.
| [QLinearConv][123] | ❌ | ❌ |
| [QLinearMatMul][124] | ❌ | ❌ |
| [QuantizeLinear][125] | ❌ | ❌ |
| [RandomNormal][126] | ❌ | ✅ |
| [RandomNormal][126] | ✅ | ✅ |
| [RandomNormalLike][127] | ❌ | ✅ |
| [RandomUniform][128] | ❌ | ✅ |
| [RandomUniform][128] | ✅ | ✅ |
| [RandomUniformLike][129] | ❌ | ✅ |
| [Range][130] | ❌ | ✅ |
| [Reciprocal][131] | ✅ | ✅ |
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ fn main() {
.input("tests/mask_where/mask_where.onnx")
.input("tests/squeeze/squeeze_opset16.onnx")
.input("tests/squeeze/squeeze_opset13.onnx")
.input("tests/random_uniform/random_uniform.onnx")
.input("tests/random_normal/random_normal.onnx")
.out_dir("model/")
.run_from_script();

Expand Down
22 changes: 21 additions & 1 deletion crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ include_models!(
unsqueeze_opset16,
unsqueeze_opset11,
squeeze_opset16,
squeeze_opset13
squeeze_opset13,
random_uniform,
random_normal
);

#[cfg(test)]
Expand Down Expand Up @@ -1410,4 +1412,22 @@ mod tests {
let output = model.forward(input);
assert_eq!(expected_shape, output.shape());
}

#[test]
fn random_uniform() {
let device = Default::default();
let model = random_uniform::Model::<Backend>::new(&device);
let expected_shape = Shape::from([2, 3]);
let output = model.forward();
assert_eq!(expected_shape, output.shape());
}

#[test]
fn random_normal() {
let device = Default::default();
let model = random_normal::Model::<Backend>::new(&device);
let expected_shape = Shape::from([2, 3]);
let output = model.forward();
assert_eq!(expected_shape, output.shape());
}
}
Binary file not shown.
48 changes: 48 additions & 0 deletions crates/burn-import/onnx-tests/tests/random_normal/random_normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python3

# used to generate model: random_normal.onnx

# torch doesn't generate RandomNormal operations in ONNX,
# but always uses RandomNormalLike.
# Hence this model is exported using onnx directly

import onnx
import onnx.helper


def build_model():
return onnx.helper.make_model(
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 16)],
graph=onnx.helper.make_graph(name="main_graph", nodes=[
onnx.helper.make_node(
"RandomNormal",
inputs=[],
outputs=["output1"],
name="/RandomNormal",
mean=2.0,
scale=1.5,
shape=[2, 3]
),
],
inputs=[],
outputs=[
onnx.helper.make_value_info(
name="output1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3]
),
)
]),
)
hexd0t marked this conversation as resolved.
Show resolved Hide resolved


def main():
onnx_model = build_model()
file_name = "random_normal.onnx"

onnx.save(onnx_model, file_name)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
pytorch2.3.0:c
@1/RandomUniform"RandomUniform*
dtype *
shape@@ 
main_graphb
1


B
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python3

# used to generate model: random_uniform.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, _in):
return torch.rand(2, 3)


def main():
# Set seed for reproducibility
torch.manual_seed(42)

torch.set_printoptions(precision=8)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

file_name = "random_uniform.onnx"
test_input = torch.empty(0)
torch.onnx.export(model, test_input, file_name,
verbose=False, opset_version=16)

print(f"Finished exporting model to {file_name}")

# Output some test data for use in the test
print(f"Test input data: {test_input}")
print(f"Test input data shape: {test_input.shape}")
output = model.forward(test_input)
print(f"Test output data shape: {output.shape}")
print(f"Test output data: {output}")


if __name__ == '__main__':
main()
9 changes: 8 additions & 1 deletion crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use super::{
conv1d::Conv1dNode, conv2d::Conv2dNode, conv_transpose_2d::ConvTranspose2dNode,
dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, reshape::ReshapeNode,
squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
Expand Down Expand Up @@ -99,6 +100,8 @@ pub enum Node<PS: PrecisionSettings> {
Unary(UnaryNode),
Unsqueeze(UnsqueezeNode),
Where(WhereNode),
RandomUniform(RandomUniformNode),
RandomNormal(RandomNormalNode),
}

macro_rules! match_all {
Expand Down Expand Up @@ -129,6 +132,8 @@ macro_rules! match_all {
Node::Unary(node) => $func(node),
Node::Unsqueeze(node) => $func(node),
Node::Where(node) => $func(node),
Node::RandomNormal(node) => $func(node),
Node::RandomUniform(node) => $func(node),
}
}};
}
Expand Down Expand Up @@ -169,6 +174,8 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Unary(unary) => unary.kind.as_str(),
Node::Unsqueeze(_) => "unsqueeze",
Node::Where(_) => "where",
Node::RandomNormal(_) => "random_normal",
Node::RandomUniform(_) => "random_uniform",
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ pub(crate) mod matmul;
pub(crate) mod max_pool1d;
pub(crate) mod max_pool2d;
pub(crate) mod prelu;
pub(crate) mod random_normal;
pub(crate) mod random_uniform;
pub(crate) mod reshape;
pub(crate) mod squeeze;
pub(crate) mod unary;
Expand Down
128 changes: 128 additions & 0 deletions crates/burn-import/src/burn/node/random_normal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Debug, Clone)]
pub struct RandomNormalNode {
pub mean: f64,
pub scale: f64,
pub output_ty: TensorType,
}

impl RandomNormalNode {
pub fn new(output_ty: TensorType, mean: f64, scale: f64) -> Self {
Self {
mean,
scale,
output_ty,
}
}

fn get_output_shape(&self) -> TokenStream {
let shape_it = self
.output_ty
.shape
.as_ref()
.expect("RandomNormal output has no shape!")
.iter();
quote! { Shape::new([#(#shape_it),*]) }
}

fn get_distribution(&self) -> TokenStream {
let std_deviation = self.scale; // ONNX spec defines `scale` == `standard deviation`
let mean = self.mean;
quote! { Distribution::Normal(#mean, #std_deviation) }
}
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for RandomNormalNode {
fn input_types(&self) -> Vec<Type> {
Vec::with_capacity(0)
}

fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output_ty.clone())]
}

fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream {
let output = &self.output_ty.name;
let shape = self.get_output_shape();
let dist = self.get_distribution();
quote! {
let #output = Tensor::random(#shape, #dist, &*self.device);
}
}

fn into_node(self) -> Node<PS> {
Node::RandomNormal(self)
}

fn register_imports(&self, imports: &mut crate::burn::BurnImports) {
imports.register("burn::tensor::Distribution");
imports.register("burn::prelude::Shape");
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{random_normal::RandomNormalNode, test::assert_tokens},
TensorKind, TensorType,
};

#[test]
fn test_codegen_nodes() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(RandomNormalNode::new(
TensorType::new("tensor1", 2, TensorKind::Float, Some(vec![2, 3])),
0.0f64,
1.0f64,
));

graph.register_input_output(vec![], vec!["tensor1".to_string()]);

let expected = quote! {
use burn::prelude::Shape;
use burn::tensor::Distribution;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self) -> Tensor<B, 2> {
let tensor1 = Tensor::random(
Shape::new([2usize, 3usize]),
Distribution::Normal(0f64, 1f64),
&*self.device,
);

tensor1
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
Loading
Loading