Skip to content

Commit

Permalink
Refactor/ad backend decorator (tracel-ai#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Oct 14, 2022
1 parent 31d512e commit 72e4433
Show file tree
Hide file tree
Showing 8 changed files with 270 additions and 424 deletions.
231 changes: 85 additions & 146 deletions burn-tensor/src/tensor/backend/autodiff/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,154 +3,93 @@ use crate::graph::grad::Gradients;
use crate::tensor::backend::{ADBackend, Backend};
use crate::tensor::{Data, Distribution, Shape};

#[cfg(feature = "ndarray")]
use crate::NdArrayElement;

#[cfg(feature = "tch")]
use crate::TchElement;

macro_rules! define_impl {
(
name: $name:ident,
backend: $backend:ty,
element: $element:ident
) => {
#[derive(Clone, Copy, Debug, Default)]
pub struct $name<E> {
_b: $backend,
}

impl<E: $element> Backend for $name<E> {
type Device = <$backend as Backend>::Device;
type Elem = E;
type FullPrecisionElem = f32;
type IntegerBackend = <$backend as Backend>::IntegerBackend;
type FullPrecisionBackend = $name<<$backend as Backend>::FullPrecisionElem>;
type TensorPrimitive<const D: usize> = ADTensor<D, $backend>;
type BoolTensorPrimitive<const D: usize> =
<$backend as Backend>::BoolTensorPrimitive<D>;

fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
device: Self::Device,
) -> Self::TensorPrimitive<D> {
let tensor = <$backend as Backend>::from_data(data, device);
ADTensor::from_tensor(tensor)
}

fn from_data_bool<const D: usize>(
data: Data<bool, D>,
device: Self::Device,
) -> Self::BoolTensorPrimitive<D> {
<$backend as Backend>::from_data_bool(data, device)
}

fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<Self::Elem>,
device: Self::Device,
) -> Self::TensorPrimitive<D> {
Self::from_inner(<$backend as Backend>::random(shape, distribution, device))
}

fn ad_enabled() -> bool {
true
}

fn zeros<const D: usize>(
shape: Shape<D>,
device: Self::Device,
) -> Self::TensorPrimitive<D> {
Self::from_inner(<$backend as Backend>::zeros(shape, device))
}

fn ones<const D: usize>(
shape: Shape<D>,
device: Self::Device,
) -> Self::TensorPrimitive<D> {
Self::from_inner(<$backend as Backend>::ones(shape, device))
}

fn name() -> String {
format!("autodiff<{}>", <$backend as Backend>::name())
}

fn seed(seed: u64) {
<$backend as Backend>::seed(seed)
}
}

impl<E: $element> ADBackend for $name<E> {
type InnerBackend = $backend;

fn backward<const D: usize>(tensor: &Self::TensorPrimitive<D>) -> Gradients {
tensor.backward()
}
fn grad<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
grads: &Gradients,
) -> Option<<$backend as Backend>::TensorPrimitive<D>> {
grads.wrt(tensor).map(|grad| grad.clone())
}

fn inner<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
) -> <Self::InnerBackend as Backend>::TensorPrimitive<D> {
tensor.tensor()
}

fn from_inner<const D: usize>(
tensor: <Self::InnerBackend as Backend>::TensorPrimitive<D>,
) -> Self::TensorPrimitive<D> {
ADTensor::from_tensor(tensor)
}
}
};
#[derive(Clone, Copy, Debug, Default)]
pub struct ADBackendDecorator<B> {
_b: B,
}

#[cfg(feature = "ndarray")]
define_impl!(
name: ADBackendNdArray,
backend: crate::tensor::backend::ndarray::NdArrayBackend<E>,
element: NdArrayElement
);
#[cfg(feature = "tch")]
define_impl!(
name: ADBackendTch,
backend: crate::tensor::backend::tch::TchBackend<E>,
element: TchElement
);

