Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Account for additional cases in Linear node conversion #709

Merged
merged 2 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Account for additional cases in Linear node conversion
  • Loading branch information
antimora committed Aug 28, 2023
commit d8e7aa6c54942bcd43f1549b709d245b5f6ca1bc
Binary file modified burn-import/onnx-tests/tests/linear/linear.onnx
Binary file not shown.
46 changes: 33 additions & 13 deletions burn-import/onnx-tests/tests/linear/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,22 @@ class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

# When input's 2d with [1, n] shape and bias is True, Gemm ONNX node is used
self.fc1 = nn.Linear(16, 32, bias=True)
# Case 1: When input's 2d with [1, n] shape and bias is True.
# This will produce a single Gemm ONNX node with alpha=1 and beta=1 and transB=1 attributes.
# with 3 inputs and 1 output. The 3 inputs are the input, weight, and bias.
self.linear1_with_gemm = nn.Linear(3, 4, bias=True)

# TODO Test other cases that use matmul instead of Gemm
# Case 2: When input >= 2D but linear does not have bias.
self.linear2_with_matmul = nn.Linear(5, 6, bias=False)

def forward(self, x):
x = self.fc1(x)
return x
# Case 3: When input > 2D and linear does have bias or does not have bias (doesn't matter).
self.linear3_with_matmul = nn.Linear(7, 8, bias=True)

def forward(self, x1, x2, x3):
y1 = self.linear1_with_gemm(x1)
y2 = self.linear2_with_matmul(x2)
y3 = self.linear3_with_matmul(x3)
return y1, y2, y3


def main():
Expand All @@ -33,20 +41,32 @@ def main():
print("Made model")

