Skip to content

Commit

Permalink
Refactor/tensor ops (tracel-ai#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Oct 16, 2022
1 parent 72e4433 commit 0e1b0ac
Show file tree
Hide file tree
Showing 23 changed files with 232 additions and 100 deletions.
3 changes: 3 additions & 0 deletions burn-tensor/src/graph/node/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ where
pub fn value(&self) -> Out {
self.value.clone()
}
pub fn value_ref(&self) -> &Out {
&self.value
}
}

#[derive(Debug, Clone)]
Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/autodiff/activation/relu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl<B: Backend, const D: usize> ReLU<B::Elem, D> for ADTensor<D, B> {
fn relu(&self) -> Self {
execute_ops!(
input self.node.clone(),
out ReLU::relu(&self.tensor()),
out self.tensor_ref().relu(),
ops ADReLU::<B, D>::new(),
)
}
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/src/tensor/backend/autodiff/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod activation;
mod backend;
mod ops;
mod tensor;
mod tensor_ops;

pub use backend::*;
pub use ops::*;
Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl<B: Backend, const D: usize> TensorOpsAdd<B::Elem, D> for ADTensor<D, B> {
fn add_scalar(&self, other: &B::Elem) -> Self {
execute_ops!(
input self.node.clone(),
out TensorOpsAdd::add_scalar(&self.tensor(), other),
out TensorOpsAdd::add_scalar(self.tensor_ref(), other),
ops ADTensorAddScalarOps::<B, D>::new(*other),
)
}
Expand Down
4 changes: 2 additions & 2 deletions burn-tensor/src/tensor/backend/autodiff/ops/cat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl<const D: usize, B: Backend> BackwardRecordedOps<B::TensorPrimitive<D>>
{
fn backward_step(&self, state: &BackwardNodeState<B::TensorPrimitive<D>>) {
let grad = state.grad();
let indexes: Vec<_> = grad.shape().dims.iter().map(|v| 0..*v).collect();
let indexes: Vec<_> = B::shape(&grad).dims.iter().map(|v| 0..*v).collect();
let indexes: [std::ops::Range<usize>; D] = indexes.try_into().unwrap();

self.nodes.iter().enumerate().for_each(|(i, node)| {
Expand Down Expand Up @@ -76,7 +76,7 @@ impl<B: Backend, const D: usize> TensorOpsCat<B::Elem, D> for ADTensor<D, B> {

let out = TensorOpsCat::cat(tensors_inner_ref, dim);

let shape = *out.shape();
let shape = *B::shape(&out);
let state = crate::graph::node::ForwardNodeState::new(out);

let ops = ForwardCatOps::<D, B>::new(nodes, dim);
Expand Down
26 changes: 24 additions & 2 deletions burn-tensor/src/tensor/backend/autodiff/ops/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ macro_rules! execute_ops {
ops $ops:expr,
) => {{
let callback = || {
let shape = $crate::tensor::ops::TensorOpsUtilities::shape(&$out).clone();
let shape = B::shape(&$out).clone();
let state = $crate::graph::node::ForwardNodeState::new($out);

let ops = std::sync::Arc::new($ops);
Expand All @@ -137,7 +137,7 @@ macro_rules! execute_ops {
ops $ops:expr,
) => {{
let callback = || {
let shape = $crate::tensor::ops::TensorOpsUtilities::shape(&$out).clone();
let shape = B::shape(&$out).clone();
let state = $crate::graph::node::ForwardNodeState::new($out);

let ops = std::sync::Arc::new($ops);
Expand All @@ -151,6 +151,28 @@ macro_rules! execute_ops {
};
callback()
}};
(
input $input:expr,
out $out:expr,
ops $ops:expr,
shape $shape:expr,
) => {{
let callback = || {
let shape = $shape;
let state = $crate::graph::node::ForwardNodeState::new($out);

let ops = std::sync::Arc::new($ops);
let ops = $crate::graph::ops::ForwardUnaryRecordedOps::new($input, ops.clone());
let ops = std::sync::Arc::new(ops);

let node = $crate::graph::node::ForwardNode::from_unary(&$input, state, ops);
let node = std::sync::Arc::new(node);

$crate::tensor::backend::autodiff::ADTensor { node, shape }
};
callback()
}};

(
init $out:expr
) => {{
Expand Down
13 changes: 9 additions & 4 deletions burn-tensor/src/tensor/backend/autodiff/ops/precision.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::backend::autodiff::ADBackendDecorator;
use crate::backend::Backend;
use crate::ops::TensorOpsPrecision;
use crate::{define_ops, execute_ops};
use crate::ops::{TensorOps, TensorOpsPrecision};
use crate::{define_ops, execute_ops, Shape};
use crate::{
graph::ops::{UnaryOps, UnaryOpsNodeState},
tensor::backend::autodiff::ADTensor,
Expand Down Expand Up @@ -50,18 +50,23 @@ impl<B: Backend, const D: usize> TensorOpsPrecision<ADBackendDecorator<B>, D>
for <ADBackendDecorator<B> as Backend>::TensorPrimitive<D>
{
fn to_full_precision(&self) -> ADTensor<D, <B as Backend>::FullPrecisionBackend> {
let out = TensorOpsPrecision::to_full_precision(&self.tensor());
let shape: Shape<D> =
*<B::FullPrecisionBackend as TensorOps<B::FullPrecisionBackend>>::shape(&out);

execute_ops!(
input self.node.clone(),
out TensorOpsPrecision::to_full_precision(&self.tensor()),
out out,
ops ADTensorToPrecisionOps::<B, D>::new(),
shape shape,
)
}

fn from_full_precision(
tensor_full: ADTensor<D, <B as Backend>::FullPrecisionBackend>,
) -> ADTensor<D, B> {
let tensor = <B as Backend>::TensorPrimitive::from_full_precision(tensor_full.tensor());
let shape = *crate::tensor::ops::TensorOpsUtilities::shape(&tensor);
let shape = *B::shape(&tensor);
let state = crate::graph::node::ForwardNodeState::new(tensor);

let ops = std::sync::Arc::new(ADTensorFromFullPrecisionOps::<B, D>::new());
Expand Down
4 changes: 2 additions & 2 deletions burn-tensor/src/tensor/backend/autodiff/ops/reshape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ impl<B: Backend, const D1: usize, const D2: usize>
let mut grad = state.output.grad();
let value = state.output.value();

let shape_grad = *grad.shape();
let shape_value = *value.shape();
let shape_grad = *B::shape(&grad);
let shape_value = *B::shape(&value);

if shape_value == shape_grad {
return grad.reshape(self.shape);
Expand Down
21 changes: 6 additions & 15 deletions burn-tensor/src/tensor/backend/autodiff/tensor/base.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
execute_ops,
graph::node::ForwardNodeRef,
tensor::{backend::Backend, ops::TensorOpsUtilities, Data, Shape},
tensor::{backend::Backend, Shape},
};

#[derive(Debug, Clone)]
Expand All @@ -10,26 +10,13 @@ pub struct ADTensor<const D: usize, B: Backend> {
pub shape: Shape<D>,
}

impl<B: Backend, const D: usize> TensorOpsUtilities<B::Elem, D> for ADTensor<D, B> {
fn shape(&self) -> &Shape<D> {
&self.shape
}

fn into_data(self) -> Data<B::Elem, D> {
self.tensor().into_data()
}
fn to_data(&self) -> Data<B::Elem, D> {
self.tensor().to_data()
}
}

impl<B: Backend, const D: usize> ADTensor<D, B> {
pub fn from_tensor(tensor: B::TensorPrimitive<D>) -> Self {
let node = execute_ops!(
init tensor.clone()
);

let shape = *tensor.shape();
let shape = *B::shape(&tensor);
Self { node, shape }
}
}
Expand All @@ -38,6 +25,10 @@ impl<B: Backend, const D: usize> ADTensor<D, B> {
pub fn tensor(&self) -> B::TensorPrimitive<D> {
self.node.state.value()
}

pub fn tensor_ref(&self) -> &B::TensorPrimitive<D> {
self.node.state.value_ref()
}
}

#[cfg(test)]
Expand Down
40 changes: 40 additions & 0 deletions burn-tensor/src/tensor/backend/autodiff/tensor_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use super::ADBackendDecorator;
use crate::{backend::Backend, ops::TensorOps, Data, Shape};

impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn shape<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> &Shape<D> {
B::shape(tensor.tensor_ref())
}

fn to_data<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> Data<<ADBackendDecorator<B> as Backend>::Elem, D> {
B::to_data(tensor.tensor_ref())
}

fn into_data<const D: usize>(
tensor: <ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> Data<<ADBackendDecorator<B> as Backend>::Elem, D> {
B::into_data(tensor.tensor())
}

fn bool_shape<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D>,
) -> &Shape<D> {
B::bool_shape(tensor)
}

fn bool_to_data<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D>,
) -> Data<bool, D> {
B::bool_to_data(tensor)
}

fn bool_into_data<const D: usize>(
tensor: <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D>,
) -> Data<bool, D> {
B::bool_into_data(tensor)
}
}
14 changes: 5 additions & 9 deletions burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ use crate::tensor::Element;
use crate::tensor::{Data, Distribution, Shape};
use crate::Gradients;

pub trait Backend: Clone + Sized + Default + Send + Sync + std::fmt::Debug + 'static {
pub trait Backend:
TensorOps<Self> + Clone + Sized + Default + Send + Sync + std::fmt::Debug + 'static
{
type Device: Copy + Clone + Default + std::fmt::Debug + Send + Sync;
type Elem: Element;
type FullPrecisionElem: Element;
type FullPrecisionBackend: Backend<Elem = Self::FullPrecisionElem, Device = Self::Device>;
type IntegerBackend: Backend<Elem = i64, Device = Self::Device>;
type TensorPrimitive<const D: usize>: TensorOpsUtilities<Self::Elem, D>
+ TensorOpsMatmul<Self::Elem, D>
type TensorPrimitive<const D: usize>: TensorOpsMatmul<Self::Elem, D>
+ TensorOpsTranspose<Self::Elem, D>
+ TensorOpsMul<Self::Elem, D>
+ TensorOpsDiv<Self::Elem, D>
Expand Down Expand Up @@ -44,12 +45,7 @@ pub trait Backend: Clone + Sized + Default + Send + Sync + std::fmt::Debug + 'st
+ 'static
+ std::fmt::Debug;

type BoolTensorPrimitive<const D: usize>: TensorOpsUtilities<bool, D>
+ Clone
+ Send
+ Sync
+ 'static
+ std::fmt::Debug;
type BoolTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + std::fmt::Debug;

fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/src/tensor/backend/ndarray/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod backend;
mod ops;
mod shape;
mod tensor;
mod tensor_ops;

pub use backend::*;
pub use shape::*;
Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/ndarray/ops/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ where
fn add(&self, other: &Self) -> Self {
let array = self.array.clone() + other.array.clone();
let array = array.into_shared();
let shape = self.shape.higher(other.shape());
let shape = self.shape.higher(&other.shape);

Self { array, shape }
}
Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/ndarray/ops/arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn arg<E: NdArrayElement, F, const D: usize>(
where
F: Fn(&f64, &f64) -> Ordering,
{
let mut data = tensor.to_data();
let mut data = <NdArrayBackend<E> as TensorOps<NdArrayBackend<E>>>::to_data::<D>(tensor);
let batch_size = tensor.shape.dims[dim];
let mut start = 0;
let mut end = tensor.shape.dims[dim];
Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/ndarray/ops/cat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use ndarray::{Axis, IxDyn};

impl<P: NdArrayElement, const D: usize> TensorOpsCat<P, D> for NdArrayTensor<P, D> {
fn cat(tensors: Vec<&Self>, dim: usize) -> Self {
let mut shape = *tensors.get(0).unwrap().shape();
let mut shape = tensors.get(0).unwrap().shape;
shape.dims[dim] = tensors.len();

let arrays: Vec<ndarray::ArrayView<P, IxDyn>> =
Expand Down
31 changes: 15 additions & 16 deletions burn-tensor/src/tensor/backend/ndarray/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::tensor::{ops::TensorOpsUtilities, Data, Shape};
use crate::tensor::{Data, Shape};
use ndarray::{s, ArcArray, Array, Axis, Dim, Ix2, Ix3, IxDyn};

#[derive(Debug, Clone)]
Expand All @@ -7,22 +7,21 @@ pub struct NdArrayTensor<E, const D: usize> {
pub shape: Shape<D>,
}

impl<E, const D: usize> TensorOpsUtilities<E, D> for NdArrayTensor<E, D>
where
E: Default + Clone,
{
fn shape(&self) -> &Shape<D> {
&self.shape
}

fn into_data(self) -> Data<E, D> {
let values = self.array.into_iter().collect();
Data::new(values, self.shape)
}
#[cfg(test)]
mod utils {
use crate::{backend::NdArrayBackend, ops::TensorOps, NdArrayElement};

fn to_data(&self) -> Data<E, D> {
let values = self.array.clone().into_iter().collect();
Data::new(values, self.shape)
use super::*;
impl<E, const D: usize> NdArrayTensor<E, D>
where
E: Default + Clone,
{
pub(crate) fn into_data(self) -> Data<E, D>
where
E: NdArrayElement,
{
<NdArrayBackend<E> as TensorOps<NdArrayBackend<E>>>::into_data::<D>(self)
}
}
}

Expand Down
44 changes: 44 additions & 0 deletions burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use super::NdArrayBackend;
use crate::{backend::Backend, ops::TensorOps, Data, NdArrayElement, Shape};

impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn shape<const D: usize>(
tensor: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
) -> &Shape<D> {
&tensor.shape
}

fn to_data<const D: usize>(
tensor: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
) -> Data<<NdArrayBackend<E> as Backend>::Elem, D> {
let values = tensor.array.iter().map(Clone::clone).collect();
Data::new(values, tensor.shape)
}

fn into_data<const D: usize>(
tensor: <NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
) -> Data<<NdArrayBackend<E> as Backend>::Elem, D> {
let values = tensor.array.into_iter().collect();
Data::new(values, tensor.shape)
}

fn bool_shape<const D: usize>(
tensor: &<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> &Shape<D> {
&tensor.shape
}

fn bool_to_data<const D: usize>(
tensor: &<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> Data<bool, D> {
let values = tensor.array.iter().map(Clone::clone).collect();
Data::new(values, tensor.shape)
}

fn bool_into_data<const D: usize>(
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> Data<bool, D> {
let values = tensor.array.into_iter().collect();
Data::new(values, tensor.shape)
}
}
Loading

0 comments on commit 0e1b0ac

Please sign in to comment.