Skip to content

Commit

Permalink
Feat/squeeze dims (tracel-ai#1779)
Browse files Browse the repository at this point in the history
  • Loading branch information
agelas authored May 22, 2024
1 parent 76fe0ed commit 81ecd14
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 31 deletions.
6 changes: 3 additions & 3 deletions crates/burn-import/src/burn/node/squeeze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for SqueezeNode {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;

let axis = &self.axes.first().unwrap().to_tokens();
let axes_arg = &self.axes.to_tokens();

quote! {
let #output = #input.squeeze(#axis);
let #output = #input.squeeze_dims(&#axes_arg);
}
}

Expand Down Expand Up @@ -81,7 +81,7 @@ mod tests {
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 3>) -> Tensor<B, 2> {
let tensor2 = tensor1.squeeze(1);
let tensor2 = tensor1.squeeze_dims(&[1]);
tensor2
}
}
Expand Down
29 changes: 2 additions & 27 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ pub fn transpose_config(curr: &Node) -> Vec<i64> {
}

pub fn squeeze_config(curr: &Node) -> Vec<i64> {
let mut axes = curr
let axes = curr
.attrs
.iter()
.filter_map(|(key, value)| {
Expand All @@ -986,35 +986,10 @@ pub fn squeeze_config(curr: &Node) -> Vec<i64> {
.next()
.unwrap_or_else(Vec::new);

// If axes are not found in attributes, try to extract them from input tensor
if axes.is_empty() {
assert!(!curr.inputs.is_empty(), "Squeeze: input must be present");

let input_value = &curr.inputs[1];
match &input_value.ty {
ArgType::Tensor(tensor) => {
assert_eq!(tensor.dim, 1, "Squeeze: axes tensor must be 1D");
if let Some(Data::Int64s(data)) = &input_value.value {
axes.clone_from(data)
} else {
panic!("Squeeze: Tensor data type must be int64");
}
}
_ => panic!("Squeeze: Argument for axes must be a tensor"),
}
}

let tensor = match curr.inputs.first().unwrap().clone().ty {
match curr.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

// Adjust negative axes
axes.iter_mut().for_each(|x| {
if *x < 0 {
*x += tensor.dim as i64;
}
});

axes
}
89 changes: 89 additions & 0 deletions crates/burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,95 @@ where
Tensor::new(K::reshape::<D, D2>(self.primitive, new_dims.into()))
}

/// Removes specified dimensions of size 1 from a tensor's shape. This function takes a tensor and
/// an array of dimensions (`dims`) to be squeezed. If `dims` is provided, only the dimensions
/// specified in this array will be removed. Each dimension in `dims` should correspond to a size of 1
/// in the tensor; otherwise, the dimension will not be squeezed. If `dims` is empty, all single-dimensional entries
/// in the tensor will be removed. If entries in `dims` are negative, then dimensions will be counted
/// from the back.
///
/// # Arguments
///
/// - `dims`: The dimension(s) to be squeezed.
///
/// # Type Parameters
///
/// - 'D2': The resulting number of dimensions in the squeezed tensor.
///
/// # Returns
///
/// A new `Tensor<B, D2, K>` instance with the specified dimensions removed.
///
/// # Example
///
/// ```rust
///
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Shape};
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let tensor = Tensor::<B, 4>::ones(Shape::new([2, 1, 4, 1]), &device);
///
/// // Given a 4D tensor with dimensions (2, 1, 4, 1), squeeze the 1 and 3 dimensions
/// let squeezed_tensor: Tensor::<B, 2> = tensor.squeeze_dims(&[1, 3]);
///
/// // Resulting tensor will have dimensions (2, 4)
/// println!("{:?}", squeezed_tensor.shape());
/// }
/// ```
pub fn squeeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> {
let current_dims = self.shape().dims;
let mut dim_indices: Vec<usize>;

// Check if dims is empty, if yes then assign dim_indices all single-dimensional entries
if dims.is_empty() {
dim_indices = current_dims
.iter()
.enumerate()
.filter_map(|(index, &dim)| if dim == 1 { Some(index) } else { None })
.collect();
} else {
// If negative dims, count from the back
dim_indices = dims
.iter()
.map(|&d| {
if d < 0 {
(current_dims.len() as isize + d) as usize
} else {
d as usize
}
})
.collect();
}

// Sort indices and remove duplicates
dim_indices.sort_unstable();
dim_indices.dedup();

// Make sure squeeze_dims doesn't result in a tensor with < 1 dimensions
check!(TensorCheck::squeeze_dims_input::<D2>(
&dim_indices,
&current_dims
));

// Calculate new dimensions
let mut new_dims = Vec::new();
for (index, &dim_size) in current_dims.iter().enumerate() {
// Exclude the dimension if it's explicitly marked for squeezing
if dim_indices.contains(&index) {
check!(TensorCheck::squeeze::<D2>(index, &current_dims));
continue;
}
new_dims.push(dim_size);
}

// Check that after squeezing, we still respect the D2 size
check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));

Tensor::new(K::reshape::<D, D2>(self.primitive, new_dims.into()))
}

