Skip to content

Commit

Permalink
Feat/erf (tracel-ai#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Sep 24, 2022
1 parent 6625b4c commit a84df25
Show file tree
Hide file tree
Showing 12 changed files with 134 additions and 1 deletion.
7 changes: 6 additions & 1 deletion burn-tensor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ no-default-features = true
[features]
default = ["tch", "ndarray"]
tch = ["dep:tch"]
ndarray = ["dep:ndarray"]
ndarray = ["dep:ndarray", "dep:libm"]
doc = ["dep:tch", "tch/doc-only", "dep:ndarray"]

[dependencies]
Expand All @@ -34,7 +34,12 @@ half = { version = "1.6", features = ["num-traits"] } # needs to be 1.6 to work
# Backends
tch = { version = "0.8", optional = true }
lazy_static = "1.4"



# NdArray
ndarray = { version = "0.15", optional = true }
libm = { version = "0.2", optional = true }

# Autodiff
nanoid = "0.4"
Expand Down
59 changes: 59 additions & 0 deletions burn-tensor/src/tensor/backend/autodiff/ops/erf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use crate::tensor::backend::Backend;
use crate::{
execute_ops,
graph::ops::{UnaryOps, UnaryOpsNodeState},
register_ops,
tensor::{backend::autodiff::ADTensor, ops::*},
ElementConversion,
};

register_ops!(
ops UnaryOps,
name ADTensorErfOps,
partial |state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>|{
let value = state.input.value();
let exponent = value.powf(2.0.to_elem()).neg();
let numerator = exponent.exp().mul_scalar(&2.0.to_elem());
let denominator = std::f64::consts::PI.sqrt().to_elem();
let value = numerator.div_scalar(&denominator);
state.output.grad().mul(&value)
},
);

impl<B: Backend, const D: usize> TensorOpsErf<B::Elem, D> for ADTensor<D, B> {
fn erf(&self) -> Self {
execute_ops!(
input self.node.clone(),
out TensorOpsErf::erf(&self.tensor()),
ops ADTensorErfOps::<B, D>::new(),
)
}
}

#[cfg(test)]
mod tests {
use crate::tensor::{backend::autodiff::helper::TestADTensor, Data};

#[test]
fn should_diff_erf() {
let data_1 = Data::<f64, 2>::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = Data::<f64, 2>::from([[6.0, 7.0], [9.0, 10.0]]);

let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);

let tensor_3 = tensor_1.matmul(&tensor_2.erf());
let tensor_4 = tensor_3.matmul(&tensor_2);
let grads = tensor_4.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

grad_1
.to_data()
.assert_approx_eq(&Data::from([[32.0, 32.0], [32.0, 32.0]]), 3);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[8.0, 8.0], [8.0, 8.0]]), 3);
}
}
1 change: 1 addition & 0 deletions burn-tensor/src/tensor/backend/autodiff/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod creation;
mod detach;
mod device;
mod div;
mod erf;
mod exp;
mod index;
mod log;
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub trait Backend: Clone + Sized + Default + Send + Sync + std::fmt::Debug + 'st
+ TensorOpsArg<Self, D>
+ TensorOpsCat<Self::Elem, D>
+ TensorOpsLog<Self::Elem, D>
+ TensorOpsErf<Self::Elem, D>
+ TensorOpsPow<Self::Elem, D>
+ TensorOpsMask<Self, D>
+ TensorOpsMapComparison<Self, D>
Expand Down
19 changes: 19 additions & 0 deletions burn-tensor/src/tensor/backend/ndarray/ops/erf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use crate::{
tensor::{backend::ndarray::NdArrayTensor, ops::*},
ElementConversion, NdArrayElement,
};

impl<E, const D: usize> TensorOpsErf<E, D> for NdArrayTensor<E, D>
where
E: NdArrayElement,
{
fn erf(&self) -> Self {
let array = self
.array
.mapv(|a| libm::erf(a.to_f64().unwrap()).to_elem())
.into_shared();
let shape = self.shape;

Self { array, shape }
}
}
1 change: 1 addition & 0 deletions burn-tensor/src/tensor/backend/ndarray/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod creation;
mod detach;
mod device;
mod div;
mod erf;
mod exp;
mod index;
mod log;
Expand Down
21 changes: 21 additions & 0 deletions burn-tensor/src/tensor/backend/tch/ops/erf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use crate::{
tensor::{backend::tch::TchTensor, ops::*},
TchElement,
};

impl<E, const D: usize> TensorOpsErf<E, D> for TchTensor<E, D>
where
E: TchElement,
{
fn erf(&self) -> Self {
let tensor = self.tensor.erf();
let kind = self.kind.clone();
let shape = self.shape;

Self {
tensor,
shape,
kind,
}
}
}
1 change: 1 addition & 0 deletions burn-tensor/src/tensor/backend/tch/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod creation;
mod detach;
mod device;
mod div;
mod erf;
mod exp;
mod index;
mod log;
Expand Down
7 changes: 7 additions & 0 deletions burn-tensor/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ where
Self::new(self.value.log())
}

/// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
///
/// `y = erf(x)`
pub fn erf(&self) -> Self {
Self::new(self.value.erf())
}

/// Applies element wise power operation.
///
/// `y = x^a`
Expand Down
4 changes: 4 additions & 0 deletions burn-tensor/src/tensor/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ pub trait TensorOpsDetach<E, const D: usize> {
fn detach(self) -> Self;
}

pub trait TensorOpsErf<E, const D: usize> {
fn erf(&self) -> Self;
}

pub trait Zeros<T> {
fn zeros(&self) -> T;
}
Expand Down
13 changes: 13 additions & 0 deletions burn-tensor/tests/tensor/ops/erf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use super::super::TestBackend;
use burn_tensor::{Data, Tensor};

#[test]
fn should_support_erf_ops() {
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);

let data_actual = tensor.erf().into_data();

let data_expected = Data::from([[0.0000, 0.8427, 0.9953], [1.0000, 1.0000, 1.0000]]);
data_expected.assert_approx_eq(&data_actual, 3);
}
1 change: 1 addition & 0 deletions burn-tensor/tests/tensor/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod add;
mod aggregation;
mod arg;
mod div;
mod erf;
mod exp;
mod index;
mod map_comparison;
Expand Down

0 comments on commit a84df25

Please sign in to comment.