Skip to content

Commit

Permalink
Check nonzero stride, dilation and groups (#2540)
Browse files Browse the repository at this point in the history
* Check nonzero stride, dilation and groups

* Fix typos

* Fix another typo
  • Loading branch information
laggui authored Nov 27, 2024
1 parent a6a5c22 commit ec6d853
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 24 deletions.
2 changes: 1 addition & 1 deletion crates/burn-jit/src/fusion/on_write/trace_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl FuseOnWriteTraceBuilder {
let meta = 1;
let inputs = self.inputs.len() as u32;
let outputs = self.output_tensors().len() as u32;
// In the future, scalars could be packed into 1 bufer or into the metadata, but currently take up
// In the future, scalars could be packed into 1 buffer or into the metadata, but currently take up
// one slot per scalar.
let scalar = self.scalars.len() as u32;
meta + inputs + outputs + scalar
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-tensor/src/tensor/api/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2286,7 +2286,7 @@ where
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For performing the modul operation, users should prefer the [Tensor::remainder_scalar](Tensor::remainder_scalar) function,
/// For performing the modulo operation, users should prefer the [Tensor::remainder_scalar](Tensor::remainder_scalar) function,
/// which is more high-level and designed for public use.
fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;

Expand Down
180 changes: 166 additions & 14 deletions crates/burn-tensor/src/tensor/ops/modules/base.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use core::num::NonZeroUsize;

use super::{conv, pool, unfold::unfold4d_using_conv2d};
use crate::{
backend::Backend,
Expand Down Expand Up @@ -84,45 +86,88 @@ pub struct MaxPool2dWithIndices<B: Backend> {
pub indices: IntTensor<B>,
}

/// Check that the parameter value is non-zero.
// NOTE: for now we keep usize but we could refactor the parameters to hold `NonZeroUsize`.
pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
NonZeroUsize::new(value).expect(msg);
value
}

/// Convolution options.
#[derive(new, Debug, Clone, Hash, PartialEq, Eq)]
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct ConvOptions<const N: usize> {
/// Stride.
/// Stride (non-zero).
pub stride: [usize; N],

/// Padding.
pub padding: [usize; N],

/// Dilation.
/// Dilation (non-zero).
pub dilation: [usize; N],

/// Groups.
/// Groups (non-zero).
pub groups: usize,
}

impl<const N: usize> ConvOptions<N> {
/// Constructs a new `ConvOptions`.
pub fn new(
stride: [usize; N],
padding: [usize; N],
dilation: [usize; N],
groups: usize,
) -> Self {
Self {
stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
padding,
dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
groups: check_nonzero(groups, "groups must be non-zero"),
}
}
}

/// Convolution options.
#[derive(new, Debug, Clone, Hash, PartialEq, Eq)]
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct DeformConvOptions<const N: usize> {
/// Stride.
/// Stride (non-zero).
pub stride: [usize; N],

/// Padding.
pub padding: [usize; N],

/// Dilation.
/// Dilation (non-zero).
pub dilation: [usize; N],

/// Weight Groups.
/// Weight Groups (non-zero).
pub weight_groups: usize,

/// Offset Groups.
/// Offset Groups (non-zero).
pub offset_groups: usize,
}

impl<const N: usize> DeformConvOptions<N> {
/// Constructs a new `DeformConvOptions`.
pub fn new(
stride: [usize; N],
padding: [usize; N],
dilation: [usize; N],
weight_groups: usize,
offset_groups: usize,
) -> Self {
Self {
stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
padding,
dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
}
}
}

/// Transposed convolution options.
#[derive(new, Debug, Clone, Hash, PartialEq, Eq)]
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct ConvTransposeOptions<const N: usize> {
/// Stride.
/// Stride (non-zero).
pub stride: [usize; N],

/// Padding.
Expand All @@ -131,15 +176,34 @@ pub struct ConvTransposeOptions<const N: usize> {
/// Padding out.
pub padding_out: [usize; N],

/// Dilation.
/// Dilation (non-zero).
pub dilation: [usize; N],

/// Groups.
/// Groups (non-zero).
pub groups: usize,
}

impl<const N: usize> ConvTransposeOptions<N> {
/// Constructs a new `ConvTransposeOptions`.
pub fn new(
stride: [usize; N],
padding: [usize; N],
padding_out: [usize; N],
dilation: [usize; N],
groups: usize,
) -> Self {
Self {
stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
padding,
padding_out,
dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
groups: check_nonzero(groups, "groups must be non-zero"),
}
}
}

/// Unfold operation options.
#[derive(new, Debug, Clone)]
#[derive(Debug, Clone)]
pub struct UnfoldOptions {
/// The number of positions to slide over the input tensor in each dimension.
/// A stride of `[1, 1]` will slide the kernel one pixel at a time.
Expand All @@ -152,6 +216,17 @@ pub struct UnfoldOptions {
pub dilation: [usize; 2],
}

impl UnfoldOptions {
/// Constructs a new `UnfoldOptions`.
pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
Self {
stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
padding,
dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
}
}
}

/// Algorithm used for upsampling.
#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
pub enum InterpolateMode {
Expand Down Expand Up @@ -690,3 +765,80 @@ pub trait ModuleOps<B: Backend> {
options: InterpolateOptions,
) -> FloatTensor<B>;
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
#[should_panic = "stride must be non-zero"]
fn conv_options_stride_zero() {
let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
}

#[test]
#[should_panic = "dilation must be non-zero"]
fn conv_options_dilation_zero() {
let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
}

#[test]
#[should_panic = "groups must be non-zero"]
fn conv_options_groups_zero() {
let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
}

#[test]
#[should_panic = "stride must be non-zero"]
fn conv_transpose_options_stride_zero() {
let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
}

#[test]
#[should_panic = "dilation must be non-zero"]
fn conv_transpose_options_dilation_zero() {
let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
}

#[test]
#[should_panic = "groups must be non-zero"]
fn conv_transpose_options_groups_zero() {
let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
}

#[test]
#[should_panic = "stride must be non-zero"]
fn deform_conv_options_stride_zero() {
let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
}

#[test]
#[should_panic = "dilation must be non-zero"]
fn deform_conv_options_dilation_zero() {
let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
}

#[test]
#[should_panic = "weight groups must be non-zero"]
fn deform_conv_options_weights_groups_zero() {
let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
}

#[test]
#[should_panic = "offset groups must be non-zero"]
fn deform_conv_options_offset_groups_zero() {
let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
}

#[test]
#[should_panic = "stride must be non-zero"]
fn unfold_options_stride_zero() {
let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
}

#[test]
#[should_panic = "dilation must be non-zero"]
fn unfold_options_dilation_zero() {
let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
}
}
7 changes: 1 addition & 6 deletions crates/burn-tensor/src/tensor/ops/modules/unfold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,7 @@ pub(crate) fn unfold4d_using_conv2d<B: Backend>(
x,
weight,
None,
ConvOptions {
stride: options.stride,
padding: options.padding,
dilation: options.dilation,
groups: 1,
},
ConvOptions::new(options.stride, options.padding, options.dilation, 1),
);

let [batch_size, channels_out, out_height, out_width] = unfolded.shape().dims();
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-tensor/src/tensor/quantization/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub fn pack_i8s_to_u32s(bytes: &[u8]) -> Vec<u32> {
.collect()
}

/// Unpack 32-bit unsiged integer values into a sequence of signed 8-bit integers.
/// Unpack 32-bit unsigned integer values into a sequence of signed 8-bit integers.
///
/// # Note
/// This assumes that the bytes represent `u32` values.
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-train/src/metric/auroc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ mod tests {
let input = AurocInput::new(
Tensor::from_data(
[
[0.1, 0.9], // All positives perdictions
[0.1, 0.9], // All positives predictions
[0.2, 0.8],
[0.3, 0.7],
[0.4, 0.6],
Expand Down

0 comments on commit ec6d853

Please sign in to comment.