Skip to content

Commit

Permalink
Feat/Split Operator (#2490)
Browse files Browse the repository at this point in the history
  • Loading branch information
agelas authored Nov 21, 2024
1 parent b4e8e45 commit d1398d6
Show file tree
Hide file tree
Showing 20 changed files with 813 additions and 17 deletions.
4 changes: 3 additions & 1 deletion burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` |
| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
Expand Down Expand Up @@ -195,7 +197,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.div(other)` or `tensor / other` | `tensor / other` |
| `tensor.div_scalar(scalar)` or `tensor / scalar` | `tensor / scalar` |
| `tensor.equal_elem(other)` | `tensor.eq(other)` |
| `tensor.full_like(fill_value)` | `torch.full_like(tensor, fill_value) |
| `tensor.full_like(fill_value)` | `torch.full_like(tensor, fill_value) |
| `tensor.gather(dim, indices)` | `torch.gather(tensor, dim, indices)` |
| `tensor.greater(other)` | `tensor.gt(other)` |
| `tensor.greater_elem(scalar)` | `tensor.gt(scalar)` |
Expand Down
12 changes: 12 additions & 0 deletions crates/burn-autodiff/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
B::bool_chunk(tensor, chunks, dim)
}

fn bool_split(tensor: BoolTensor<B>, split_size: usize, dim: usize) -> Vec<BoolTensor<B>> {
B::bool_split(tensor, split_size, dim)
}

fn bool_split_with_sizes(
tensor: BoolTensor<B>,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<BoolTensor<B>> {
B::bool_split_with_sizes(tensor, split_sizes, dim)
}

fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
B::bool_permute(tensor, axes)
}
Expand Down
16 changes: 16 additions & 0 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,22 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_chunk(tensor, chunks, dim)
}

fn int_split(
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,
split_size: usize,
dim: usize,
) -> Vec<<Autodiff<B> as Backend>::IntTensorPrimitive> {
B::int_split(tensor, split_size, dim)
}

