Skip to content

Commit

Permalink
Support count_include_pad attr in avg_pool2d ONNX (tracel-ai#978)
Browse files Browse the repository at this point in the history
  • Loading branch information
antimora authored Nov 21, 2023
1 parent cb616ed commit 445f41b
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 22 deletions.
2 changes: 1 addition & 1 deletion burn-core/src/nn/pool/avg_pool2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub struct AvgPool2dConfig {
pub padding: PaddingConfig2d,
/// If the padding is counted in the denominator when computing the average.
#[config(default = "true")]
count_include_pad: bool,
pub count_include_pad: bool,
}

/// Applies a 2D avg pooling over input tensors.
Expand Down
Binary file modified burn-import/onnx-tests/tests/avg_pool2d/avg_pool2d.onnx
Binary file not shown.
38 changes: 24 additions & 14 deletions burn-import/onnx-tests/tests/avg_pool2d/avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@ class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

# TODO when https://github.com/burn-rs/burn/issues/636 is resolved, test this with a model
# that uses `count_include_pad=False` and padding=(2, 1)
self.pool2d = nn.AvgPool2d((4, 2), stride=(
2, 1), padding=(0, 0), count_include_pad=False)
self.pool2d1 = nn.AvgPool2d((4, 2), stride=(
2, 1))

def forward(self, x):
x = self.pool2d(x)
return x
self.pool2d2 = nn.AvgPool2d((4, 2), stride=(
2, 1), padding=(2, 1), count_include_pad=True)

self.pool2d3 = nn.AvgPool2d((4, 2), stride=(
2, 1), padding=(2, 1), count_include_pad=False)

def forward(self, x1, x2, x3):
y1 = self.pool2d1(x1)
y2 = self.pool2d2(x2)
y3 = self.pool2d3(x3)
return y1, y2, y3


def main():
Expand All @@ -33,18 +39,22 @@ def main():
device = torch.device("cpu")

file_name = "avg_pool2d.onnx"
test_input = torch.randn(1, 1, 5, 5, device=device)
torch.onnx.export(model, test_input, file_name,
input1 = torch.randn(1, 1, 5, 5, device=device)
torch.onnx.export(model, (input1, input1, input1), file_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data shape of ones: {}".format(test_input.shape))
print("Test input data of ones: {}".format(test_input))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))
print("Test output: {}".format(output))
print("Test input data shape: {}".format(input1.shape))
print("Test input data: {}".format(input1))
output1, output2, output3 = model.forward(input1, input1, input1)
print("Test output1 data shape: {}".format(output1.shape))
print("Test output2 data shape: {}".format(output2.shape))
print("Test output3 data shape: {}".format(output3.shape))
print("Test output1: {}".format(output1))
print("Test output2: {}".format(output2))
print("Test output3: {}".format(output3))


if __name__ == '__main__':
Expand Down
26 changes: 23 additions & 3 deletions burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,30 @@ mod tests {
[-1.805, -0.476, 0.205, 0.338, 1.353],
[0.374, 0.013, 0.774, -0.109, -0.271],
]]]);
let output = model.forward(input);
let expected = Data::from([[[[0.008, -0.131, -0.208, 0.425]]]]);
let (output1, output2, output3) = model.forward(input.clone(), input.clone(), input);
let expected1 = Data::from([[[[0.008, -0.131, -0.208, 0.425]]]]);
let expected2 = Data::from([[[
[-0.045, 0.202, -0.050, -0.295, 0.162, 0.160],
[-0.176, 0.008, -0.131, -0.208, 0.425, 0.319],
[-0.084, -0.146, 0.017, 0.170, 0.216, 0.125],
]]]);
let expected3 = Data::from([[[
[-0.182, 0.404, -0.100, -0.590, 0.324, 0.638],
[-0.352, 0.008, -0.131, -0.208, 0.425, 0.638],
[-0.224, -0.195, 0.023, 0.226, 0.288, 0.335],
]]]);

let expected_shape1 = Shape::from([1, 1, 1, 4]);
let expected_shape2 = Shape::from([1, 1, 3, 6]);
let expected_shape3 = Shape::from([1, 1, 3, 6]);

assert_eq!(output1.shape(), expected_shape1);
assert_eq!(output2.shape(), expected_shape2);
assert_eq!(output3.shape(), expected_shape3);

output.to_data().assert_approx_eq(&expected, 3);
output1.to_data().assert_approx_eq(&expected1, 3);
output2.to_data().assert_approx_eq(&expected2, 3);
output3.to_data().assert_approx_eq(&expected3, 3);
}

#[test]
Expand Down
3 changes: 3 additions & 0 deletions burn-import/src/burn/node/avg_pool2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for AvgPool2dNode {
let kernel_size = self.config.kernel_size.to_tokens();
let strides = self.config.strides.to_tokens();
let padding = self.config.padding.to_tokens();
let count_include_pad = self.config.count_include_pad;

let init_line = quote! {
init();
Expand All @@ -60,6 +61,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for AvgPool2dNode {
let #name = AvgPool2dConfig::new(#kernel_size)
.with_strides(#strides)
.with_padding(#padding)
.with_count_include_pad(#count_include_pad)
.#init_line
};

Expand Down Expand Up @@ -137,6 +139,7 @@ mod tests {
let avg_pool2d = AvgPool2dConfig::new([3, 3])
.with_strides([1, 1])
.with_padding(PaddingConfig2d::Valid)
.with_count_include_pad(true)
.init();

Self {
Expand Down
11 changes: 7 additions & 4 deletions burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,26 +129,29 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig {
let mut strides = vec![1, 1];
let mut pads = vec![0, 0, 0, 0];
let mut count_include_pad: i64 = 0;
let mut ceil_mode: i64 = 0;

for (key, value) in curr.attrs.iter() {
match key.as_str() {
"kernel_shape" => kernel_shape = value.clone().into_i64s(),
"strides" => strides = value.clone().into_i64s(),
"pads" => pads = value.clone().into_i64s(),
"count_include_pad" => count_include_pad = value.clone().into_i64(),
"ceil_mode" => ceil_mode = value.clone().into_i64(),
_ => {}
}
}

let padding = padding_config(&pads);

if count_include_pad == 1 && padding != PaddingConfig2d::Valid {
todo!("AvgPool2d: count_include_pad is not supported. See https://github.com/burn-rs/burn/issues/636");
if ceil_mode == 1 {
panic!("ceil_mode is not supported");
}

let padding = padding_config(&pads);

AvgPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize])
.with_strides([strides[0] as usize, strides[1] as usize])
.with_padding(padding)
.with_count_include_pad(count_include_pad == 1)
}

/// Create a FlattenConfig from the attributes of the node
Expand Down

0 comments on commit 445f41b

Please sign in to comment.