/// Unsqueeze the current tensor. Create new dimensions to fit the given size.
///
/// If the output size is higher than the current tensor.
Expand Down
32 changes: 31 additions & 1 deletion crates/burn-tensor/src/tensor/api/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,36 @@ impl TensorCheck {
check
}

pub(crate) fn squeeze_dims_input<const D2: usize>(
dim_indices: &[usize],
current_dims: &[usize],
) -> Self {
let mut check = Self::Ok;
if dim_indices.len() >= current_dims.len() {
check = check.register(
"Squeeze",
TensorError::new("Attempted to squeeze too many dimensions!"),
);
}

check
}

pub(crate) fn squeeze_dims_len<const D2: usize>(new_dims_len: usize) -> Self {
let mut check = Self::Ok;
if new_dims_len != D2 {
check = check.register(
"Squeeze",
TensorError::new(format!(
"Resulting dimensions {} do not match the required D2 size {}.",
new_dims_len, D2
)),
);
}

check
}

pub(crate) fn unsqueeze<const D1: usize, const D2: usize>() -> Self {
let mut check = Self::Ok;
if D2 < D1 {
Expand Down Expand Up @@ -283,7 +313,7 @@ impl TensorCheck {
//contains is right exclusive, so this is to spec
if !(-output_rank..output_rank).contains(&dim) {
check = check.register(
"Unsqeeze",
"Unsqueeze",
TensorError::new(format!(
"unsqueeze arg {} is out of range for the output tensor of rank {}",
dim, output_rank
Expand Down
53 changes: 53 additions & 0 deletions crates/burn-tensor/src/tests/ops/squeeze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,59 @@ mod tests {
let squeezed_tensor: Tensor<TestBackend, 3> = tensor.squeeze(2);
}

/// Test if the function works with an empty slice
#[test]
fn should_squeeze_dims_with_empty_slice() {
let tensor = Tensor::<TestBackend, 3>::ones(Shape::new([1, 1, 3]), &Default::default());
let squeezed_tensor: Tensor<TestBackend, 1> = tensor.squeeze_dims(&[]);
let expected_shape = Shape::new([3]);
assert_eq!(squeezed_tensor.shape(), expected_shape);
}

/// Test if the function works with positive indices
#[test]
fn should_squeeze_dims_with_positive_indices() {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([1, 3, 1, 5]), &Default::default());
let squeezed_tensor: Tensor<TestBackend, 2> = tensor.squeeze_dims(&[0, 2]);
let expected_shape = Shape::new([3, 5]);
assert_eq!(squeezed_tensor.shape(), expected_shape);
}

/// Test if the function works with negative indices
#[test]
fn should_squeeze_dims_with_negative_indices() {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 1, 3, 1]), &Default::default());
let squeezed_tensor: Tensor<TestBackend, 2> = tensor.squeeze_dims(&[-3, -1]);
let expected_shape = Shape::new([2, 3]);
assert_eq!(squeezed_tensor.shape(), expected_shape);
}

/// Test to make sure the function panics if a non-singleton dimension is squeezed
#[test]
#[should_panic]
fn should_squeeze_dims_work_if_non_singleton() {
let tensor = Tensor::<TestBackend, 3>::ones(Shape::new([2, 3, 4]), &Default::default());
let squeezed_tensor: Tensor<TestBackend, 3> = tensor.squeeze_dims(&[1]);
let expected_shape = Shape::new([2, 3, 4]);
assert_eq!(squeezed_tensor.shape(), expected_shape);
}

/// Test to make sure the function panics if too many dimensions are requested to be squeezed
#[test]
#[should_panic]
fn should_squeeze_dims_panic_on_too_many_dimensions() {
let tensor = Tensor::<TestBackend, 3>::ones(Shape::new([1, 1, 1]), &Default::default());
let _: Tensor<TestBackend, 1> = tensor.squeeze_dims(&[0, 1, 2]);
}

/// Test to make sure function panics if dimensions are mismatched
#[test]
#[should_panic]
fn should_squeeze_dims_dimension_mismatch_panic() {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([1, 3, 1, 5]), &Default::default());
let _: Tensor<TestBackend, 3> = tensor.squeeze_dims(&[0, 2]);
}

/// Test if the function can successfully unsqueeze the size 1 dimension at the specified position of a 3D tensor.
#[test]
fn should_unsqueeze_dim() {
Expand Down

0 comments on commit 81ecd14

Please sign in to comment.