#[macro_export]
macro_rules! register_ndarray {
() => {
#[cfg(feature = "ndarray")]
mod ndarray_impl {
use super::*;
use $crate::NdArrayElement;

define_impl!(
$crate::tensor::backend::autodiff::ADBackendNdArray::<E>,
$crate::tensor::backend::ndarray::NdArrayBackend::<E>,
NdArrayElement
);
}
};
impl<B: Backend> Backend for ADBackendDecorator<B> {
type Device = B::Device;
type Elem = B::Elem;
type FullPrecisionElem = B::FullPrecisionElem;
type IntegerBackend = B::IntegerBackend;
type FullPrecisionBackend = ADBackendDecorator<B::FullPrecisionBackend>;
type TensorPrimitive<const D: usize> = ADTensor<D, B>;
type BoolTensorPrimitive<const D: usize> = B::BoolTensorPrimitive<D>;

fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
device: Self::Device,
) -> Self::TensorPrimitive<D> {
let tensor = B::from_data(data, device);
ADTensor::from_tensor(tensor)
}

fn from_data_bool<const D: usize>(
data: Data<bool, D>,
device: Self::Device,
) -> Self::BoolTensorPrimitive<D> {
B::from_data_bool(data, device)
}

fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<Self::Elem>,
device: Self::Device,
) -> Self::TensorPrimitive<D> {
Self::from_inner(B::random(shape, distribution, device))
}

fn ad_enabled() -> bool {
true
}

fn zeros<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
Self::from_inner(B::zeros(shape, device))
}

fn ones<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
Self::from_inner(B::ones(shape, device))
}

fn name() -> String {
format!("autodiff<{}>", B::name())
}

fn seed(seed: u64) {
B::seed(seed)
}
}

#[macro_export]
macro_rules! register_tch {
() => {
#[cfg(feature = "tch")]
mod tch_impl {
use super::*;
use $crate::TchElement;

define_impl!(
$crate::tensor::backend::autodiff::ADBackendTch::<E>,
$crate::tensor::backend::tch::TchBackend::<E>,
TchElement
);
}
};
impl<B: Backend> ADBackend for ADBackendDecorator<B> {
type InnerBackend = B;

fn backward<const D: usize>(tensor: &Self::TensorPrimitive<D>) -> Gradients {
tensor.backward()
}
fn grad<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
grads: &Gradients,
) -> Option<B::TensorPrimitive<D>> {
grads.wrt(tensor).cloned()
}

fn inner<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
) -> <Self::InnerBackend as Backend>::TensorPrimitive<D> {
tensor.tensor()
}

fn from_inner<const D: usize>(
tensor: <Self::InnerBackend as Backend>::TensorPrimitive<D>,
) -> Self::TensorPrimitive<D> {
ADTensor::from_tensor(tensor)
}
}

#[cfg(feature = "ndarray")]
pub type ADBackendNdArray<E> =
ADBackendDecorator<crate::tensor::backend::ndarray::NdArrayBackend<E>>;

#[cfg(feature = "tch")]
pub type ADBackendTch<E> = ADBackendDecorator<crate::tensor::backend::tch::TchBackend<E>>;
82 changes: 35 additions & 47 deletions burn-tensor/src/tensor/backend/autodiff/ops/aggregation.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::backend::autodiff::ADBackendDecorator;
use crate::tensor::ElementConversion;
use crate::Tensor;
use crate::{backend::Backend, tensor::ops::*};
Expand Down Expand Up @@ -97,54 +98,41 @@ impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimit
}
}

