Skip to content

Commit

Permalink
Feat/variance (tracel-ai#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Sep 11, 2022
1 parent 34ad49d commit bffc543
Show file tree
Hide file tree
Showing 24 changed files with 234 additions and 25 deletions.
2 changes: 1 addition & 1 deletion burn-tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ pub(crate) mod graph;
pub use graph::grad::Gradients;

mod tensor;
pub use tensor::*;

pub use half::f16;
pub use tensor::*;
9 changes: 9 additions & 0 deletions burn-tensor/src/tensor/backend/autodiff/ops/detach.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use crate::tensor::backend::backend::Backend;
use crate::tensor::{backend::autodiff::ADTensor, ops::*};

impl<B: Backend, P, const D: usize> TensorOpsDetach<P, D> for ADTensor<D, B> {
fn detach(self) -> Self {
let tensor = self.tensor();
Self::from_tensor(tensor.detach())
}
}
2 changes: 2 additions & 0 deletions burn-tensor/src/tensor/backend/autodiff/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod aggregation;
mod arg;
mod cat;
mod creation;
mod detach;
mod device;
mod div;
mod exp;
Expand All @@ -13,6 +14,7 @@ mod mask;
mod matmul;
mod mul;
mod neg;
mod pow;
mod precision;
mod reshape;
mod sub;
Expand Down
61 changes: 61 additions & 0 deletions burn-tensor/src/tensor/backend/autodiff/ops/pow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use crate::tensor::backend::backend::Backend;
use crate::ElementConversion;
use crate::{
execute_ops,
graph::ops::{UnaryOps, UnaryOpsNodeState},
register_ops,
tensor::{backend::autodiff::ADTensor, ops::*},
};

register_ops!(
ops UnaryOps,
name ADTensorPowOps state f32,
partial |
value: &f32,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
| {
let value = state.input
.value()
.powf(value - 1.0)
.mul_scalar(&value.clone().to_elem());
state.output.grad().mul(&value)
},
);

impl<B: Backend, const D: usize> TensorOpsPow<B::Elem, D> for ADTensor<D, B> {
fn powf(&self, value: f32) -> Self {
execute_ops!(
input self.node.clone(),
out TensorOpsPow::powf(&self.tensor(), value),
ops ADTensorPowOps::<B, D>::new(value.clone()),
)
}
}

#[cfg(test)]
mod tests {
use crate::tensor::{backend::autodiff::helper::TestADTensor, Data};

#[test]
fn should_diff_powf() {
let data_1 = Data::<f64, 2>::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = Data::<f64, 2>::from([[6.0, 7.0], [9.0, 10.0]]);

let tensor_1 = TestADTensor::from_data(data_1.clone());
let tensor_2 = TestADTensor::from_data(data_2.clone());

let tensor_3 = tensor_1.matmul(&tensor_2.powf(0.4));
let tensor_4 = tensor_3.matmul(&tensor_2);
let grads = tensor_4.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

grad_1
.to_data()
.assert_approx_eq(&Data::from([[68.0, 79.0328], [68.0, 79.0328]]), 3);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[23.5081, 25.2779], [26.0502, 28.6383]]), 3);
}
}
2 changes: 2 additions & 0 deletions burn-tensor/src/tensor/backend/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub trait Backend: Clone + Sized + Default + Send + Sync + std::fmt::Debug + 'st
+ TensorOpsNeg<Self::Elem, D>
+ TensorOpsAdd<Self::Elem, D>
+ TensorOpsSub<Self::Elem, D>
+ TensorOpsDetach<Self::Elem, D>
+ Zeros<Self::TensorPrimitive<D>>
+ Ones<Self::TensorPrimitive<D>>
+ TensorOpsReshape<Self, D>
Expand All @@ -30,6 +31,7 @@ pub trait Backend: Clone + Sized + Default + Send + Sync + std::fmt::Debug + 'st
+ TensorOpsArg<Self, D>
+ TensorOpsCat<Self::Elem, D>
+ TensorOpsLog<Self::Elem, D>
+ TensorOpsPow<Self::Elem, D>
+ TensorOpsMask<Self, D>
+ TensorOpsMapComparison<Self, D>
+ ReLU<Self::Elem, D>
Expand Down
13 changes: 13 additions & 0 deletions burn-tensor/src/tensor/backend/ndarray/ops/detach.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use crate::{
tensor::{backend::ndarray::NdArrayTensor, ops::*},
NdArrayElement,
};