fn int_split_with_sizes(
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<<Autodiff<B> as Backend>::IntTensorPrimitive> {
B::int_split_with_sizes(tensor, split_sizes, dim)
}

fn int_random(
shape: Shape,
distribution: Distribution,
Expand Down
23 changes: 23 additions & 0 deletions crates/burn-tch/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,29 @@ impl TchOps {
.collect()
}

pub fn split(tensor: TchTensor, split_size: usize, dim: usize) -> Vec<TchTensor> {
tensor
.tensor
.split(split_size as i64, dim as i64)
.into_iter()
.map(TchTensor::new)
.collect()
}

pub fn split_with_sizes(
tensor: TchTensor,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<TchTensor> {
let split_sizes_i64: Vec<i64> = split_sizes.iter().map(|&s| s as i64).collect();
tensor
.tensor
.split_with_sizes(split_sizes_i64, dim as i64)
.into_iter()
.map(TchTensor::new)
.collect()
}

pub fn powf(tensor: TchTensor, exponent: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
tensor,
Expand Down
12 changes: 12 additions & 0 deletions crates/burn-tch/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ impl<E: TchElement, Q: QuantElement> BoolTensorOps<Self> for LibTorch<E, Q> {
TchOps::chunk(tensor, chunks, dim)
}

fn bool_split(tensor: TchTensor, split_size: usize, dim: usize) -> Vec<TchTensor> {
TchOps::split(tensor, split_size, dim)
}

fn bool_split_with_sizes(
tensor: TchTensor,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<TchTensor> {
TchOps::split_with_sizes(tensor, split_sizes, dim)
}

fn bool_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
TchOps::permute(tensor, axes)
}
Expand Down
12 changes: 12 additions & 0 deletions crates/burn-tch/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,18 @@ impl<E: TchElement, Q: QuantElement> IntTensorOps<Self> for LibTorch<E, Q> {
TchOps::chunk(tensor, chunks, dim)
}

fn int_split(tensor: TchTensor, split_size: usize, dim: usize) -> Vec<TchTensor> {
TchOps::split(tensor, split_size, dim)
}

fn int_split_with_sizes(
tensor: TchTensor,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<TchTensor> {
TchOps::split_with_sizes(tensor, split_sizes, dim)
}

fn int_random(shape: Shape, distribution: Distribution, device: &LibTorchDevice) -> TchTensor {
match distribution {
Distribution::Default => {
Expand Down
12 changes: 12 additions & 0 deletions crates/burn-tch/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,18 @@ impl<E: TchElement, Q: QuantElement> FloatTensorOps<Self> for LibTorch<E, Q> {
TchOps::chunk(tensor, chunks, dim)
}

fn float_split(tensor: TchTensor, split_size: usize, dim: usize) -> Vec<TchTensor> {
TchOps::split(tensor, split_size, dim)
}

fn float_split_with_sizes(
tensor: TchTensor,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<TchTensor> {
TchOps::split_with_sizes(tensor, split_sizes, dim)
}

fn float_powf(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::powf(lhs, rhs)
}
Expand Down
191 changes: 189 additions & 2 deletions crates/burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1199,7 +1199,7 @@ where
Self::new(narrow::<B, K>(self.primitive, dim, start, length))
}

/// Attempts to split the tensor along the given dimension into chunks.
/// Attempts to split the tensor into a specified number of chunks along a given dimension.
/// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
///
/// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
Expand Down Expand Up @@ -1247,6 +1247,91 @@ where
.collect()
}

/// Splits the tensor into chunks of a specified size along a given dimension.
/// Each chunk is a view of the original tensor.
///
/// If the tensor size along the given dimension is not divisible by `split_size`,
/// then the last chunk will be smaller.
///
/// # Panics
///
/// If the specified dimension to split along is greater than the number of dimensions of the tensor.
///
/// # Returns
///
/// A vector of tensors.
///
/// # Example
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::Tensor;
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// // Create a 1D tensor with 5 elements
/// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
/// // Split the tensor into chunks of size 2 along dimension 0
/// let chunks = tensor.split(2, 0);
/// // The result is a vector of tensors:
/// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])]
/// println!("{:?}", chunks);
/// }
/// ```
pub fn split(self, split_size: usize, dim: usize) -> Vec<Self> {
check!(TensorCheck::split::<D>(
self.shape().dims.as_ref(),
split_size,
dim
));
K::split(self.primitive, split_size, dim)
.into_iter()
.map(Self::new)
.collect()
}

/// Splits the tensor into chunks with the specified sizes along a given dimension.
/// Each chunk is a view of the original tensor.
///
/// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
/// in `split_sizes` must equal the size of the tensor along the specified dimension.
///
/// # Panics
///
/// If the specified dimension to split along is greater than the number of dimensions of the tensor or
/// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
///
/// # Returns
///
/// A vector of tensors.
///
/// # Example
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::Tensor;
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// // Create a 1D tensor with 5 elements
/// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
/// // Split the tensor into chunks with sizes [2, 3] along dimension 0
/// let chunks = tensor.split_with_sizes(vec![2, 3], 0);
/// // The result is a vector of tensors:
/// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])]
/// println!("{:?}", chunks);
/// }
/// ```
pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) -> Vec<Self> {
check!(TensorCheck::split_with_sizes::<D>(
self.shape().dims.as_ref(),
&split_sizes,
dim
));
K::split_with_sizes(self.primitive, split_sizes, dim)
.into_iter()
.map(Self::new)
.collect()
}

/// Tests if any element in the `tensor` evaluates to True.
///
/// # Arguments
Expand Down Expand Up @@ -2111,10 +2196,58 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// To split a tensor, users should prefer the [Tensor::chunk](Tensor::chunk) function,
/// To chunk a tensor, users should prefer the [Tensor::chunk](Tensor::chunk) function,
/// which is more high-level and designed for public use.
fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive>;

/// Splits the tensor into chunks of a specified size along a given dimension.
/// Each chunk is a view of the original tensor.
///
/// # Panics
///
/// If the dimension to split along is greater than the number of dimensions of the tensor.
///
/// # Returns
///
/// A vector of tensors.
///
/// # Remarks
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// To split a tensor, users should prefer the [Tensor::split](Tensor::split) function,
/// which is more high-level and designed for public use.
fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive>;

/// Splits the tensor into chunks with the specified sizes along a given dimension.
/// Each chunk is a view of the original tensor.
///
/// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
/// in `split_sizes` must equal the size of the tensor along the specified dimension.
///
/// # Panics
///
/// If the dimension to split along is greater than the number of dimensions of the tensor or
/// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
///
/// # Returns
///
/// A vector of tensors.
///
/// # Remarks
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// To split a tensor, users should prefer the [Tensor::split_with_sizes](Tensor::split_with_sizes) function,
/// which is more high-level and designed for public use.
fn split_with_sizes(
tensor: Self::Primitive,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<Self::Primitive>;

/// Equates the given tensors.
///
/// # Arguments
Expand Down Expand Up @@ -2437,6 +2570,36 @@ impl<B: Backend> BasicOps<B> for Float {
.collect(),
}
}

fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
match tensor {
TensorPrimitive::Float(tensor) => B::float_split(tensor, split_size, dim)
.into_iter()
.map(TensorPrimitive::Float)
.collect(),
TensorPrimitive::QFloat(tensor) => B::q_split(tensor, split_size, dim)
.into_iter()
.map(TensorPrimitive::QFloat)
.collect(),
}
}

fn split_with_sizes(
tensor: Self::Primitive,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<Self::Primitive> {
match tensor {
TensorPrimitive::Float(tensor) => B::float_split_with_sizes(tensor, split_sizes, dim)
.into_iter()
.map(TensorPrimitive::Float)
.collect(),
TensorPrimitive::QFloat(tensor) => B::q_split_with_sizes(tensor, split_sizes, dim)
.into_iter()
.map(TensorPrimitive::QFloat)
.collect(),
}
}
}

impl<B: Backend> BasicOps<B> for Int {
Expand Down Expand Up @@ -2536,6 +2699,18 @@ impl<B: Backend> BasicOps<B> for Int {
fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
B::int_chunk(tensor, chunks, dim)
}

fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
B::int_split(tensor, split_size, dim)
}

fn split_with_sizes(
tensor: Self::Primitive,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<Self::Primitive> {
B::int_split_with_sizes(tensor, split_sizes, dim)
}
}

impl<B: Backend> BasicOps<B> for Bool {
Expand Down Expand Up @@ -2635,6 +2810,18 @@ impl<B: Backend> BasicOps<B> for Bool {
fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
B::bool_chunk(tensor, chunks, dim)
}

fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
B::bool_split(tensor, split_size, dim)
}

fn split_with_sizes(
tensor: Self::Primitive,
split_sizes: Vec<usize>,
dim: usize,
) -> Vec<Self::Primitive> {
B::bool_split_with_sizes(tensor, split_sizes, dim)
}
}

/// Trait used for movedim arguments
Expand Down
Loading

0 comments on commit d1398d6

Please sign in to comment.