Skip to content

Commit

Permalink
Refactor/reshape (tracel-ai#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Nov 6, 2022
1 parent e654129 commit d369388
Show file tree
Hide file tree
Showing 16 changed files with 114 additions and 173 deletions.
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ mod mask;
mod module;
mod pow;
mod precision;
mod reshape;
mod tensor;

mod macros;
Expand Down
103 changes: 0 additions & 103 deletions burn-tensor/src/tensor/backend/autodiff/ops/reshape.rs

This file was deleted.

47 changes: 46 additions & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
Backend,
},
graph::ops::{BinaryOps, BinaryOpsNodeState, UnaryOps, UnaryOpsNodeState},
ops::{Ones, TensorOps},
ops::{Ones, TensorOps, TensorOpsAggregation},
Data, Shape,
};

Expand Down Expand Up @@ -476,4 +476,49 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {

unary_ops_wrapper(tensor.node.clone(), output, ops)
}

fn reshape<const D1: usize, const D2: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D1>,
shape: Shape<D2>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D2> {
#[derive(new, Debug)]
struct ReshapeBackward<B: Backend, const D1: usize, const D2: usize> {
shape: Shape<D1>,
_b: B,
}

impl<B: Backend, const D1: usize, const D2: usize>
UnaryOps<B::TensorPrimitive<D1>, B::TensorPrimitive<D2>>
for ReshapeBackward<B, D1, D2>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D1>, B::TensorPrimitive<D2>>,
) -> B::TensorPrimitive<D1> {
let mut grad = state.output.grad();
let value = state.output.value();

let shape_grad = *B::shape(&grad);
let shape_value = *B::shape(&value);

if shape_value == shape_grad {
return B::reshape(&grad, self.shape);
}

for i in 0..D2 {
if shape_value.dims[i] == 1 && shape_grad.dims[i] != 1 {
grad = grad.sum_dim(i);
}
}

B::reshape(&grad, self.shape)
}
}

let shape_old = B::shape(tensor.tensor_ref());
let output = B::reshape(tensor.tensor_ref(), shape);
let ops = ReshapeBackward::<B, D1, D2>::new(*shape_old, B::default());

unary_ops_wrapper(tensor.node.clone(), output, ops)
}
}
3 changes: 1 addition & 2 deletions burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::ops::activation::*;
use crate::ops::*;
use crate::tensor::ops::{TensorOpsIndex, TensorOpsReshape};
use crate::tensor::ops::TensorOpsIndex;
use crate::tensor::Element;
use crate::tensor::{Data, Distribution, Shape};
use crate::Gradients;
Expand All @@ -25,7 +25,6 @@ pub trait Backend:
+ TensorOpsDetach<Self::Elem, D>
+ Zeros<Self::TensorPrimitive<D>>
+ Ones<Self::TensorPrimitive<D>>
+ TensorOpsReshape<Self, D>
+ TensorOpsPrecision<Self, D>
+ TensorOpsIndex<Self::Elem, D>
+ TensorOpsAggregation<Self, D>
Expand Down
20 changes: 10 additions & 10 deletions burn-tensor/src/tensor/backend/ndarray/module_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@ impl<E: NdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {

let mut tensors = Vec::with_capacity(batch_size * seq_length);

for index in indexes
.reshape(Shape::new([batch_size * seq_length]))
for index in NdArrayBackend::reshape(indexes, Shape::new([batch_size * seq_length]))
.array
.iter()
{
let index = *index as usize;
tensors.push(weights.index([index..index + 1, 0..d_model]));
}
let embedding = TensorOpsCat::cat(tensors.iter().collect(), 0);
embedding.reshape(Shape::new([batch_size, seq_length, d_model]))
NdArrayBackend::reshape(&embedding, Shape::new([batch_size, seq_length, d_model]))
}