file_name = "linear.onnx"
test_input = torch.full((1, 16), 3.141592, device=device)
torch.onnx.export(model, test_input, file_name,
input1 = torch.full((4, 3), 3.14, device=device)
input2 = torch.full((2, 5), 3.14, device=device)
input3 = torch.full((3, 2, 7), 3.14, device=device)
torch.onnx.export(model, (input1, input2, input3), file_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data shape of ones: {}".format(test_input.shape))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))
print("Test input1 data shape: {}".format(input1.shape))
print("Test input2 data shape: {}".format(input2.shape))
print("Test input3 data shape: {}".format(input3.shape))

output1, output2, output3 = model.forward(input1, input2, input3)

print("Test output1 data shape: {}".format(output1.shape))
print("Test output2 data shape: {}".format(output2.shape))
print("Test output3 data shape: {}".format(output3.shape))

sum = output.sum().item()
sum1 = output1.sum().item()
sum2 = output2.sum().item()
sum3 = output3.sum().item()

print("Test output sum: {}".format(sum))
print("Test output sum1: {}".format(sum1))
print("Test output sum2: {}".format(sum2))
print("Test output sum3: {}".format(sum3))


if __name__ == "__main__":
Expand Down
31 changes: 21 additions & 10 deletions burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,22 +507,33 @@ mod tests {
fn linear() {
// Initialize the model with weights (loaded from the exported file)
let model: linear::Model<Backend> = linear::Model::default();
let input1 = Tensor::<Backend, 2>::full([4, 3], 3.14);
let input2 = Tensor::<Backend, 2>::full([2, 5], 3.14);
let input3 = Tensor::<Backend, 3>::full([3, 2, 7], 3.14);

// Run the model with 3.1416 as input for easier testing
let input = Tensor::<Backend, 2>::full([1, 16], 3.1416);

let output = model.forward(input);
let (output1, output2, output3) = model.forward(input1, input2, input3);

// test the output shape
let expected_shape: Shape<2> = Shape::from([1, 32]);
assert_eq!(output.shape(), expected_shape);
let expected_shape1: Shape<2> = Shape::from([4, 4]);
let expected_shape2: Shape<2> = Shape::from([2, 6]);
let expected_shape3: Shape<3> = Shape::from([3, 2, 8]);
assert_eq!(output1.shape(), expected_shape1);
assert_eq!(output2.shape(), expected_shape2);
assert_eq!(output3.shape(), expected_shape3);

// We are using the sum of the output tensor to test the correctness of the conv1d node
// because the output tensor is too large to compare with the expected tensor.
let output_sum = output.sum().into_scalar();
let expected_sum = -3.205_825; // from pytorch
println!("output_sum: {}", output_sum);
assert!(expected_sum.approx_eq(output_sum, (1.0e-5, 2)));
let output_sum1 = output1.sum().into_scalar();
let output_sum2 = output2.sum().into_scalar();
let output_sum3 = output3.sum().into_scalar();

let expected_sum1 = -9.655_477; // from pytorch
let expected_sum2 = -8.053_822; // from pytorch
let expected_sum3 = 27.575_281; // from pytorch

assert!(expected_sum1.approx_eq(output_sum1, (1.0e-6, 2)));
assert!(expected_sum2.approx_eq(output_sum2, (1.0e-6, 2)));
assert!(expected_sum3.approx_eq(output_sum3, (1.0e-6, 2)));
}

#[test]
Expand Down
86 changes: 79 additions & 7 deletions burn-import/src/onnx/coalesce.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
use crate::onnx::ir::{ArgType, Data, TensorType};
use std::{iter::Peekable, slice::IterMut};

use super::ir::{AttributeValue, Node, NodeType};
use crate::onnx::ir::{ArgType, Data, TensorType};

/// The function transforms the graph into a new one where the nodes are coalesced into a single node.
pub fn coalesce(nodes: &mut Vec<Node>) {
for node in nodes.iter_mut() {
let mut iter_mut = nodes.iter_mut().peekable();
let mut nodes_to_remove: Vec<String> = vec![];
while let Some(node) = iter_mut.next() {
match node.node_type {
NodeType::Gemm => convert_gemm(node),
// TODO Account when linear is converted into MatMul and Add nodes
NodeType::Gemm => convert_gemm_to_linear(node),
NodeType::MatMul => {
convert_matmul_to_linear(node, &mut iter_mut, &mut nodes_to_remove);
}
_ => {}
}
}

// Remove nodes instructed by conversation functions
for node_to_remove in nodes_to_remove {
nodes.retain(|n| n.name != node_to_remove);
}
}

/// This function converts a Gemm node into a Linear node
///
/// Warning: This function is not complete yet.
/// It only supports the case where the Gemm node is a straight linear transformation.
fn convert_gemm(node: &mut Node) {
/// PyTorch and other frameworks use Gemm node to represent Linear layer.
fn convert_gemm_to_linear(node: &mut Node) {
if node.outputs.len() != 1 {
panic!("Gemm node must have 1 output");
}
Expand Down Expand Up @@ -101,3 +110,66 @@ fn transpose_flattened<T: Copy>(matrix: Vec<T>, rows: usize, cols: usize) -> Vec

transposed
}

/// This function converts a MatMul node into a Linear node if possible.
///
/// PyTorch and other frameworks use MatMul node to represent Linear layer.
///
/// This function also converts the following Add node into a Linear node if possible.
/// Add node is used to represent bias in PyTorch.
fn convert_matmul_to_linear(
node: &mut Node,
iter_mut: &mut Peekable<IterMut<Node>>,
nodes_to_remove: &mut Vec<String>,
) {
if node.inputs.len() != 2 {
panic!("MatMul node must have 2 inputs");
}

// Do not convert if the second input does not have a value, and
// treat it as a normal MatMul node
if node.inputs[1].value.is_none() {
return;
}

let weight = node.inputs[1]
.clone()
.into_tensor()
.expect("Tensor input is expected");

assert_eq!(weight.dim, 2, "Weight must be a 2D tensor");

// Convert the node to Linear
node.node_type = NodeType::Linear;

// The following block of code is used to convert the following Add node into this Linear node
// Add node is used to represent bias in PyTorch.
let peek_node = iter_mut.peek(); // Peek the next node
if peek_node.is_some()
&& peek_node.unwrap().node_type == NodeType::Add
&& peek_node.unwrap().inputs.len() == 2

// Make sure the Add node has a value in one of its inputs and
// the other input is the output of this MatMul node
&& (peek_node.unwrap().inputs[0].name == node.outputs[0].name
&& peek_node.unwrap().inputs[1].value.is_some())
| (peek_node.unwrap().inputs[1].name == node.outputs[0].name
&& peek_node.unwrap().inputs[0].value.is_some())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Convert the condition into a function. You can early return if peak_node.is_none(), then call unwrap only once.

{
// Proceed iteration
let bias_node = iter_mut.next().unwrap();

// Copy input value from one of the inputs of the Add node
if bias_node.inputs[0].value.is_some() {
node.inputs.push(bias_node.inputs[0].clone());
} else {
node.inputs.push(bias_node.inputs[1].clone());
}

// Rename the output of MatMul node to the output of Add node
node.outputs[0].name = bias_node.outputs[0].name.clone();

// Remove the Add node
nodes_to_remove.push(bias_node.name.clone());
};
}
17 changes: 7 additions & 10 deletions burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,22 +537,19 @@ impl ONNXGraph {
/// * `node` - The node where value are stored.
#[track_caller]
fn extract_data_serialize<E: Element>(input_index: usize, node: &Node) -> Option<DataSerialize<E>> {
if node.inputs.is_empty() || node.inputs.get(input_index).unwrap().value.is_none() {
if node.inputs.is_empty() {
return None;
}

let ty = node.inputs.get(input_index).unwrap().ty.clone();
let input = node.inputs.get(input_index);
input?;
let input = input.unwrap();
input.value.as_ref()?;
let ty = input.ty.clone();

match ty {
ArgType::Tensor(tensor_type) => {
let value = node
.inputs
.get(input_index)
.unwrap()
.value
.as_ref()
.expect("Value to be provided.")
.clone();
let value = input.value.as_ref().expect("Value to be provided.").clone();

Some(serialize_data(
value.clone(),
Expand Down