Skip to content

Commit

Permalink
refactor: pow ops (tracel-ai#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Nov 12, 2022
1 parent 8c050c2 commit ef01a4e
Show file tree
Hide file tree
Showing 15 changed files with 71 additions and 108 deletions.
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/erf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ register_ops!(
name ADTensorErfOps,
partial |state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>|{
let value = state.input.value();
let exponent = B::neg(&value.powf(2.0.to_elem()));
let exponent = B::neg(&B::powf(&value, 2.0));
let numerator = B::mul_scalar(&B::exp(&exponent), &2.0.to_elem());
let denominator = std::f64::consts::PI.sqrt().to_elem();
let value = B::div_scalar(&numerator, &denominator);
Expand Down
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ mod cat;
mod creation;
mod erf;
mod module;
mod pow;
mod tensor;

mod macros;
Expand Down
61 changes: 0 additions & 61 deletions burn-tensor/src/tensor/backend/autodiff/ops/pow.rs

This file was deleted.

31 changes: 31 additions & 0 deletions burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -983,4 +983,35 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {

unary_ops_wrapper(tensor.node.clone(), output, ops)
}

fn powf<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
value: f32,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
#[derive(new, Debug)]
struct Backward<B: Backend, const D: usize> {
value: f32,
_b: B,
}

impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for Backward<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
let value = B::mul_scalar(
&B::powf(&state.input.value(), self.value - 1.0),
&self.value.clone().to_elem(),
);
B::mul(&state.output.grad(), &value)
}
}

let output = B::powf(tensor.tensor_ref(), value);
let ops = Backward::<B, D>::new(value, B::default());

unary_ops_wrapper(tensor.node.clone(), output, ops)
}
}
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ pub trait Backend:
+ Ones<Self::TensorPrimitive<D>>
+ TensorOpsCat<Self::Elem, D>
+ TensorOpsErf<Self::Elem, D>
+ TensorOpsPow<Self::Elem, D>
+ ReLU<Self::Elem, D>
+ Clone
+ Send
Expand Down
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/ndarray/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
mod cat;
mod creation;
mod erf;
mod pow;
16 changes: 0 additions & 16 deletions burn-tensor/src/tensor/backend/ndarray/ops/pow.rs

This file was deleted.

7 changes: 7 additions & 0 deletions burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,13 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {

NdArrayTensor { array, shape }
}

fn powf<const D: usize>(tensor: &NdArrayTensor<E, D>, value: f32) -> NdArrayTensor<E, D> {
let array = tensor.array.mapv(|a| a.pow_elem(value)).into_shared();
let shape = tensor.shape;

NdArrayTensor { array, shape }
}
}

fn to_slice_args<const D1: usize, const D2: usize>(
Expand Down
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/tch/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
mod cat;
mod creation;
mod erf;
mod pow;
21 changes: 0 additions & 21 deletions burn-tensor/src/tensor/backend/tch/ops/pow.rs

This file was deleted.

4 changes: 4 additions & 0 deletions burn-tensor/src/tensor/backend/tch/tensor_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
fn log<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
to_tensor(tensor.tensor.log())
}

fn powf<const D: usize>(tensor: &TchTensor<E, D>, value: f32) -> TchTensor<E, D> {
to_tensor(tensor.tensor.pow_tensor_scalar(value as f64))
}
}

fn to_tensor<const D: usize, E: TchElement>(tensor: tch::Tensor) -> TchTensor<E, D> {
Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ where
///
/// `y = x^a`
pub fn powf(&self, value: f32) -> Self {
Self::new(self.value.powf(value))
Self::new(B::powf(&self.value, value))
}

/// Returns the shape of the current tensor.
Expand Down
5 changes: 1 addition & 4 deletions burn-tensor/src/tensor/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,16 +194,13 @@ pub trait TensorOps<B: Backend> {
) -> <B::IntegerBackend as Backend>::TensorPrimitive<D>;
fn exp<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
fn log<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
fn powf<const D: usize>(tensor: &B::TensorPrimitive<D>, value: f32) -> B::TensorPrimitive<D>;
}

pub trait TensorOpsCat<E, const D: usize> {
fn cat(tensors: Vec<&Self>, dim: usize) -> Self;
}

pub trait TensorOpsPow<E, const D: usize> {
fn powf(&self, value: f32) -> Self;
}

pub trait TensorOpsErf<E, const D: usize> {
fn erf(&self) -> Self;
}
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/tests/tensor/grad/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod mask;
mod matmul;
mod mul;
mod neg;
mod pow;
mod reshape;
mod softmax;
mod sub;
Expand Down
25 changes: 25 additions & 0 deletions burn-tensor/tests/tensor/grad/pow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use crate::tensor::TestADTensor;
use burn_tensor::Data;

#[test]
fn should_diff_powf() {
let data_1 = Data::<f32, 2>::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = Data::<f32, 2>::from([[6.0, 7.0], [9.0, 10.0]]);

let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);

let tensor_3 = tensor_1.matmul(&tensor_2.powf(0.4));
let tensor_4 = tensor_3.matmul(&tensor_2);
let grads = tensor_4.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

grad_1
.to_data()
.assert_approx_eq(&Data::from([[68.0, 79.0328], [68.0, 79.0328]]), 3);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[23.5081, 25.2779], [26.0502, 28.6383]]), 3);
}

0 comments on commit ef01a4e

Please sign in to comment.