impl<E, const D: usize> TensorOpsDetach<E, D> for NdArrayTensor<E, D>
where
E: NdArrayElement,
{
fn detach(self) -> Self {
self
}
}
2 changes: 2 additions & 0 deletions burn-tensor/src/tensor/backend/ndarray/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod aggregation;
mod arg;
mod cat;
mod creation;
mod detach;
mod device;
mod div;
mod exp;
Expand All @@ -13,6 +14,7 @@ mod mask;
mod matmul;
mod mul;
mod neg;
mod pow;
mod precision;
mod reshape;
mod sub;
Expand Down
16 changes: 16 additions & 0 deletions burn-tensor/src/tensor/backend/ndarray/ops/pow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use crate::{
tensor::{backend::ndarray::NdArrayTensor, ops::*},
NdArrayElement,
};

impl<E, const D: usize> TensorOpsPow<E, D> for NdArrayTensor<E, D>
where
E: NdArrayElement,
{
fn powf(&self, value: f32) -> Self {
let array = self.array.mapv(|a| a.pow_elem(value)).into_shared();
let shape = self.shape.clone();

Self { array, shape }
}
}
13 changes: 13 additions & 0 deletions burn-tensor/src/tensor/backend/tch/ops/detach.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use crate::{
tensor::{backend::tch::TchTensor, ops::*},
TchElement,
};

impl<E, const D: usize> TensorOpsDetach<E, D> for TchTensor<E, D>
where
E: TchElement,
{
fn detach(self) -> Self {
self
}
}
2 changes: 2 additions & 0 deletions burn-tensor/src/tensor/backend/tch/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod aggregation;
mod arg;
mod cat;
mod creation;
mod detach;
mod device;
mod div;
mod exp;
Expand All @@ -13,6 +14,7 @@ mod mask;
mod matmul;
mod mul;
mod neg;
mod pow;
mod precision;
mod reshape;
mod sub;
Expand Down
21 changes: 21 additions & 0 deletions burn-tensor/src/tensor/backend/tch/ops/pow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use crate::{
tensor::{backend::tch::TchTensor, ops::*},
TchElement,
};

impl<E, const D: usize> TensorOpsPow<E, D> for TchTensor<E, D>
where
E: TchElement,
{
fn powf(&self, value: f32) -> Self {
let tensor = self.tensor.pow_tensor_scalar(value as f64);
let kind = self.kind.clone();
let shape = self.shape.clone();

Self {
tensor,
shape,
kind,
}
}
}
8 changes: 8 additions & 0 deletions burn-tensor/src/tensor/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub(crate) trait TchElement: Element + tch::kind::Element {}
pub(crate) trait ExpElement {
fn exp_elem(self) -> Self;
fn log_elem(self) -> Self;
fn pow_elem(self, value: f32) -> Self;
}

pub trait ElementConversion {
Expand Down Expand Up @@ -233,6 +234,9 @@ mod ndarray_elem {
fn log_elem(self) -> Self {
$elem::ln(self)
}
fn pow_elem(self, value: f32) -> Self {
$elem::powf(self, value.into())
}
}
};
($elem:ident, $tmp:ident) => {
Expand All @@ -245,6 +249,10 @@ mod ndarray_elem {
let tmp = $tmp::ln(self as $tmp);
tmp as $elem
}
fn pow_elem(self, value: f32) -> Self {
let tmp = $tmp::powf(self as $tmp, value as $tmp);
tmp as $elem
}
}
};
}
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/src/tensor/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub(crate) mod ops;
pub(crate) mod stats;

mod bool_tensor;
mod data;
Expand Down
6 changes: 5 additions & 1 deletion burn-tensor/src/tensor/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,17 @@ pub trait TensorOpsCat<E, const D: usize> {
}

pub trait TensorOpsPow<E, const D: usize> {
fn pow(&self, value: &E) -> Self;
fn powf(&self, value: f32) -> Self;
}

pub trait TensorOpsLog<E, const D: usize> {
fn log(&self) -> Self;
}

pub trait TensorOpsDetach<E, const D: usize> {
fn detach(self) -> Self;
}

pub trait Zeros<T> {
fn zeros(&self) -> T;
}
Expand Down
17 changes: 17 additions & 0 deletions burn-tensor/src/tensor/stats/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use crate::{backend::Backend, ElementConversion, Tensor};

