Skip to content

Commit

Permalink
feat: repeat (tracel-ai#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Oct 24, 2022
1 parent a78886d commit 0c4c657
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 2 deletions.
12 changes: 11 additions & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use super::unary_ops_wrapper;
use crate::{
backend::{autodiff::ADBackendDecorator, Backend},
backend::{
autodiff::{ADBackendDecorator, ADTensor},
Backend,
},
graph::ops::{UnaryOps, UnaryOpsNodeState},
ops::TensorOps,
Data, Shape,
Expand Down Expand Up @@ -76,4 +79,11 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {

unary_ops_wrapper(input, output, ops)
}

fn empty<const D: usize>(
shape: Shape<D>,
device: <ADBackendDecorator<B> as Backend>::Device,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
ADTensor::from_tensor(B::empty(shape, device))
}
}
7 changes: 7 additions & 0 deletions burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,11 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
) -> NdArrayTensor<E, D> {
tensor.clone()
}

fn empty<const D: usize>(
shape: Shape<D>,
device: <NdArrayBackend<E> as Backend>::Device,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
NdArrayBackend::<E>::zeros(shape, device)
}
}
17 changes: 16 additions & 1 deletion burn-tensor/src/tensor/backend/tch/tensor_ops.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{TchBackend, TchDevice, TchTensor};
use super::{TchBackend, TchDevice, TchKind, TchTensor};
use crate::{backend::Backend, ops::TensorOps, Data, Shape, TchElement};

impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
Expand Down Expand Up @@ -57,4 +57,19 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
shape: tensor.shape,
}
}

fn empty<const D: usize>(
shape: Shape<D>,
device: <TchBackend<E> as Backend>::Device,
) -> <TchBackend<E> as Backend>::TensorPrimitive<D> {
let kind = TchKind::new();
let tensor =
tch::Tensor::empty(&shape.dims.map(|a| a as i64), (kind.kind(), device.into()));

TchTensor {
kind,
tensor,
shape,
}
}
}
9 changes: 9 additions & 0 deletions burn-tensor/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,15 @@ where
self.reshape(shape)
}

/// Repeat the tensor along the given dimension.
///
/// # Panics
///
/// If the selected dimension more than one item.
pub fn repeat(&self, dim: usize, times: usize) -> Self {
Self::new(B::repeat(&self.value, dim, times))
}

pub(crate) fn relu(&self) -> Self {
Self::new(self.value.relu())
}
Expand Down
29 changes: 29 additions & 0 deletions burn-tensor/src/tensor/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,35 @@ pub trait TensorOps<B: Backend> {
let data = Data::new(value, shape);
<B::IntegerBackend as Backend>::from_data(data, device)
}
fn empty<const D: usize>(shape: Shape<D>, device: B::Device) -> B::TensorPrimitive<D>;
fn repeat<const D: usize>(
tensor: &B::TensorPrimitive<D>,
dim: usize,
times: usize,
) -> B::TensorPrimitive<D> {
let mut shape = *B::shape(tensor);
if shape.dims[dim] != 1 {
panic!("Can only repeat dimension with dim=1");
}
shape.dims[dim] = times;

let mut i = 0;
let indexes_select_all = [0; D].map(|_| {
let start = 0;
let end = shape.dims[i];
i += 1;
start..end
});

let mut tensor_output = B::empty(shape, B::device(tensor));
for i in 0..times {
let mut indexes = indexes_select_all.clone();
indexes[dim] = i..i + 1;
tensor_output = tensor_output.index_assign(indexes, tensor);
}

tensor_output
}
}

pub trait TensorOpsAdd<E, const D: usize>: std::ops::Add<Self, Output = Self>
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/tests/tensor/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod matmul;
mod mul;
mod neg;
mod powf;
mod repeat;
mod reshape;
mod sub;
mod transpose;
18 changes: 18 additions & 0 deletions burn-tensor/tests/tensor/ops/repeat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use super::super::TestBackend;
use burn_tensor::{Data, Tensor};

#[test]
fn should_support_repeat_ops() {
let data = Data::from([[0.0, 1.0, 2.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);

let data_actual = tensor.repeat(0, 4).into_data();

let data_expected = Data::from([
[0.0, 1.0, 2.0],
[0.0, 1.0, 2.0],
[0.0, 1.0, 2.0],
[0.0, 1.0, 2.0],
]);
assert_eq!(data_expected, data_actual);
}

0 comments on commit 0c4c657

Please sign in to comment.