diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index 53ef658d2c..f5fba9eb1d 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -142,6 +142,8 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`. | `tensor.any()` | `tensor.any()` | | `tensor.any_dim(dim)` | `tensor.any(dim)` | | `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` | +| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` | +| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` | | `tensor.device()` | `tensor.device` | | `tensor.dims()` | `tensor.size()` | | `tensor.equal(other)` | `x == y` | @@ -195,7 +197,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. | `tensor.div(other)` or `tensor / other` | `tensor / other` | | `tensor.div_scalar(scalar)` or `tensor / scalar` | `tensor / scalar` | | `tensor.equal_elem(other)` | `tensor.eq(other)` | -| `tensor.full_like(fill_value)` | `torch.full_like(tensor, fill_value) | +| `tensor.full_like(fill_value)` | `torch.full_like(tensor, fill_value) | | `tensor.gather(dim, indices)` | `torch.gather(tensor, dim, indices)` | | `tensor.greater(other)` | `tensor.gt(other)` | | `tensor.greater_elem(scalar)` | `tensor.gt(scalar)` | diff --git a/crates/burn-autodiff/src/ops/bool_tensor.rs b/crates/burn-autodiff/src/ops/bool_tensor.rs index cf241cf18c..1b40a1af93 100644 --- a/crates/burn-autodiff/src/ops/bool_tensor.rs +++ b/crates/burn-autodiff/src/ops/bool_tensor.rs @@ -88,6 +88,18 @@ impl BoolTensorOps for Autodiff { B::bool_chunk(tensor, chunks, dim) } + fn bool_split(tensor: BoolTensor, split_size: usize, dim: usize) -> Vec> { + B::bool_split(tensor, split_size, dim) + } + + fn bool_split_with_sizes( + tensor: BoolTensor, + split_sizes: Vec, + dim: usize, + ) -> Vec> { + B::bool_split_with_sizes(tensor, split_sizes, dim) + } + fn bool_permute(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { B::bool_permute(tensor, axes) } diff --git a/crates/burn-autodiff/src/ops/int_tensor.rs b/crates/burn-autodiff/src/ops/int_tensor.rs index c445633f3a..5f3c80199e 100644 --- a/crates/burn-autodiff/src/ops/int_tensor.rs +++ b/crates/burn-autodiff/src/ops/int_tensor.rs @@ -285,6 +285,22 @@ impl IntTensorOps for Autodiff { B::int_chunk(tensor, chunks, dim) } + fn int_split( + tensor: as Backend>::IntTensorPrimitive, + split_size: usize, + dim: usize, + ) -> Vec< as Backend>::IntTensorPrimitive> { + B::int_split(tensor, split_size, dim) + } + + fn int_split_with_sizes( + tensor: as Backend>::IntTensorPrimitive, + split_sizes: Vec, + dim: usize, + ) -> Vec< as Backend>::IntTensorPrimitive> { + B::int_split_with_sizes(tensor, split_sizes, dim) + } + fn int_random( shape: Shape, distribution: Distribution, diff --git a/crates/burn-tch/src/ops/base.rs b/crates/burn-tch/src/ops/base.rs index 5a561fe714..19459731b4 100644 --- a/crates/burn-tch/src/ops/base.rs +++ b/crates/burn-tch/src/ops/base.rs @@ -416,6 +416,29 @@ impl TchOps { .collect() } + pub fn split(tensor: TchTensor, split_size: usize, dim: usize) -> Vec { + tensor + .tensor + .split(split_size as i64, dim as i64) + .into_iter() + .map(TchTensor::new) + .collect() + } + + pub fn split_with_sizes( + tensor: TchTensor, + split_sizes: Vec, + dim: usize, + ) -> Vec { + let split_sizes_i64: Vec = split_sizes.iter().map(|&s| s as i64).collect(); + tensor + .tensor + .split_with_sizes(split_sizes_i64, dim as i64) + .into_iter() + .map(TchTensor::new) + .collect() + } + pub fn powf(tensor: TchTensor, exponent: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( tensor, diff --git a/crates/burn-tch/src/ops/bool_tensor.rs b/crates/burn-tch/src/ops/bool_tensor.rs index afcace5349..b31ef57356 100644 --- a/crates/burn-tch/src/ops/bool_tensor.rs +++ b/crates/burn-tch/src/ops/bool_tensor.rs @@ -93,6 +93,18 @@ impl BoolTensorOps for LibTorch { TchOps::chunk(tensor, chunks, dim) } + fn bool_split(tensor: TchTensor, split_size: usize, dim: usize) -> Vec { + TchOps::split(tensor, split_size, dim) + } + + fn bool_split_with_sizes( + tensor: TchTensor, + split_sizes: Vec, + dim: usize, + ) -> Vec { + TchOps::split_with_sizes(tensor, split_sizes, dim) + } + fn bool_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor { TchOps::permute(tensor, axes) } diff --git a/crates/burn-tch/src/ops/int_tensor.rs b/crates/burn-tch/src/ops/int_tensor.rs index 8c387a8937..5ddb9c21f2 100644 --- a/crates/burn-tch/src/ops/int_tensor.rs +++ b/crates/burn-tch/src/ops/int_tensor.rs @@ -349,6 +349,18 @@ impl IntTensorOps for LibTorch { TchOps::chunk(tensor, chunks, dim) } + fn int_split(tensor: TchTensor, split_size: usize, dim: usize) -> Vec { + TchOps::split(tensor, split_size, dim) + } + + fn int_split_with_sizes( + tensor: TchTensor, + split_sizes: Vec, + dim: usize, + ) -> Vec { + TchOps::split_with_sizes(tensor, split_sizes, dim) + } + fn int_random(shape: Shape, distribution: Distribution, device: &LibTorchDevice) -> TchTensor { match distribution { Distribution::Default => { diff --git a/crates/burn-tch/src/ops/tensor.rs b/crates/burn-tch/src/ops/tensor.rs index 5cc9b9a65a..8f460fccaa 100644 --- a/crates/burn-tch/src/ops/tensor.rs +++ b/crates/burn-tch/src/ops/tensor.rs @@ -417,6 +417,18 @@ impl FloatTensorOps for LibTorch { TchOps::chunk(tensor, chunks, dim) } + fn float_split(tensor: TchTensor, split_size: usize, dim: usize) -> Vec { + TchOps::split(tensor, split_size, dim) + } + + fn float_split_with_sizes( + tensor: TchTensor, + split_sizes: Vec, + dim: usize, + ) -> Vec { + TchOps::split_with_sizes(tensor, split_sizes, dim) + } + fn float_powf(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::powf(lhs, rhs) } diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index fd3706fe6a..b0267e08ff 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -1199,7 +1199,7 @@ where Self::new(narrow::(self.primitive, dim, start, length)) } - /// Attempts to split the tensor along the given dimension into chunks. + /// Attempts to split the tensor into a specified number of chunks along a given dimension. /// May return less chunks than requested if the tensor size is not divisible by the number of chunks. /// /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size. @@ -1247,6 +1247,91 @@ where .collect() } + /// Splits the tensor into chunks of a specified size along a given dimension. + /// Each chunk is a view of the original tensor. + /// + /// If the tensor size along the given dimension is not divisible by `split_size`, + /// then the last chunk will be smaller. + /// + /// # Panics + /// + /// If the specified dimension to split along is greater than the number of dimensions of the tensor. + /// + /// # Returns + /// + /// A vector of tensors. + /// + /// # Example + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// // Create a 1D tensor with 5 elements + /// let tensor = Tensor::::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device); + /// // Split the tensor into chunks of size 2 along dimension 0 + /// let chunks = tensor.split(2, 0); + /// // The result is a vector of tensors: + /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])] + /// println!("{:?}", chunks); + /// } + /// ``` + pub fn split(self, split_size: usize, dim: usize) -> Vec { + check!(TensorCheck::split::( + self.shape().dims.as_ref(), + split_size, + dim + )); + K::split(self.primitive, split_size, dim) + .into_iter() + .map(Self::new) + .collect() + } + + /// Splits the tensor into chunks with the specified sizes along a given dimension. + /// Each chunk is a view of the original tensor. + /// + /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes + /// in `split_sizes` must equal the size of the tensor along the specified dimension. + /// + /// # Panics + /// + /// If the specified dimension to split along is greater than the number of dimensions of the tensor or + /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`. + /// + /// # Returns + /// + /// A vector of tensors. + /// + /// # Example + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let device = Default::default(); + /// // Create a 1D tensor with 5 elements + /// let tensor = Tensor::::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device); + /// // Split the tensor into chunks with sizes [2, 3] along dimension 0 + /// let chunks = tensor.split_with_sizes(vec![2, 3], 0); + /// // The result is a vector of tensors: + /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])] + /// println!("{:?}", chunks); + /// } + /// ``` + pub fn split_with_sizes(self, split_sizes: Vec, dim: usize) -> Vec { + check!(TensorCheck::split_with_sizes::( + self.shape().dims.as_ref(), + &split_sizes, + dim + )); + K::split_with_sizes(self.primitive, split_sizes, dim) + .into_iter() + .map(Self::new) + .collect() + } + /// Tests if any element in the `tensor` evaluates to True. /// /// # Arguments @@ -2111,10 +2196,58 @@ pub trait BasicOps: TensorKind { /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// - /// To split a tensor, users should prefer the [Tensor::chunk](Tensor::chunk) function, + /// To chunk a tensor, users should prefer the [Tensor::chunk](Tensor::chunk) function, /// which is more high-level and designed for public use. fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec; + /// Splits the tensor into chunks of a specified size along a given dimension. + /// Each chunk is a view of the original tensor. + /// + /// # Panics + /// + /// If the dimension to split along is greater than the number of dimensions of the tensor. + /// + /// # Returns + /// + /// A vector of tensors. + /// + /// # Remarks + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// To split a tensor, users should prefer the [Tensor::split](Tensor::split) function, + /// which is more high-level and designed for public use. + fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec; + + /// Splits the tensor into chunks with the specified sizes along a given dimension. + /// Each chunk is a view of the original tensor. + /// + /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes + /// in `split_sizes` must equal the size of the tensor along the specified dimension. + /// + /// # Panics + /// + /// If the dimension to split along is greater than the number of dimensions of the tensor or + /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`. + /// + /// # Returns + /// + /// A vector of tensors. + /// + /// # Remarks + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// To split a tensor, users should prefer the [Tensor::split_with_sizes](Tensor::split_with_sizes) function, + /// which is more high-level and designed for public use. + fn split_with_sizes( + tensor: Self::Primitive, + split_sizes: Vec, + dim: usize, + ) -> Vec; + /// Equates the given tensors. /// /// # Arguments @@ -2437,6 +2570,36 @@ impl BasicOps for Float { .collect(), } } + + fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec { + match tensor { + TensorPrimitive::Float(tensor) => B::float_split(tensor, split_size, dim) + .into_iter() + .map(TensorPrimitive::Float) + .collect(), + TensorPrimitive::QFloat(tensor) => B::q_split(tensor, split_size, dim) + .into_iter() + .map(TensorPrimitive::QFloat) + .collect(), + } + } + + fn split_with_sizes( + tensor: Self::Primitive, + split_sizes: Vec, + dim: usize, + ) -> Vec { + match tensor { + TensorPrimitive::Float(tensor) => B::float_split_with_sizes(tensor, split_sizes, dim) + .into_iter() + .map(TensorPrimitive::Float) + .collect(), + TensorPrimitive::QFloat(tensor) => B::q_split_with_sizes(tensor, split_sizes, dim) + .into_iter() + .map(TensorPrimitive::QFloat) + .collect(), + } + } } impl BasicOps for Int { @@ -2536,6 +2699,18 @@ impl BasicOps for Int { fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec { B::int_chunk(tensor, chunks, dim) } + + fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec { + B::int_split(tensor, split_size, dim) + } + + fn split_with_sizes( + tensor: Self::Primitive, + split_sizes: Vec, + dim: usize, + ) -> Vec { + B::int_split_with_sizes(tensor, split_sizes, dim) + } } impl BasicOps for Bool { @@ -2635,6 +2810,18 @@ impl BasicOps for Bool { fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec { B::bool_chunk(tensor, chunks, dim) } + + fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec { + B::bool_split(tensor, split_size, dim) + } + + fn split_with_sizes( + tensor: Self::Primitive, + split_sizes: Vec, + dim: usize, + ) -> Vec { + B::bool_split_with_sizes(tensor, split_sizes, dim) + } } /// Trait used for movedim arguments diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index 0f02a0c060..d4ab13faf4 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -1006,6 +1006,66 @@ impl TensorCheck { check } + pub(crate) fn split( + tensor_dims: &[usize], + split_size: usize, + dim: usize, + ) -> Self { + let mut check = Self::Ok; + let op = "split"; + let tensor_rank = tensor_dims.len(); + + if dim >= tensor_rank { + check = check.register( + op, + TensorError::new("Given dimension is greater than or equal to the tensor rank.") + .details(format!("Tensor rank: '{D}', given dimension: '{dim}'")), + ); + } else { + let tensor_size = tensor_dims[dim]; + if split_size == 0 && tensor_size != 0 { + check = check.register( + op, + TensorError::new("split_size must be greater than 0 unless the tensor size along the dimension is 0.") + .details(format!("split_size: '{split_size}', tensor size along dim '{dim}': '{tensor_size}'.")), + ); + } + } + + check + } + + pub(crate) fn split_with_sizes( + tensor_dims: &[usize], + split_sizes: &[usize], + dim: usize, + ) -> Self { + let mut check = Self::Ok; + let op = "split_with_sizes"; + let tensor_rank = tensor_dims.len(); + + if dim >= tensor_rank { + check = check.register( + op, + TensorError::new("Given dimension is greater than or equal to the tensor rank.") + .details(format!("Tensor rank: '{D}', given dimension: '{dim}'.")), + ); + } else { + // Validate split_sizes add up to size of dimension to split along + let tensor_size = tensor_dims[dim]; + let total_split_size: usize = split_sizes.iter().sum(); + if total_split_size != tensor_size { + check = check.register( + op, + TensorError::new("The sum of split_sizes must equal the tensor size along the specified dimension.") + .details(format!("Sum of split_sizes: '{total_split_size}', tensor size along dim '{dim}': '{tensor_size}'.")), + ); + } + } + + check + } + /// The goal is to minimize the cost of checks when there are no error, but it's way less /// important when an error occurred, crafting a comprehensive error message is more important /// than optimizing string manipulation. diff --git a/crates/burn-tensor/src/tensor/api/chunk.rs b/crates/burn-tensor/src/tensor/api/chunk.rs index 21063faa2a..c3735f8a57 100644 --- a/crates/burn-tensor/src/tensor/api/chunk.rs +++ b/crates/burn-tensor/src/tensor/api/chunk.rs @@ -7,16 +7,16 @@ use alloc::vec::Vec; /// # Arguments /// /// * `tensor` - The tensor. -/// * `chunks` - The number of chunks to be produced -/// * `times` - The dimension along which the tensor will be split. +/// * `chunks` - The number of chunks to be produced. +/// * `dim` - The dimension along which the tensor will be split. /// /// # Returns /// -/// A vectors of tensors +/// A vectors of tensors. /// /// # Remarks /// -/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation. +/// This is a fallback solution that is used only when the backend doesn't have the corresponding implementation. /// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved /// by static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. diff --git a/crates/burn-tensor/src/tensor/api/mod.rs b/crates/burn-tensor/src/tensor/api/mod.rs index 60272d80bd..cdca97d4c7 100644 --- a/crates/burn-tensor/src/tensor/api/mod.rs +++ b/crates/burn-tensor/src/tensor/api/mod.rs @@ -12,6 +12,7 @@ mod kind; mod narrow; mod numeric; mod sort; +mod split; pub use argwhere::argwhere_data; pub use autodiff::*; @@ -22,3 +23,4 @@ pub use kind::*; pub use narrow::narrow; pub use numeric::*; pub use sort::{argsort, sort, sort_with_indices}; +pub use split::{split, split_with_sizes}; diff --git a/crates/burn-tensor/src/tensor/api/split.rs b/crates/burn-tensor/src/tensor/api/split.rs new file mode 100644 index 0000000000..b316faa334 --- /dev/null +++ b/crates/burn-tensor/src/tensor/api/split.rs @@ -0,0 +1,76 @@ +use super::narrow::narrow; +use crate::{backend::Backend, BasicOps, TensorKind}; +use alloc::vec::Vec; + +/// Splits the tensor along the given dimension into equally sized chunks (if possible) +/// with size `split_size`. Last chunk will be smaller if the tensor size along the given +/// dimension `dim` is not divisible by `split_size`. +/// +/// # Arguments +/// +/// * `tensor` - The tensor. +/// * `split_size` - The size of a single chunk. +/// * `dim` - The dimension along which to split the tensor. +/// +/// # Returns +/// +/// A vector of tensors. +/// +/// # Remarks +/// +/// This (and the following) are fallback solutions that is used only when the backend doesn't have the corresponding implementation. +/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved +/// by static dispatch. It is not designed for direct usage by users, and not recommended to import +/// or use this function directly. +pub fn split + BasicOps>( + tensor: K::Primitive, + split_size: usize, + dim: usize, +) -> Vec { + let size = K::shape(&tensor).dims[dim]; + let mut tensors = Vec::new(); + + let mut start = 0; + while start < size { + let length = usize::min(split_size, size - start); + tensors.push(narrow::(tensor.clone(), dim, start, length)); + start += length; + } + + tensors +} + +/// Splits the tensor along the given dimension into chunks with sizes in +/// `dim` according to `split_sizes`. +/// +/// # Arguments +/// +/// * `tensor` - The tensor. +/// * `split_sizes` - Vector of sizes for each chunk. +/// * `dim` - The dimension along which to split the tensor. +/// +/// # Returns +/// +/// A vector of tensors. +/// +/// # Remarks +/// +/// Fallback solution for backends with no equivalent functionality. +pub fn split_with_sizes + BasicOps>( + tensor: K::Primitive, + split_sizes: Vec, + dim: usize, +) -> Vec { + let mut tensors = Vec::new(); + + let mut start = 0; + for length in split_sizes { + if length == 0 { + continue; + } + tensors.push(narrow::(tensor.clone(), dim, start, length)); + start += length; + } + + tensors +} diff --git a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs index a8c2c1772e..31ed6d1ede 100644 --- a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs @@ -3,8 +3,8 @@ use super::{ FloatTensor, IntTensor, }; use crate::{ - argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, - TensorData, + argwhere_data, backend::Backend, chunk, narrow, split, split_with_sizes, tensor::Shape, Bool, + ElementConversion, TensorData, }; use alloc::{vec, vec::Vec}; use core::{future::Future, ops::Range}; @@ -279,16 +279,51 @@ pub trait BoolTensorOps { /// # Arguments /// /// * `tensor` - The tensor. - /// * `chunks` - The number of chunks to be produced + /// * `chunks` - The number of chunks to be produced. /// * `times` - The dimension along which the tensor will be split. /// /// # Returns /// - /// A vector of tensors + /// A vector of tensors. fn bool_chunk(tensor: BoolTensor, chunks: usize, dim: usize) -> Vec> { chunk::(tensor, chunks, dim) } + /// Split the tensor along the given dimension into chunks of `split_size`. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `split_size` - The size of a single chunk. + /// * `times` - The dimension along which the tensor will be split. + /// + /// # Returns + /// + /// A vector of tensors. + fn bool_split(tensor: BoolTensor, split_size: usize, dim: usize) -> Vec> { + split::(tensor, split_size, dim) + } + + /// Split the tensor along the given dimension into chunks with sizes in + /// `dim` according to `split_sizes`. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `split_sizes` - Vector of sizes for each chunk. + /// * `times` - The dimension along which the tensor will be split. + /// + /// # Returns + /// + /// A vector of tensors. + fn bool_split_with_sizes( + tensor: BoolTensor, + split_sizes: Vec, + dim: usize, + ) -> Vec> { + split_with_sizes::(tensor, split_sizes, dim) + } + /// Tests if any element in the boolean `tensor` evaluates to True. /// /// # Arguments diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index 2df545dba4..f62c06b467 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -2,8 +2,8 @@ use super::cat::cat_with_slice_assign; use super::repeat_dim::repeat_with_slice_assign; use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; use crate::cast::ToElement; +use crate::tensor::api::{chunk, narrow, split, split_with_sizes}; use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Int, TensorData}; -use crate::{tensor::api::chunk, tensor::api::narrow}; use alloc::vec::Vec; use core::future::Future; use core::ops::Range; @@ -950,7 +950,7 @@ pub trait IntTensorOps { /// # Arguments /// /// * `tensor` - The tensor. - /// * `chunks` - The number of chunks to be produced + /// * `chunks` - The number of chunks to be produced. /// * `times` - The dimension along which the tensor will be split. /// /// # Returns @@ -960,6 +960,41 @@ pub trait IntTensorOps { chunk::(tensor, chunks, dim) } + /// Split the tensor along the given dimension into chunks of `split_size`. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `split_size` - The size of a single chunk. + /// * `times` - The dimension along which the tensor will be split. + /// + /// # Returns + /// + /// A vector of tensors. + fn int_split(tensor: IntTensor, split_size: usize, dim: usize) -> Vec> { + split::(tensor, split_size, dim) + } + + /// Split the tensor along the given dimension into chunks with sizes in + /// `dim` according to `split_sizes`. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `split_sizes` - Vector of sizes for each chunk. + /// * `times` - The dimension along which the tensor will be split. + /// + /// # Returns + /// + /// A vector of tensors. + fn int_split_with_sizes( + tensor: IntTensor, + split_sizes: Vec, + dim: usize, + ) -> Vec> { + split_with_sizes::(tensor, split_sizes, dim) + } + /// Creates a new int tensor with random values. /// /// # Arguments diff --git a/crates/burn-tensor/src/tensor/ops/qtensor.rs b/crates/burn-tensor/src/tensor/ops/qtensor.rs index 037b25e0b2..4b9df13a49 100644 --- a/crates/burn-tensor/src/tensor/ops/qtensor.rs +++ b/crates/burn-tensor/src/tensor/ops/qtensor.rs @@ -1198,12 +1198,12 @@ pub trait QTensorOps { /// # Arguments /// /// * `tensor` - The tensor. - /// * `chunks` - The number of chunks to be produced + /// * `chunks` - The number of chunks to be produced. /// * `times` - The dimension along which the tensor will be split. /// /// # Returns /// - /// A vector of tensors + /// A vector of tensors. fn q_chunk(tensor: QuantizedTensor, chunks: usize, dim: usize) -> Vec> { let scheme = *tensor.scheme(); @@ -1216,6 +1216,61 @@ pub trait QTensorOps { .collect() } + /// Split the tensor along the given dimension into chunks of `split_size`. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `split_size` - The size of a single chunk. + /// * `times` - The dimension along which the tensor will be split. + /// + /// # Returns + /// + /// A vector of tensors. + fn q_split( + tensor: QuantizedTensor, + split_size: usize, + dim: usize, + ) -> Vec> { + let scheme = *tensor.scheme(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_split(tensor_f, split_size, dim); + + out_f + .into_iter() + .map(|tensor| Self::quantize_dynamic(tensor, &scheme)) + .collect() + } + + /// Split the tensor along the given dimension into chunks with sizes in + /// `dim` according to `split_sizes`. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `split_sizes` - Vector of sizes for each chunk. + /// * `times` - The dimension along which the tensor will be split. + /// + /// # Returns + /// + /// A vector of tensors. + fn q_split_with_sizes( + tensor: QuantizedTensor, + split_sizes: Vec, + dim: usize, + ) -> Vec> { + let scheme = *tensor.scheme(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_split_with_sizes(tensor_f, split_sizes, dim); + + out_f + .into_iter() + .map(|tensor| Self::quantize_dynamic(tensor, &scheme)) + .collect() + } + /// Tests if any element in the `tensor` evaluates to True. /// /// # Arguments diff --git a/crates/burn-tensor/src/tensor/ops/tensor.rs b/crates/burn-tensor/src/tensor/ops/tensor.rs index 1226dc1c30..e5f551dd76 100644 --- a/crates/burn-tensor/src/tensor/ops/tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/tensor.rs @@ -4,8 +4,10 @@ use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, In use crate::backend::BackendBridge; use crate::tensor::cast::ToElement; use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Float, TensorData}; -use crate::{tensor::api::chunk, tensor::api::narrow}; -use crate::{FloatDType, TensorPrimitive}; +use crate::{ + tensor::api::chunk, tensor::api::narrow, tensor::api::split, tensor::api::split_with_sizes, + FloatDType, TensorPrimitive, +}; use alloc::vec::Vec; use core::future::Future; use core::ops::Range; @@ -1191,6 +1193,47 @@ pub trait FloatTensorOps { .collect() } + /// Split the tensor along the given dimension into chunks of `split_size`. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `split_size` - The size of a single chunk. + /// * `times` - The dimension along which the tensor will be split. + /// + /// # Returns + /// + /// A vector of tensors. + fn float_split(tensor: FloatTensor, split_size: usize, dim: usize) -> Vec> { + split::(TensorPrimitive::Float(tensor), split_size, dim) + .into_iter() + .map(|t| t.tensor()) + .collect() + } + + /// Split the tensor along the given dimension into chunks with sizes in + /// `dim` according to `split_sizes`. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `split_sizes` - Vector of sizes for each chunk. + /// * `times` - The dimension along which the tensor will be split. + /// + /// # Returns + /// + /// A vector of tensors. + fn float_split_with_sizes( + tensor: FloatTensor, + split_sizes: Vec, + dim: usize, + ) -> Vec> { + split_with_sizes::(TensorPrimitive::Float(tensor), split_sizes, dim) + .into_iter() + .map(|t| t.tensor()) + .collect() + } + /// Tests if any element in the float `tensor` evaluates to True. /// /// # Arguments diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index ca36834597..a3f2c82b72 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -107,6 +107,7 @@ macro_rules! testgen_quantization { burn_tensor::testgen_q_sin!(); burn_tensor::testgen_q_slice!(); burn_tensor::testgen_q_sort_argsort!(); + // burn_tensor::testgen_q_split!(); burn_tensor::testgen_q_sqrt!(); burn_tensor::testgen_q_stack!(); burn_tensor::testgen_q_sub!(); @@ -220,6 +221,7 @@ macro_rules! testgen_with_float_param { burn_tensor::testgen_floor!(); burn_tensor::testgen_ceil!(); burn_tensor::testgen_select!(); + burn_tensor::testgen_split!(); burn_tensor::testgen_prod!(); // test stats diff --git a/crates/burn-tensor/src/tests/ops/chunk.rs b/crates/burn-tensor/src/tests/ops/chunk.rs index f0338915b0..660bba8ff1 100644 --- a/crates/burn-tensor/src/tests/ops/chunk.rs +++ b/crates/burn-tensor/src/tests/ops/chunk.rs @@ -4,6 +4,7 @@ mod tests { use alloc::vec::Vec; use burn_tensor::{Int, Shape, Tensor, TensorData}; + #[test] fn test_chunk_evenly_divisible() { let tensors: Vec> = Tensor::arange(0..12, &Default::default()).chunk(6, 0); diff --git a/crates/burn-tensor/src/tests/ops/mod.rs b/crates/burn-tensor/src/tests/ops/mod.rs index 240c3fb5ee..b1096e0216 100644 --- a/crates/burn-tensor/src/tests/ops/mod.rs +++ b/crates/burn-tensor/src/tests/ops/mod.rs @@ -57,6 +57,7 @@ mod sign; mod sin; mod slice; mod sort_argsort; +mod split; mod sqrt; mod squeeze; mod stack; diff --git a/crates/burn-tensor/src/tests/ops/split.rs b/crates/burn-tensor/src/tests/ops/split.rs new file mode 100644 index 0000000000..093b652c55 --- /dev/null +++ b/crates/burn-tensor/src/tests/ops/split.rs @@ -0,0 +1,210 @@ +#[burn_tensor_testgen::testgen(split)] +mod tests { + use super::*; + use alloc::vec::Vec; + use burn_tensor::{Int, Shape, Tensor, TensorData}; + + #[test] + fn test_split_evenly_divisible() { + let device = Default::default(); + let tensors = + TestTensor::<2>::from_data([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]], &device); + + let split_tensors = tensors.split(2, 0); + assert_eq!(split_tensors.len(), 3); + + let expected = vec![ + TensorData::from([[0, 1], [2, 3]]), + TensorData::from([[4, 5], [6, 7]]), + TensorData::from([[8, 9], [10, 11]]), + ]; + + for (index, tensor) in split_tensors.iter().enumerate() { + tensor.to_data().assert_eq(&expected[index], false); + } + } + + #[test] + fn test_split_not_evenly_divisible() { + let device = Default::default(); + let tensors = TestTensor::<2>::from_data([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], &device); + + let split_tensors = tensors.split(2, 0); + assert_eq!(split_tensors.len(), 3); + + let expected = vec![ + TensorData::from([[0, 1], [2, 3]]), + TensorData::from([[4, 5], [6, 7]]), + TensorData::from([[8, 9]]), + ]; + + for (index, tensor) in split_tensors.iter().enumerate() { + tensor.to_data().assert_eq(&expected[index], false); + } + } + + #[test] + fn test_split_along_dim1() { + let device = Default::default(); + let tensors = TestTensor::<2>::from_data([[0, 1, 2], [3, 4, 5]], &device); + + let split_tensors = tensors.split(2, 1); + assert_eq!(split_tensors.len(), 2); + + let expected = vec![ + TensorData::from([[0, 1], [3, 4]]), + TensorData::from([[2], [5]]), + ]; + + for (index, tensor) in split_tensors.iter().enumerate() { + tensor.to_data().assert_eq(&expected[index], false); + } + } + + #[test] + fn test_split_split_size_larger_than_tensor_size() { + let device = Default::default(); + let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4], &device); + + let split_tensors = tensors.split(10, 0); + assert_eq!(split_tensors.len(), 1); + + let expected = vec![TensorData::from([0, 1, 2, 3, 4])]; + + for (index, tensor) in split_tensors.iter().enumerate() { + tensor.to_data().assert_eq(&expected[index], false); + } + } + + #[test] + fn test_split_with_zero_split_size_zero_tensor_size() { + let device = Default::default(); + let empty_array: [i32; 0] = []; + let tensors = TestTensor::<1>::from_data(empty_array, &device); + + let split_tensors = tensors.split(0, 0); + assert_eq!(split_tensors.len(), 0); + } + + #[test] + fn test_split_zero_sized_tensor() { + let device = Default::default(); + let empty_array: [i32; 0] = []; + let tensors = TestTensor::<1>::from_data(empty_array, &device); + + let split_tensors = tensors.split(1, 0); + assert_eq!(split_tensors.len(), 0); + } + + #[test] + #[should_panic( + expected = "split_size must be greater than 0 unless the tensor size along the dimension is 0." + )] + fn test_split_with_zero_split_size_non_zero_tensor() { + let device = Default::default(); + let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4], &device); + + let _split_tensors = tensors.split(0, 0); + } + + #[test] + #[should_panic(expected = "Given dimension is greater than or equal to the tensor rank.")] + fn test_split_invalid_dim() { + let device = Default::default(); + let tensors = TestTensor::<1>::from_data([0, 1, 2], &device); + + let _split_tensors = tensors.split(1, 2); + } + + #[test] + fn test_split_3d_tensor_along_dim0() { + let device = Default::default(); + let tensors = TestTensor::<3>::from_data( + [ + [[0, 1], [2, 3]], + [[4, 5], [6, 7]], + [[8, 9], [10, 11]], + [[12, 13], [14, 15]], + ], + &device, + ); + + let split_tensors = tensors.split(2, 0); + assert_eq!(split_tensors.len(), 2); + + let expected = vec![ + TensorData::from([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]), + TensorData::from([[[8, 9], [10, 11]], [[12, 13], [14, 15]]]), + ]; + + for (index, tensor) in split_tensors.iter().enumerate() { + tensor.to_data().assert_eq(&expected[index], false); + } + } + + #[test] + fn test_split_3d_tensor_along_dim1() { + let device = Default::default(); + let tensors = TestTensor::<3>::from_data( + [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], + &device, + ); + + let split_tensors = tensors.split(2, 1); + assert_eq!(split_tensors.len(), 2); + + let expected = vec![ + TensorData::from([[[0, 1], [2, 3]], [[6, 7], [8, 9]]]), + TensorData::from([[[4, 5]], [[10, 11]]]), + ]; + + for (index, tensor) in split_tensors.iter().enumerate() { + tensor.to_data().assert_eq(&expected[index], false); + } + } + + #[test] + fn test_split_with_sizes() { + let device = Default::default(); + let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4, 5], &device); + + let split_tensors = tensors.split_with_sizes(vec![2, 3, 1], 0); + assert_eq!(split_tensors.len(), 3); + + let expected = vec![ + TensorData::from([0, 1]), + TensorData::from([2, 3, 4]), + TensorData::from([5]), + ]; + + for (index, tensor) in split_tensors.iter().enumerate() { + tensor.to_data().assert_eq(&expected[index], false); + } + } + + #[test] + #[should_panic( + expected = "The sum of split_sizes must equal the tensor size along the specified dimension." + )] + fn test_split_with_sizes_invalid_sum() { + let device = Default::default(); + let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4, 5], &device); + + let _split_tensors = tensors.split_with_sizes(vec![2, 2, 1], 0); + } + + #[test] + fn test_split_with_sizes_zero_length() { + let device = Default::default(); + let tensors = TestTensor::<1>::from_data([0, 1, 2], &device); + + let split_tensors = tensors.split_with_sizes(vec![0, 1, 2], 0); + assert_eq!(split_tensors.len(), 2); + + let expected = vec![TensorData::from([0]), TensorData::from([1, 2])]; + + for (index, tensor) in split_tensors.iter().enumerate() { + tensor.to_data().assert_eq(&expected[index], false); + } + } +}