Skip to content

Commit

Permalink
Repeat operation (tracel-ai#2090)
Browse files Browse the repository at this point in the history
* renaming repeat to repeat_dim

* implementing repeat function

* renaming repeat files to repeat_dim

* renaming part 2

* renaming part 3

* renaming part 4

* renaming part 5

* adding test file

* adding unit test

* adding rust book documentation

* adding function args doc

* fixing tests

* changing repeat api to match pytorch equivalent

* fixing clippy error
  • Loading branch information
mepatrick73 authored Aug 3, 2024
1 parent bb13729 commit f7639bd
Show file tree
Hide file tree
Showing 40 changed files with 478 additions and 174 deletions.
69 changes: 35 additions & 34 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,40 +131,41 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t

Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.

| Burn | PyTorch Equivalent |
| ------------------------------------- | ------------------------------------ |
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.expand(shape)` | `tensor.expand(shape)` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.flip(axes)` | `tensor.flip(axes)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
| Burn | PyTorch Equivalent |
| ------------------------------------- | ------------------------------------------------------------------------ |
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.expand(shape)` | `tensor.expand(shape)` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.flip(axes)` | `tensor.flip(axes)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])`|
| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |

### Numeric Operations

Expand Down
4 changes: 2 additions & 2 deletions crates/burn-autodiff/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
B::bool_expand(tensor, shape)
}

fn bool_repeat<const D: usize>(
fn bool_repeat_dim<const D: usize>(
tensor: BoolTensor<B, D>,
dim: usize,
times: usize,
) -> BoolTensor<B, D> {
B::bool_repeat(tensor, dim, times)
B::bool_repeat_dim(tensor, dim, times)
}
}
4 changes: 2 additions & 2 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,12 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_mean_dim(tensor, dim)
}

fn int_repeat<const D: usize>(
fn int_repeat_dim<const D: usize>(
tensor: IntTensor<B, D>,
dim: usize,
times: usize,
) -> IntTensor<B, D> {
B::int_repeat(tensor, dim, times)
B::int_repeat_dim(tensor, dim, times)
}

fn int_greater<const D: usize>(lhs: IntTensor<B, D>, rhs: IntTensor<B, D>) -> BoolTensor<B, D> {
Expand Down
10 changes: 6 additions & 4 deletions crates/burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2418,7 +2418,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
B::float_argsort(tensor.primitive, dim, descending)
}

fn float_repeat<const D: usize>(
fn float_repeat_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
times: usize,
Expand All @@ -2437,7 +2437,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
impl<B: Backend, const D: usize> RetroForward for RetroRepeat<B, D> {
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
let tensor = states.get_state::<B::FloatTensorPrimitive<D>>(&self.tensor_id);
let out = B::float_repeat(tensor, self.dim, self.times);
let out = B::float_repeat_dim(tensor, self.dim, self.times);
states.save(out_node, out)
}
}
Expand Down Expand Up @@ -2467,9 +2467,11 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
.stateful()
{
OpsKind::Tracked(prep) => {
prep.finish(dim, B::float_repeat(tensor.primitive, dim, times))
prep.finish(dim, B::float_repeat_dim(tensor.primitive, dim, times))
}
OpsKind::UnTracked(prep) => {
prep.finish(B::float_repeat_dim(tensor.primitive, dim, times))
}
OpsKind::UnTracked(prep) => prep.finish(B::float_repeat(tensor.primitive, dim, times)),
}
}

Expand Down
4 changes: 2 additions & 2 deletions crates/burn-autodiff/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ mod permute;
mod pow;
mod recip;
mod relu;
mod repeat;
mod repeat_dim;
mod reshape;
mod select;
mod sigmoid;
Expand Down Expand Up @@ -133,6 +133,6 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_sign!();
burn_autodiff::testgen_ad_expand!();
burn_autodiff::testgen_ad_sort!();
burn_autodiff::testgen_ad_repeat!();
burn_autodiff::testgen_ad_repeat_dim!();
};
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#[burn_tensor_testgen::testgen(ad_repeat)]
#[burn_tensor_testgen::testgen(ad_repeat_dim)]
mod tests {
use super::*;
use burn_tensor::{activation, TensorData};
Expand All @@ -12,7 +12,7 @@ mod tests {
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();

let tensor_3 = tensor_2.clone().repeat(1, 3);
let tensor_3 = tensor_2.clone().repeat_dim(1, 3);

let tensor_3 = tensor_1.matmul(tensor_3);
let grads = tensor_3.backward();
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ mod tests {
// burn_tensor::testgen_powf!();

burn_tensor::testgen_random!();
burn_tensor::testgen_repeat!();
burn_tensor::testgen_repeat_dim!();
burn_tensor::testgen_reshape!();
burn_tensor::testgen_select!();
burn_tensor::testgen_sin!();
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-core/src/nn/attention/mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub fn generate_autoregressive_mask<B: Backend>(
mask = mask.slice_assign([0..1, i..i + 1, i + 1..seq_length], values);
}

mask = mask.repeat(0, batch_size);
mask = mask.repeat_dim(0, batch_size);

mask.equal_elem(1_i64.elem::<i64>())
}
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-core/src/nn/loss/cross_entropy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
* weights
.clone()
.reshape([1, nr_classes])
.repeat(0, batch_size);
.repeat_dim(0, batch_size);
let weights = weights.clone().gather(0, targets);
let tensor = Self::apply_mask_2d(tensor, mask);
tensor.sum().neg() / weights.sum()
Expand Down Expand Up @@ -224,7 +224,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
if let Some(mask) = mask {
let [batch_size, nr_classes] = tensor.dims();
tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat(1, nr_classes), 0);
tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
}

tensor
Expand Down Expand Up @@ -312,7 +312,7 @@ mod tests {
* targets_logits
* Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
.unsqueeze()
.repeat(0, 4);
.repeat_dim(0, 4);
let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
}
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-core/src/nn/rope_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl RotaryEncodingConfig {
.float()
.unsqueeze()
.transpose()
.repeat(1, self.d_model / 2)
.repeat_dim(1, self.d_model / 2)
* theta_i.unsqueeze();

// Convert frequency values to complex numbers (polar form)
Expand All @@ -71,7 +71,7 @@ impl RotaryEncodingConfig {
.reshape([self.max_sequence_length, 2, self.d_model / 2])
.transpose()
.unsqueeze_dim::<4>(2)
.repeat(2, 2)
.repeat_dim(2, 2)
.reshape([self.max_sequence_length, self.d_model, 2]);

RotaryEncoding {
Expand Down
18 changes: 9 additions & 9 deletions crates/burn-fusion/src/ops/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use burn_tensor::{
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription,
HandleContainer, OperationDescription, PermuteOperationDescription,
RepeatOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
},
Device, Shape,
Expand Down Expand Up @@ -575,22 +575,22 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
out
}

fn bool_repeat<const D: usize>(
fn bool_repeat_dim<const D: usize>(
tensor: BoolTensor<Self, D>,
dim: usize,
times: usize,
) -> BoolTensor<Self, D> {
#[derive(new)]
struct RepeatOps<B: FusionBackend, const D: usize> {
desc: RepeatOperationDescription,
struct RepeatDimOps<B: FusionBackend, const D: usize> {
desc: RepeatDimOperationDescription,
_b: PhantomData<B>,
}

impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatOps<B, D> {
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B, D> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_bool_tensor::<B, D>(&self.desc.tensor);

let output = B::bool_repeat::<D>(tensor, self.desc.dim, self.desc.times);
let output = B::bool_repeat_dim::<D>(tensor, self.desc.dim, self.desc.times);

handles.register_bool_tensor::<B, D>(&self.desc.out.id, output);
}
Expand All @@ -601,16 +601,16 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
shape[dim] *= times;
let out = tensor.client.tensor_uninitialized(shape, DType::Bool);

let desc = RepeatOperationDescription {
let desc = RepeatDimOperationDescription {
tensor: tensor.into_description(),
dim,
times,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::BaseBool(BaseOperationDescription::Repeat(desc.clone())),
RepeatOps::<B, D>::new(desc),
OperationDescription::BaseBool(BaseOperationDescription::RepeatDim(desc.clone())),
RepeatDimOps::<B, D>::new(desc),
);

out
Expand Down
16 changes: 8 additions & 8 deletions crates/burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1696,22 +1696,22 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
out
}

fn float_repeat<const D: usize>(
fn float_repeat_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
times: usize,
) -> FloatTensor<Self, D> {
#[derive(new)]
struct RepeatOps<B: FusionBackend, const D: usize> {
desc: RepeatOperationDescription,
struct RepeatDimOps<B: FusionBackend, const D: usize> {
desc: RepeatDimOperationDescription,
_b: PhantomData<B>,
}

impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatOps<B, D> {
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B, D> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_float_tensor::<B, D>(&self.desc.tensor);

let output = B::float_repeat::<D>(tensor, self.desc.dim, self.desc.times);
let output = B::float_repeat_dim::<D>(tensor, self.desc.dim, self.desc.times);

handles.register_float_tensor::<B, D>(&self.desc.out.id, output);
}
Expand All @@ -1724,16 +1724,16 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
.client
.tensor_uninitialized(shape, B::FloatElem::dtype());

let desc = RepeatOperationDescription {
let desc = RepeatDimOperationDescription {
tensor: tensor.into_description(),
dim,
times,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::BaseFloat(BaseOperationDescription::Repeat(desc.clone())),
RepeatOps::<B, D>::new(desc),
OperationDescription::BaseFloat(BaseOperationDescription::RepeatDim(desc.clone())),
RepeatDimOps::<B, D>::new(desc),
);

out
Expand Down
16 changes: 8 additions & 8 deletions crates/burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1755,22 +1755,22 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out
}

fn int_repeat<const D: usize>(
fn int_repeat_dim<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
times: usize,
) -> IntTensor<Self, D> {
#[derive(new)]
struct RepeatOps<B: FusionBackend, const D: usize> {
desc: RepeatOperationDescription,
struct RepeatDimOps<B: FusionBackend, const D: usize> {
desc: RepeatDimOperationDescription,
_b: PhantomData<B>,
}

impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatOps<B, D> {
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B, D> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_int_tensor::<B, D>(&self.desc.tensor);

let output = B::int_repeat::<D>(tensor, self.desc.dim, self.desc.times);
let output = B::int_repeat_dim::<D>(tensor, self.desc.dim, self.desc.times);

handles.register_int_tensor::<B, D>(&self.desc.out.id, output);
}
Expand All @@ -1783,16 +1783,16 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
.client
.tensor_uninitialized(shape, B::IntElem::dtype());

let desc = RepeatOperationDescription {
let desc = RepeatDimOperationDescription {
tensor: tensor.into_description(),
dim,
times,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::BaseInt(BaseOperationDescription::Repeat(desc.clone())),
RepeatOps::<B, D>::new(desc),
OperationDescription::BaseInt(BaseOperationDescription::RepeatDim(desc.clone())),
RepeatDimOps::<B, D>::new(desc),
);

out
Expand Down
Loading

0 comments on commit f7639bd

Please sign in to comment.