macro_rules! define_impl {
(
$backend:ty,
$backend_inner:ty,
$element:ident
) => {
impl<E: $element, const D: usize> TensorOpsAggregation<$backend, D>
for <$backend as Backend>::TensorPrimitive<D>
{
fn mean(&self) -> <$backend as Backend>::TensorPrimitive<1> {
execute_ops!(
input self.node.clone(),
out TensorOpsAggregation::mean(&self.tensor()),
ops ADTensorOpsMean::<$backend_inner, D>::new(self.shape.clone()),
)
}

fn sum(&self) -> <$backend as Backend>::TensorPrimitive<1> {
execute_ops!(
input self.node.clone(),
out TensorOpsAggregation::sum(&self.tensor()),
ops ADTensorOpsSum::<$backend_inner, D>::new(self.shape.clone()),
)
}

fn mean_dim(&self, dim: usize) -> <$backend as Backend>::TensorPrimitive<D> {
execute_ops!(
input self.node.clone(),
out TensorOpsAggregation::mean_dim(&self.tensor(), dim),
ops ADTensorOpsMeanDim::<$backend_inner, D>::new((self.shape.clone(), dim)),
)

}

fn sum_dim(&self, dim: usize) -> <$backend as Backend>::TensorPrimitive<D> {
execute_ops!(
input self.node.clone(),
out TensorOpsAggregation::sum_dim(&self.tensor(), dim),
ops ADTensorOpsSumDim::<$backend_inner, D>::new((self.shape.clone(), dim)),
)

}
}
};
}
impl<B: Backend, const D: usize> TensorOpsAggregation<ADBackendDecorator<B>, D>
for <ADBackendDecorator<B> as Backend>::TensorPrimitive<D>
{
fn mean(&self) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<1> {
execute_ops!(
input self.node.clone(),
out TensorOpsAggregation::mean(&self.tensor()),
ops ADTensorOpsMean::<B, D>::new(self.shape),
)
}

fn sum(&self) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<1> {
execute_ops!(
input self.node.clone(),
out TensorOpsAggregation::sum(&self.tensor()),
ops ADTensorOpsSum::<B, D>::new(self.shape),
)
}

crate::register_tch!();
crate::register_ndarray!();
fn mean_dim(&self, dim: usize) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
execute_ops!(
input self.node.clone(),
out TensorOpsAggregation::mean_dim(&self.tensor(), dim),
ops ADTensorOpsMeanDim::<B, D>::new((self.shape, dim)),
)
}

fn sum_dim(&self, dim: usize) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
execute_ops!(
input self.node.clone(),
out TensorOpsAggregation::sum_dim(&self.tensor(), dim),
ops ADTensorOpsSumDim::<B, D>::new((self.shape, dim)),
)
}
}

#[cfg(test)]
mod tests {
Expand Down
42 changes: 16 additions & 26 deletions burn-tensor/src/tensor/backend/autodiff/ops/arg.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,21 @@
use crate::backend::autodiff::ADBackendDecorator;
use crate::backend::Backend;
use crate::tensor::ops::*;

macro_rules! define_impl {
(
$backend:ty,
$backend_inner:ty,
$element:ident
) => {
impl<E: $element, const D: usize> TensorOpsArg<$backend, D>
for <$backend as Backend>::TensorPrimitive<D>
{
fn argmax(
&self,
dim: usize,
) -> <<$backend as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
TensorOpsArg::argmax(&self.tensor(), dim)
}
impl<B: Backend, const D: usize> TensorOpsArg<ADBackendDecorator<B>, D>
for <ADBackendDecorator<B> as Backend>::TensorPrimitive<D>
{
fn argmax(
&self,
dim: usize,
) -> <<ADBackendDecorator<B> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
TensorOpsArg::argmax(&self.tensor(), dim)
}

fn argmin(
&self,
dim: usize,
) -> <<$backend as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
TensorOpsArg::argmin(&self.tensor(), dim)
}
}
};
fn argmin(
&self,
dim: usize,
) -> <<ADBackendDecorator<B> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
TensorOpsArg::argmin(&self.tensor(), dim)
}
}

crate::register_tch!();
crate::register_ndarray!();
Loading

0 comments on commit 72e4433

Please sign in to comment.