fn embedding_backward(
Expand All @@ -34,13 +33,14 @@ impl<E: NdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {
let [_n_embedding, d_model] = weights.shape.dims;

let mut weights_grad = weights.zeros();
let output = output.reshape(Shape::new([batch_size * seq_length, d_model]));

for (index_output, index) in indexes
.reshape(Shape::new([batch_size * seq_length]))
.array
.iter()
.enumerate()
let output =
NdArrayBackend::reshape(output, Shape::new([batch_size * seq_length, d_model]));

for (index_output, index) in
NdArrayBackend::reshape(indexes, Shape::new([batch_size * seq_length]))
.array
.iter()
.enumerate()
{
let index = *index as usize;

Expand Down
4 changes: 2 additions & 2 deletions burn-tensor/src/tensor/backend/ndarray/ops/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ macro_rules! keepdim {
let tensor: NdArrayTensor<E, $D> = mean_dim(&$self, $dim);
let mut shape = $self.shape.clone();
shape.dims[$dim] = 1;
tensor.reshape(shape)
NdArrayBackend::reshape(&tensor, shape)
}};
(
$D:expr,
Expand All @@ -29,7 +29,7 @@ macro_rules! keepdim {
let tensor: NdArrayTensor<E, $D> = sum_dim(&$self, $dim);
let mut shape = $self.shape.clone();
shape.dims[$dim] = 1;
tensor.reshape(shape)
NdArrayBackend::reshape(&tensor, shape)
}};
}

Expand Down
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/ndarray/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@ mod map_comparison;
mod mask;
mod pow;
mod precision;
mod reshape;
26 changes: 0 additions & 26 deletions burn-tensor/src/tensor/backend/ndarray/ops/reshape.rs

This file was deleted.

18 changes: 17 additions & 1 deletion burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use super::{BatchMatrix, NdArrayBackend, NdArrayTensor};
use crate::{
backend::{Backend, NdArrayDevice},
ops::TensorOps,
Data, ElementConversion, NdArrayElement, Shape,
to_nd_array_tensor, Data, ElementConversion, NdArrayElement, Shape,
};
use ndarray::Dim;

impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn shape<const D: usize>(
Expand Down Expand Up @@ -195,4 +196,19 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {

NdArrayTensor { array, shape }
}

fn reshape<const D1: usize, const D2: usize>(
tensor: &NdArrayTensor<E, D1>,
shape: Shape<D2>,
) -> NdArrayTensor<E, D2> {
match D2 {
1 => to_nd_array_tensor!(1, shape, tensor.array),
2 => to_nd_array_tensor!(2, shape, tensor.array),
3 => to_nd_array_tensor!(3, shape, tensor.array),
4 => to_nd_array_tensor!(4, shape, tensor.array),
5 => to_nd_array_tensor!(5, shape, tensor.array),
6 => to_nd_array_tensor!(6, shape, tensor.array),
_ => panic!("NdArrayTensor support only 6 dimensions."),
}
}
}
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/tch/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@ mod map_comparison;
mod mask;
mod pow;
mod precision;
mod reshape;
19 changes: 0 additions & 19 deletions burn-tensor/src/tensor/backend/tch/ops/reshape.rs

This file was deleted.

12 changes: 11 additions & 1 deletion burn-tensor/src/tensor/backend/tch/tensor_ops.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{TchBackend, TchDevice, TchKind, TchTensor};
use super::{TchBackend, TchDevice, TchKind, TchShape, TchTensor};
use crate::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape, TchElement};
use std::ops::{Add, Div, Mul, Sub};

Expand Down Expand Up @@ -136,6 +136,16 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
let tensor = tensor.tensor.transpose(dim1 as i64, dim2 as i64);
to_tensor(tensor)
}

fn reshape<const D1: usize, const D2: usize>(
tensor: &TchTensor<E, D1>,
shape: Shape<D2>,
) -> TchTensor<E, D2> {
let shape_tch: TchShape<D2> = shape.into();
let tensor = tensor.tensor.reshape(&shape_tch.dims);

to_tensor(tensor)
}
}

fn to_tensor<const D: usize, E: TchElement>(tensor: tch::Tensor) -> TchTensor<E, D> {
Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ where
///
/// If the tensor can not be reshape to the given shape.
pub fn reshape<const D2: usize, S: Into<Shape<D2>>>(&self, shape: S) -> Tensor<B, D2> {
Tensor::new(self.value.reshape(shape.into()))
Tensor::new(B::reshape(&self.value, shape.into()))
}

/// Returns a new tensor on the given device.
Expand Down
8 changes: 4 additions & 4 deletions burn-tensor/src/tensor/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ pub trait TensorOps<B: Backend> {
dim1: usize,
dim2: usize,
) -> B::TensorPrimitive<D>;
}

pub trait TensorOpsReshape<B: Backend, const D: usize> {
fn reshape<const D2: usize>(&self, shape: Shape<D2>) -> B::TensorPrimitive<D2>;
fn reshape<const D1: usize, const D2: usize>(
tensor: &B::TensorPrimitive<D1>,
shape: Shape<D2>,
) -> B::TensorPrimitive<D2>;
}

pub trait TensorOpsIndex<E, const D1: usize> {
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/tests/tensor/grad/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod div;
mod matmul;
mod mul;
mod neg;
mod reshape;
mod softmax;
mod sub;
mod transpose;
Loading

0 comments on commit d369388

Please sign in to comment.