pub fn var<B: Backend, const D: usize>(tensor: &Tensor<B, D>, dim: usize) -> Tensor<B, D> {
let mean = tensor.mean_dim(dim);
var_with_mean(tensor, &mean, dim)
}

pub fn var_with_mean<B: Backend, const D: usize>(
tensor: &Tensor<B, D>,
mean: &Tensor<B, D>,
dim: usize,
) -> Tensor<B, D> {
let n = tensor.shape().dims[dim] - 1;
let n = (n as f32).to_elem();

tensor.sub(&mean).powf(2.0).sum_dim(dim).div_scalar(&n)
}
32 changes: 28 additions & 4 deletions burn-tensor/src/tensor/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::tensor::backend::ADBackend;
use crate::tensor::backend::Backend;
use crate::tensor::ops::activation::*;
use crate::tensor::ops::*;
use crate::tensor::stats;
use crate::tensor::{Data, Distribution, Shape};
use crate::BoolTensor;
use crate::Element;
Expand Down Expand Up @@ -55,6 +56,13 @@ where
Self::new(self.value.log())
}

/// Apply element wise power operation.
///
/// `y = x^a`
pub fn powf(&self, value: f32) -> Self {
Self::new(self.value.powf(value))
}

/// Returns the shape of the current tensor.
pub fn shape(&self) -> &Shape<D> {
self.value.shape()
Expand Down Expand Up @@ -235,6 +243,18 @@ where
Self::new(self.value.sum_dim(dim))
}

/// Calculate the variance along the given dimension.
pub fn var(&self, dim: usize) -> Self {
stats::var(self, dim)
}

/// Calculate the variance along the given dimension and also returns the mean.
pub fn var_mean(&self, dim: usize) -> (Self, Self) {
let mean = self.mean_dim(dim);
let var = stats::var_with_mean(self, &mean, dim);
(var, mean)
}

/// Apply element wise equal comparison and returns a boolean tensor.
///
/// # Panics
Expand Down Expand Up @@ -445,6 +465,14 @@ where
Self::new(value)
}

/// Detach the current tensor from the autodiff graph.
/// This function does nothing when autodiff is not enabled.
/// This can be used in batchers or elsewere to ensure that previous operations are not
/// considered in the autodiff graph.
pub fn detach(self) -> Self {
Self::new(self.value.detach())
}

/// Unsqueeze the current tensor. Create new dimensions to fit the given size.
///
/// # Panics
Expand Down Expand Up @@ -601,8 +629,4 @@ impl<const D: usize, B: ADBackend> Tensor<B, D> {
pub fn from_inner(inner: Tensor<B::InnerBackend, D>) -> Self {
Self::new(B::from_inner(inner.value))
}

pub fn detach(&self) -> Self {
Self::from_inner(self.inner())
}
}
1 change: 1 addition & 0 deletions burn-tensor/tests/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ pub type TestADBackend = burn_tensor::backend::TchADBackend<f32>;
mod activation;
mod grad;
mod ops;
mod stats;
1 change: 1 addition & 0 deletions burn-tensor/tests/tensor/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ mod mask;
mod matmul;
mod mul;
mod neg;
mod powf;
mod reshape;
mod sub;
13 changes: 13 additions & 0 deletions burn-tensor/tests/tensor/ops/powf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use super::super::TestBackend;
use burn_tensor::{Data, Tensor};

#[test]
fn should_support_powf_ops() {
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);

let data_actual = tensor.powf(0.71).into_data();

let data_expected = Data::from([[0.0, 1.0, 1.6358], [2.182, 2.6759, 3.1352]]);
data_expected.assert_approx_eq(&data_actual, 3);
}
13 changes: 13 additions & 0 deletions burn-tensor/tests/tensor/stats/basic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use super::super::TestBackend;
use burn_tensor::{Data, Tensor};

#[test]
fn test_var() {
let data = Data::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);

let data_actual = tensor.var(1).into_data();

let data_expected = Data::from([[2.4892], [15.3333]]);
data_expected.assert_approx_eq(&data_actual, 3);
}
1 change: 1 addition & 0 deletions burn-tensor/tests/tensor/stats/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mod basic;
Loading

0 comments on commit bffc543

Please sign in to comment.