Skip to content

Commit

Permalink
feature(tensor): Add unsqueeze_dim helper (tracel-ai#966)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcvz authored Nov 20, 2023
1 parent 20e9066 commit 49e16b6
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 0 deletions.
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.device()` | `tensor.device` |
Expand Down
34 changes: 34 additions & 0 deletions burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,40 @@ where
self.reshape(shape)
}

/// Creates a new tensor with a dimension of size one inserted at the specified position.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Shape};
///
/// fn example<B: Backend>() {
/// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]));
/// let tensor: Tensor<B, 3> = tensor.unsqueeze_dim(1);
/// println!("{:?}", tensor.shape());
/// // Shape { dims: [3, 1, 3] }
/// }
/// ```
pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
check!(TensorCheck::unsqueeze_dim::<{ D }>(dim));

let mut dims = [1; D2];
let shape = self.shape();

dims[0..dim].copy_from_slice(&shape.dims[0..dim]);

if dim < D {
dims[dim] = 1;
dims[(dim + 1)..].copy_from_slice(&shape.dims[dim..]);
} else {
dims[dim] = 1;
}

let shape = Shape::new(dims);
self.reshape(shape)
}

/// Returns a tensor containing the elements selected from the given ranges.
///
/// # Panics
Expand Down
15 changes: 15 additions & 0 deletions burn-tensor/src/tensor/api/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,21 @@ impl TensorCheck {
check
}

pub(crate) fn unsqueeze_dim<const D: usize>(dim: usize) -> Self {
let mut check = Self::Ok;
if dim > D {
check = check.register(
"Unsqueeze",
TensorError::new(format!(
"Can't unsqueeze at dimension {}, exceeds tensor dimensions (D={})",
dim, D
)),
);
}

check
}

pub(crate) fn swap_dims<const D: usize>(dim1: usize, dim2: usize) -> Self {
let mut check = Self::Ok;

Expand Down
35 changes: 35 additions & 0 deletions burn-tensor/src/tests/ops/squeeze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,39 @@ mod tests {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
let squeezed_tensor: Tensor<TestBackend, 3> = tensor.squeeze(2);
}

/// Test if the function can successfully unsqueeze the size 1 dimension at the specified position of a 3D tensor.
#[test]
fn should_unsqueeze_dim() {
let tensor = Tensor::<TestBackend, 3>::ones(Shape::new([2, 4, 1]));
let unsqueezed_tensor: Tensor<TestBackend, 4> = tensor.unsqueeze_dim(1);
let expected_shape = Shape::new([2, 1, 4, 1]);
assert_eq!(unsqueezed_tensor.shape(), expected_shape);
}

/// Test if the function can successfully unsqueeze the first size 1 dimension of a 4D tensor.
#[test]
fn should_unsqueeze_dim_first() {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
let unsqueezed_tensor: Tensor<TestBackend, 5> = tensor.unsqueeze_dim(0);
let expected_shape = Shape::new([1, 2, 3, 4, 5]);
assert_eq!(unsqueezed_tensor.shape(), expected_shape);
}

/// Test if the function can successfully unsqueeze the last size 1 dimension of a 4D tensor.
#[test]
fn should_unsqueeze_dim_last() {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([5, 4, 3, 2]));
let unsqueezed_tensor: Tensor<TestBackend, 5> = tensor.unsqueeze_dim(4);
let expected_shape = Shape::new([5, 4, 3, 2, 1]);
assert_eq!(unsqueezed_tensor.shape(), expected_shape);
}

/// Test if the function panics when the unsqueezed dimension is out of bounds.
#[test]
#[should_panic]
fn should_unsqueeze_dim_panic() {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
let unsqueezed_tensor: Tensor<TestBackend, 5> = tensor.unsqueeze_dim(5);
}
}

0 comments on commit 49e16b6

Please sign in to comment.