Skip to content

Commit

Permalink
refactor/div-ops (tracel-ai#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Nov 5, 2022
1 parent ee61e84 commit 10d1c13
Show file tree
Hide file tree
Showing 15 changed files with 211 additions and 236 deletions.
121 changes: 0 additions & 121 deletions burn-tensor/src/tensor/backend/autodiff/ops/div.rs

This file was deleted.

2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/erf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ register_ops!(
let exponent = value.powf(2.0.to_elem()).neg();
let numerator = B::mul_scalar(&exponent.exp(), &2.0.to_elem());
let denominator = std::f64::consts::PI.sqrt().to_elem();
let value = numerator.div_scalar(&denominator);
let value = B::div_scalar(&numerator, &denominator);

B::mul(&state.output.grad(), &value)
},
Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ register_ops!(
name ADTensorLogOps,
partial |state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>|{
let value = state.input.value();
let value = value.ones().div(&value);
let value = B::div(&value.ones(), &value);
B::mul(&state.output.grad(), &value)
},
);
Expand Down
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ mod base;
mod cat;
mod creation;
mod detach;
mod div;
mod erf;
mod exp;
mod index;
Expand Down
80 changes: 79 additions & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
Backend,
},
graph::ops::{BinaryOps, BinaryOpsNodeState, UnaryOps, UnaryOpsNodeState},
ops::{TensorOps, TensorOpsNeg},
ops::{Ones, TensorOps, TensorOpsNeg},
Data, Shape,
};

Expand Down Expand Up @@ -298,4 +298,82 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {

unary_ops_wrapper(lhs.node.clone(), output, ops)
}

fn div<const D: usize>(
lhs: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
rhs: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
#[derive(Default, Debug)]
struct DivBackward<B: Backend, const D: usize> {
_b: B,
}

impl<B: Backend, const D: usize>
BinaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for DivBackward<B, D>
{
fn partial_left(
&self,
state: &BinaryOpsNodeState<
B::TensorPrimitive<D>,
B::TensorPrimitive<D>,
B::TensorPrimitive<D>,
>,
) -> B::TensorPrimitive<D> {
let value = state.right.value();
let value = B::div(&value.ones(), &value);

B::mul(&state.output.grad(), &value)
}

fn partial_right(
&self,
state: &BinaryOpsNodeState<
B::TensorPrimitive<D>,
B::TensorPrimitive<D>,
B::TensorPrimitive<D>,
>,
) -> B::TensorPrimitive<D> {
let value_left = state.left.value();
let value_right = state.right.value();
let value = B::div(&value_left.neg(), &B::mul(&value_right, &value_right));

B::mul(&state.output.grad(), &value)
}
}

let output = B::div(lhs.tensor_ref(), rhs.tensor_ref());
let ops = DivBackward::<B, D>::default();

binary_ops_wrapper(lhs.node.clone(), rhs.node.clone(), output, ops)
}

fn div_scalar<const D: usize>(
lhs: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
rhs: &<ADBackendDecorator<B> as Backend>::Elem,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
#[derive(new, Debug)]
struct DivScalarBackward<B: Backend, const D: usize> {
elem: B::Elem,
}

impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for DivScalarBackward<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
let value = state.input.value();
let tmp = B::div_scalar(&value.ones(), &self.elem);

B::mul(&state.output.grad(), &tmp)
}
}

let output = B::div_scalar(lhs.tensor_ref(), rhs);
let ops = DivScalarBackward::<B, D>::new(*rhs);

unary_ops_wrapper(lhs.node.clone(), output, ops)
}
}
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ pub trait Backend:
type TensorPrimitive<const D: usize>: TensorOpsMatmul<Self::Elem, D>
+ std::ops::Add<Self::TensorPrimitive<D>, Output = Self::TensorPrimitive<D>>
+ TensorOpsTranspose<Self::Elem, D>
+ TensorOpsDiv<Self::Elem, D>
+ TensorOpsNeg<Self::Elem, D>
+ TensorOpsDetach<Self::Elem, D>
+ Zeros<Self::TensorPrimitive<D>>
Expand Down
44 changes: 0 additions & 44 deletions burn-tensor/src/tensor/backend/ndarray/ops/div.rs

This file was deleted.

1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/ndarray/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ mod arg;
mod cat;
mod creation;
mod detach;
mod div;
mod erf;
mod exp;
mod index;
Expand Down
37 changes: 31 additions & 6 deletions burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
lhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
rhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
let array = lhs.array.clone() + rhs.array.clone();
let array = &lhs.array + &rhs.array;
let array = array.into_shared();
let shape = lhs.shape.higher(&rhs.shape);

Expand All @@ -78,7 +78,8 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
lhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
rhs: &E,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
let array = lhs.array.clone() + *rhs;
let array = &lhs.array + *rhs;
let array = array.into_shared();
let shape = lhs.shape;

NdArrayTensor { array, shape }
Expand All @@ -88,7 +89,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
lhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
rhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
let array = lhs.array.clone() - rhs.array.clone();
let array = &lhs.array - &rhs.array;
let array = array.into_shared();
let shape = lhs.shape.higher(&rhs.shape);

Expand All @@ -99,7 +100,8 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
lhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
rhs: &E,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
let array = lhs.array.clone() - *rhs;
let array = &lhs.array - *rhs;
let array = array.into_shared();
let shape = lhs.shape;

NdArrayTensor { array, shape }
Expand All @@ -109,7 +111,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
lhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
rhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
let array = lhs.array.clone() * rhs.array.clone();
let array = &lhs.array * &rhs.array;
let array = array.into_shared();
let shape = lhs.shape.higher(&rhs.shape);

Expand All @@ -120,7 +122,30 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
lhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
rhs: &E,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
let array = lhs.array.clone() * *rhs;
let array = &lhs.array * *rhs;
let array = array.into_shared();
let shape = lhs.shape;

NdArrayTensor { array, shape }
}

fn div<const D: usize>(
lhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
rhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
let array = &lhs.array / &rhs.array;
let array = array.into_shared();
let shape = lhs.shape.higher(&rhs.shape);

NdArrayTensor { array, shape }
}

fn div_scalar<const D: usize>(
lhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
rhs: &E,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
let array = &lhs.array / *rhs;
let array = array.into_shared();
let shape = lhs.shape;

NdArrayTensor { array, shape }
Expand Down
Loading

0 comments on commit 10d1c13

Please sign in to comment.