From d8e7aa6c54942bcd43f1549b709d245b5f6ca1bc Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Mon, 28 Aug 2023 07:40:45 -0500 Subject: [PATCH 1/2] Account for additional cases in Linear node conversion --- .../onnx-tests/tests/linear/linear.onnx | Bin 2415 -> 1244 bytes burn-import/onnx-tests/tests/linear/linear.py | 46 +++++++--- burn-import/onnx-tests/tests/onnx_tests.rs | 31 +++++-- burn-import/src/onnx/coalesce.rs | 86 ++++++++++++++++-- burn-import/src/onnx/to_burn.rs | 17 ++-- 5 files changed, 140 insertions(+), 40 deletions(-) diff --git a/burn-import/onnx-tests/tests/linear/linear.onnx b/burn-import/onnx-tests/tests/linear/linear.onnx index 91f9958e67934e8e06b00e380b8bf10c23ecac46..17ddcf7106ee4bc7be6e43fe0457e67893791713 100644 GIT binary patch literal 1244 zcma)+e@IhN6vy}G)aS+7dRZ+EU1kw0<~B9Nd?#I$u@W+@%&7IYmYY6x&Sx6^prjwm zkSr_0%0^^VCPCS6-)Ut*M8BwI6#bC;qaPGSzlezFEh}{z^vAuId+zsq&N=sRFHJMy zuH{0R+vbP~HR|*_LwZL9*+!AbveMG(^z>|dNr^>I)V#CQ&bbYiDyQJE6iP;2mEBqB z5Xi9pCsrp{!BA;2!;|`*B?Wj1iomQLL_#^<<=}?OWNRjSDM}kf6jr;?vsaQZ!OfLc znADV)k`ptD>dq;L6LKnfi-8R4yH8={27OHI@2JM^TuV5iq>@kS_DKlqUI`z>G6u6! znZ)0mH%2JMHXftWLCxiZVn>CLvjzj*=B@yz^ zpl{EU^Iy(w#qy~;(A#*VeVp6DYPV8!C>oNXn z5gxE@g0@xdBJcF$vZC!c@$O~td}?HWr0>Qz^IAY0Q;+_evr&^j6Y8$r!<6D|6ub>$ z-qRQ2+^_p^`lp3(-QA3yIb%WZ&tgBUHDbNH6K5UuW%!5ksQpIK;lBzC$|=>{MG>O_ zyApHR+V~{UEI7e#S$YVZ?_0ps84Y^}ufRobtuWQM7HSU0LSw88Dtt$vIk5rMV8`Q` zS@6AiJQ|rcit)41iB6i}X>O7d#sl=|LDAuvGMr`%e!u>~!@D=BSuJ{c|7funM#Hyin zlpDPrtHs*WDwwQ11siscfYx`}s43?&Zp`(<#`$fi8gIY|&jrl2*FZsB-%o*wB9S!3 z&_qrvq!Ch}33^n9UJALS2rWnqQWzPAs3b;Yz+V~xRk(Esq0)M}2H*;7G?5;6FKa*= QW>rhp~Oq9>>2&(s3qEh%-|=Nh(PavV7lf#~DJRjx%JZLP9E;PHHG;)ot2Tie$Fy z%DUD3XlGf{Ip_0c&B|oLJeu3otd*>}+nK9%m~`9y2I23d5q zzH>YRvzjj%BJ1p z*coKHxwyYzAG-S7fVS5}FnMMXw&)$m)P=AuHzxw8D|`+t-b;> z4_l!$P{3p<&O8^q9k}k<7g3oD3>iV~gswZ;pc-z2fpxq1 zj_xtm8!`g1$9GYc-jVo48aA9d##FN!n676naXssp>0SclTyrCtkHWyom!q+;$DpbE zA)uvRYY;;;Bwlkt-M(zJC~?77tDTrQtJIKc+DtOt27Ay4hg_>LlsjCt|Qx#A`>@*zxY9nN%3n<-edo)+*@rujZ5 zps8$#HQjB26&b12;`ffBdQKAhf78YM7Nw#8wkvR9S{kWMbI=$wY{)uxpV>D?qVl^a z(q8?RZ%K(|4qbjI96kyi=D&fg`%~ZobJ&`MIL19uP-LSB-Q=rKwQvd6mH1HN1wXzx z$A=1st;ow`Ci#qAfYu`t;&d1FRnZ>2xy3vjKM+L0%}aE%jIIMEL z0G(%dp&tnKLugDADxdq&N|v^ng1C_rtw=AQ05kcn%ZZ)G;%X~z2}VA`sR?b`T@8; z@u3f%^)cD0GB8?y1S8fmT+nldalPHI`38dhps24;hOwl>}$LNtrIn@ z*X|fAm!E>WSqfD9z6;K>5jIh8WO7k9X#>{7`HyU94W`g|ZX$LCx>48fLx!-APr$%l zGk!eCgFJqYqVnmH=+*27p~*@-f2Nk1ZkB<>J%c6A_cJISy+m~j>OnJb8e~t;>sumI zS$yX#vbS7{r6-W;n!91!!a8PU@q`VGE(e$2%}MslRuD_Vq1_Ite4jg{*xd$IO*_O7 za;)OcS}MLhi>>MT1U$NROdWp}4&}R}u68PEO6CN(HhvAo2dnsE+gd*B$}@gk7ES{r zRs5$ROEKl+RP4)fK#9L4=t7aEw{OM#+YHt1#TM3eJUu{iMwa7SOj47HRB@1KC-{025*5r*+{ ze`bog4w`$uV7m1;7PH*A zrJg}eZ%Wz44eL;b3wfE9FO^DbS^xMuOnNMYX(vDDt3TU@($%l?MdI1S8Q1dh1&?9E z^b`i3Y#}YboHZd*Va{!qxTgsuH+*SAL2xL19eTwa;*FK=Bz@yD>#NuZq4Qs*;q{Z? znvqTl#~f;*J(RAGX2q@^m>ro-7uuBAakGOBjh+N)`wCwCO)3pVDKYaKOVs4-!?M=| z7&g9u!p^;h(zz;L(=h^d!K#S->1u8C_#J(Sb$2LopoU;b2@I{o6 z7fSaZ|Cvdf2O#%gK05i@q4qC>2JT=37~S$H`D_C54}XR-ZU%WMego5GA4+-eDGR%! zpoCLXG4Vo|q0~GC(=%sd|GDWXZJffFMpm-)vm2PiQ3z+gdJ6sPmeHfzu{cnF*HCdF zmRk3VDC0mfs%K|ljdVL@9`!-fDh^F=JECi2Jf1aKlB)hRn7`c(LlLLg`M6r}h%RA? zp|jEbxIidoOy{Z@oJSl8e4A|&Ue2LO=CGs zt8Re7H(nv@!w#hWRX~x!TQTK}!;CX}>pE;q%y_H-t*7N;qHPd~3!|y`1vzoTN?q&z z{Vb|@J`H?$nMrmB!RTjGscMNWF{0mA91psNU-ex Kj#9aqSo}}8Swd6* diff --git a/burn-import/onnx-tests/tests/linear/linear.py b/burn-import/onnx-tests/tests/linear/linear.py index e72a98b0f0..5f0c1c23c2 100755 --- a/burn-import/onnx-tests/tests/linear/linear.py +++ b/burn-import/onnx-tests/tests/linear/linear.py @@ -10,14 +10,22 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() - # When input's 2d with [1, n] shape and bias is True, Gemm ONNX node is used - self.fc1 = nn.Linear(16, 32, bias=True) + # Case 1: When input's 2d with [1, n] shape and bias is True. + # This will produce a single Gemm ONNX node with alpha=1 and beta=1 and transB=1 attributes. + # with 3 inputs and 1 output. The 3 inputs are the input, weight, and bias. + self.linear1_with_gemm = nn.Linear(3, 4, bias=True) - # TODO Test other cases that use matmul instead of Gemm + # Case 2: When input >= 2D but linear does not have bias. + self.linear2_with_matmul = nn.Linear(5, 6, bias=False) - def forward(self, x): - x = self.fc1(x) - return x + # Case 3: When input > 2D and linear does have bias or does not have bias (doesn't matter). + self.linear3_with_matmul = nn.Linear(7, 8, bias=True) + + def forward(self, x1, x2, x3): + y1 = self.linear1_with_gemm(x1) + y2 = self.linear2_with_matmul(x2) + y3 = self.linear3_with_matmul(x3) + return y1, y2, y3 def main(): @@ -33,20 +41,32 @@ def main(): print("Made model") file_name = "linear.onnx" - test_input = torch.full((1, 16), 3.141592, device=device) - torch.onnx.export(model, test_input, file_name, + input1 = torch.full((4, 3), 3.14, device=device) + input2 = torch.full((2, 5), 3.14, device=device) + input3 = torch.full((3, 2, 7), 3.14, device=device) + torch.onnx.export(model, (input1, input2, input3), file_name, verbose=False, opset_version=16) print("Finished exporting model to {}".format(file_name)) # Output some test data for use in the test - print("Test input data shape of ones: {}".format(test_input.shape)) - output = model.forward(test_input) - print("Test output data shape: {}".format(output.shape)) + print("Test input1 data shape: {}".format(input1.shape)) + print("Test input2 data shape: {}".format(input2.shape)) + print("Test input3 data shape: {}".format(input3.shape)) + + output1, output2, output3 = model.forward(input1, input2, input3) + + print("Test output1 data shape: {}".format(output1.shape)) + print("Test output2 data shape: {}".format(output2.shape)) + print("Test output3 data shape: {}".format(output3.shape)) - sum = output.sum().item() + sum1 = output1.sum().item() + sum2 = output2.sum().item() + sum3 = output3.sum().item() - print("Test output sum: {}".format(sum)) + print("Test output sum1: {}".format(sum1)) + print("Test output sum2: {}".format(sum2)) + print("Test output sum3: {}".format(sum3)) if __name__ == "__main__": diff --git a/burn-import/onnx-tests/tests/onnx_tests.rs b/burn-import/onnx-tests/tests/onnx_tests.rs index a8c84de84c..8992141ec6 100644 --- a/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/burn-import/onnx-tests/tests/onnx_tests.rs @@ -507,22 +507,33 @@ mod tests { fn linear() { // Initialize the model with weights (loaded from the exported file) let model: linear::Model = linear::Model::default(); + let input1 = Tensor::::full([4, 3], 3.14); + let input2 = Tensor::::full([2, 5], 3.14); + let input3 = Tensor::::full([3, 2, 7], 3.14); - // Run the model with 3.1416 as input for easier testing - let input = Tensor::::full([1, 16], 3.1416); - - let output = model.forward(input); + let (output1, output2, output3) = model.forward(input1, input2, input3); // test the output shape - let expected_shape: Shape<2> = Shape::from([1, 32]); - assert_eq!(output.shape(), expected_shape); + let expected_shape1: Shape<2> = Shape::from([4, 4]); + let expected_shape2: Shape<2> = Shape::from([2, 6]); + let expected_shape3: Shape<3> = Shape::from([3, 2, 8]); + assert_eq!(output1.shape(), expected_shape1); + assert_eq!(output2.shape(), expected_shape2); + assert_eq!(output3.shape(), expected_shape3); // We are using the sum of the output tensor to test the correctness of the conv1d node // because the output tensor is too large to compare with the expected tensor. - let output_sum = output.sum().into_scalar(); - let expected_sum = -3.205_825; // from pytorch - println!("output_sum: {}", output_sum); - assert!(expected_sum.approx_eq(output_sum, (1.0e-5, 2))); + let output_sum1 = output1.sum().into_scalar(); + let output_sum2 = output2.sum().into_scalar(); + let output_sum3 = output3.sum().into_scalar(); + + let expected_sum1 = -9.655_477; // from pytorch + let expected_sum2 = -8.053_822; // from pytorch + let expected_sum3 = 27.575_281; // from pytorch + + assert!(expected_sum1.approx_eq(output_sum1, (1.0e-6, 2))); + assert!(expected_sum2.approx_eq(output_sum2, (1.0e-6, 2))); + assert!(expected_sum3.approx_eq(output_sum3, (1.0e-6, 2))); } #[test] diff --git a/burn-import/src/onnx/coalesce.rs b/burn-import/src/onnx/coalesce.rs index 26e208ac6b..fecf165cdf 100644 --- a/burn-import/src/onnx/coalesce.rs +++ b/burn-import/src/onnx/coalesce.rs @@ -1,23 +1,32 @@ -use crate::onnx::ir::{ArgType, Data, TensorType}; +use std::{iter::Peekable, slice::IterMut}; use super::ir::{AttributeValue, Node, NodeType}; +use crate::onnx::ir::{ArgType, Data, TensorType}; /// The function transforms the graph into a new one where the nodes are coalesced into a single node. pub fn coalesce(nodes: &mut Vec) { - for node in nodes.iter_mut() { + let mut iter_mut = nodes.iter_mut().peekable(); + let mut nodes_to_remove: Vec = vec![]; + while let Some(node) = iter_mut.next() { match node.node_type { - NodeType::Gemm => convert_gemm(node), - // TODO Account when linear is converted into MatMul and Add nodes + NodeType::Gemm => convert_gemm_to_linear(node), + NodeType::MatMul => { + convert_matmul_to_linear(node, &mut iter_mut, &mut nodes_to_remove); + } _ => {} } } + + // Remove nodes instructed by conversation functions + for node_to_remove in nodes_to_remove { + nodes.retain(|n| n.name != node_to_remove); + } } /// This function converts a Gemm node into a Linear node /// -/// Warning: This function is not complete yet. -/// It only supports the case where the Gemm node is a straight linear transformation. -fn convert_gemm(node: &mut Node) { +/// PyTorch and other frameworks use Gemm node to represent Linear layer. +fn convert_gemm_to_linear(node: &mut Node) { if node.outputs.len() != 1 { panic!("Gemm node must have 1 output"); } @@ -101,3 +110,66 @@ fn transpose_flattened(matrix: Vec, rows: usize, cols: usize) -> Vec transposed } + +/// This function converts a MatMul node into a Linear node if possible. +/// +/// PyTorch and other frameworks use MatMul node to represent Linear layer. +/// +/// This function also converts the following Add node into a Linear node if possible. +/// Add node is used to represent bias in PyTorch. +fn convert_matmul_to_linear( + node: &mut Node, + iter_mut: &mut Peekable>, + nodes_to_remove: &mut Vec, +) { + if node.inputs.len() != 2 { + panic!("MatMul node must have 2 inputs"); + } + + // Do not convert if the second input does not have a value, and + // treat it as a normal MatMul node + if node.inputs[1].value.is_none() { + return; + } + + let weight = node.inputs[1] + .clone() + .into_tensor() + .expect("Tensor input is expected"); + + assert_eq!(weight.dim, 2, "Weight must be a 2D tensor"); + + // Convert the node to Linear + node.node_type = NodeType::Linear; + + // The following block of code is used to convert the following Add node into this Linear node + // Add node is used to represent bias in PyTorch. + let peek_node = iter_mut.peek(); // Peek the next node + if peek_node.is_some() + && peek_node.unwrap().node_type == NodeType::Add + && peek_node.unwrap().inputs.len() == 2 + + // Make sure the Add node has a value in one of its inputs and + // the other input is the output of this MatMul node + && (peek_node.unwrap().inputs[0].name == node.outputs[0].name + && peek_node.unwrap().inputs[1].value.is_some()) + | (peek_node.unwrap().inputs[1].name == node.outputs[0].name + && peek_node.unwrap().inputs[0].value.is_some()) + { + // Proceed iteration + let bias_node = iter_mut.next().unwrap(); + + // Copy input value from one of the inputs of the Add node + if bias_node.inputs[0].value.is_some() { + node.inputs.push(bias_node.inputs[0].clone()); + } else { + node.inputs.push(bias_node.inputs[1].clone()); + } + + // Rename the output of MatMul node to the output of Add node + node.outputs[0].name = bias_node.outputs[0].name.clone(); + + // Remove the Add node + nodes_to_remove.push(bias_node.name.clone()); + }; +} diff --git a/burn-import/src/onnx/to_burn.rs b/burn-import/src/onnx/to_burn.rs index f2b4973dcb..54d895f824 100644 --- a/burn-import/src/onnx/to_burn.rs +++ b/burn-import/src/onnx/to_burn.rs @@ -537,22 +537,19 @@ impl ONNXGraph { /// * `node` - The node where value are stored. #[track_caller] fn extract_data_serialize(input_index: usize, node: &Node) -> Option> { - if node.inputs.is_empty() || node.inputs.get(input_index).unwrap().value.is_none() { + if node.inputs.is_empty() { return None; } - let ty = node.inputs.get(input_index).unwrap().ty.clone(); + let input = node.inputs.get(input_index); + input?; + let input = input.unwrap(); + input.value.as_ref()?; + let ty = input.ty.clone(); match ty { ArgType::Tensor(tensor_type) => { - let value = node - .inputs - .get(input_index) - .unwrap() - .value - .as_ref() - .expect("Value to be provided.") - .clone(); + let value = input.value.as_ref().expect("Value to be provided.").clone(); Some(serialize_data( value.clone(), From 2c540a2271557ffe687d32e38200a5fd686d2ba1 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Mon, 28 Aug 2023 12:50:30 -0500 Subject: [PATCH 2/2] Small refactor per PR feedback --- burn-import/src/onnx/coalesce.rs | 75 +++++++++++++++++--------------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/burn-import/src/onnx/coalesce.rs b/burn-import/src/onnx/coalesce.rs index fecf165cdf..623d7584f2 100644 --- a/burn-import/src/onnx/coalesce.rs +++ b/burn-import/src/onnx/coalesce.rs @@ -126,50 +126,57 @@ fn convert_matmul_to_linear( panic!("MatMul node must have 2 inputs"); } - // Do not convert if the second input does not have a value, and - // treat it as a normal MatMul node + // if the second input does not have a value, it is not a weight, then proceed to the next node if node.inputs[1].value.is_none() { return; } - let weight = node.inputs[1] - .clone() - .into_tensor() - .expect("Tensor input is expected"); - - assert_eq!(weight.dim, 2, "Weight must be a 2D tensor"); + // Check if the second input is a 2D tensor + if let ArgType::Tensor(ref tensor_type) = node.inputs[1].ty { + assert_eq!(tensor_type.dim, 2, "Weight must be a 2D tensor"); + } else { + panic!("Tensor input is expected"); + } // Convert the node to Linear node.node_type = NodeType::Linear; - // The following block of code is used to convert the following Add node into this Linear node - // Add node is used to represent bias in PyTorch. - let peek_node = iter_mut.peek(); // Peek the next node - if peek_node.is_some() - && peek_node.unwrap().node_type == NodeType::Add - && peek_node.unwrap().inputs.len() == 2 - - // Make sure the Add node has a value in one of its inputs and - // the other input is the output of this MatMul node - && (peek_node.unwrap().inputs[0].name == node.outputs[0].name - && peek_node.unwrap().inputs[1].value.is_some()) - | (peek_node.unwrap().inputs[1].name == node.outputs[0].name - && peek_node.unwrap().inputs[0].value.is_some()) - { - // Proceed iteration - let bias_node = iter_mut.next().unwrap(); - - // Copy input value from one of the inputs of the Add node - if bias_node.inputs[0].value.is_some() { - node.inputs.push(bias_node.inputs[0].clone()); - } else { - node.inputs.push(bias_node.inputs[1].clone()); + // Check the next node for potential conversion + if let Some(peek_node) = iter_mut.peek() { + if is_add_node_with_bias(peek_node, node) { + convert_and_remove_add_node(iter_mut, nodes_to_remove, node); } + } +} - // Rename the output of MatMul node to the output of Add node - node.outputs[0].name = bias_node.outputs[0].name.clone(); +/// Helper function to check if the peeked node is an Add node with bias +fn is_add_node_with_bias(peek_node: &Node, current_node: &Node) -> bool { + peek_node.node_type == NodeType::Add + && peek_node.inputs.len() == 2 + && ((peek_node.inputs[0].name == current_node.outputs[0].name + && peek_node.inputs[1].value.is_some()) + || (peek_node.inputs[1].name == current_node.outputs[0].name + && peek_node.inputs[0].value.is_some())) +} - // Remove the Add node - nodes_to_remove.push(bias_node.name.clone()); +/// Helper function to convert and remove the Add node +fn convert_and_remove_add_node( + iter_mut: &mut Peekable>, + nodes_to_remove: &mut Vec, + current_node: &mut Node, +) { + let bias_node = iter_mut.next().unwrap(); + + let bias_input = if bias_node.inputs[0].value.is_some() { + bias_node.inputs[0].clone() + } else { + bias_node.inputs[1].clone() }; + + // Push the bias input and update the output name + current_node.inputs.push(bias_input); + current_node.outputs[0].name = bias_node.outputs[0].name.clone(); + + // Remove the Add node + nodes_to_remove.push(bias_node.name.clone()); }