Skip to content

Commit

Permalink
Feat/layer norm (tracel-ai#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Sep 16, 2022
1 parent 8c21cf1 commit 48e0fbd
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 46 deletions.
2 changes: 0 additions & 2 deletions burn-tensor/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#![feature(generic_associated_types)]

#[macro_use]
extern crate derive_new;

Expand Down
23 changes: 9 additions & 14 deletions burn-tensor/src/tensor/backend/tch/ops/add.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::tensor::{backend::tch::TchTensor, ops::*, Data};
use crate::{
tensor::{backend::tch::TchTensor, ops::*},
TchElement,
};
use std::ops::Add;

impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TensorOpsAdd<P, D>
for TchTensor<P, D>
{
impl<P: TchElement, const D: usize> TensorOpsAdd<P, D> for TchTensor<P, D> {
fn add(&self, other: &Self) -> Self {
let tensor = (&self.tensor).add(&other.tensor);
let kind = self.kind.clone();
Expand All @@ -16,10 +17,8 @@ impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> T
}
}
fn add_scalar(&self, other: &P) -> Self {
let elems: [P; D] = [*other; D];
let data = Data::from(elems);
let other = TchTensor::from_data(data, self.tensor.device());
let tensor = (&self.tensor).add(&other.tensor);
let other: f64 = (other.clone()).to_elem();
let tensor = (&self.tensor).add(other);
let kind = self.kind.clone();
let shape = self.shape;

Expand All @@ -31,19 +30,15 @@ impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> T
}
}

impl<P: tch::kind::Element + Default + std::fmt::Debug + Copy, const D: usize> std::ops::Add<Self>
for TchTensor<P, D>
{
impl<P: TchElement, const D: usize> std::ops::Add<Self> for TchTensor<P, D> {
type Output = Self;

fn add(self, rhs: Self) -> Self::Output {
TensorOpsAdd::add(&self, &rhs)
}
}

impl<P: tch::kind::Element + Default + std::fmt::Debug + Copy, const D: usize> std::ops::Add<P>
for TchTensor<P, D>
{
impl<P: TchElement, const D: usize> std::ops::Add<P> for TchTensor<P, D> {
type Output = Self;

fn add(self, rhs: P) -> Self::Output {
Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/tch/ops/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::ops::Div;

impl<P: Element + tch::kind::Element, const D: usize> TensorOpsDiv<P, D> for TchTensor<P, D> {
fn div(&self, other: &Self) -> Self {
let tensor = (&self.tensor) / &other.tensor;
let tensor = (&self.tensor).div(&other.tensor);
let shape = self.shape.higher(&other.shape);
let kind = self.kind.clone();

Expand Down
62 changes: 37 additions & 25 deletions burn-tensor/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@ where
self.value.device()
}

/// Apply element wise exponential operation.
/// Applies element wise exponential operation.
///
/// `y = e^x`
pub fn exp(&self) -> Self {
Self::new(self.value.exp())
}

/// Apply element wise natural log operation *ln*.
/// Applies element wise natural log operation *ln*.
///
/// `y = log(x)`
pub fn log(&self) -> Self {
Self::new(self.value.log())
}

/// Apply element wise power operation.
/// Applies element wise power operation.
///
/// `y = x^a`
pub fn powf(&self, value: f32) -> Self {
Expand Down Expand Up @@ -132,35 +132,35 @@ where
tensor.index_assign(ranges, &Tensor::ones(Shape::new([1; D])))
}

/// Apply element wise addition operation.
/// Applies element wise addition operation.
///
/// `y = x2 + x1`
pub fn add(&self, other: &Self) -> Self {
Self::new(self.value.add(&other.value))
}

/// Apply element wise addition operation with a scalar.
/// Applies element wise addition operation with a scalar.
///
/// `y = x + s`
pub fn add_scalar(&self, other: &B::Elem) -> Self {
Self::new(self.value.add_scalar(other))
}

/// Apply element wise substraction operation.
/// Applies element wise substraction operation.
///
/// `y = x2 - x1`
pub fn sub(&self, other: &Self) -> Self {
Self::new(self.value.sub(&other.value))
}

/// Apply element wise substraction operation with a scalar.
/// Applies element wise substraction operation with a scalar.
///
/// `y = x - s`
pub fn sub_scalar(&self, other: &B::Elem) -> Self {
Self::new(self.value.sub_scalar(other))
}

/// Apply the transpose operation.
/// Applies the transpose operation.
///
/// On matrix and higher dimension tensor, it swap the last two dimensions.
///
Expand All @@ -171,7 +171,7 @@ where
Self::new(self.value.transpose())
}

/// Apply the matrix multiplication operation.
/// Applies the matrix multiplication operation.
///
/// `C = AB`
///
Expand All @@ -189,28 +189,28 @@ where
Self::new(self.value.neg())
}

/// Apply element wise multiplication operation.
/// Applies element wise multiplication operation.
///
/// `y = x2 * x1`
pub fn mul(&self, other: &Self) -> Self {
Self::new(self.value.mul(&other.value))
}

/// Apply element wise multiplication operation with scalar.
/// Applies element wise multiplication operation with scalar.
///
/// `y = x2 * x1`
pub fn mul_scalar(&self, other: &B::Elem) -> Self {
Self::new(self.value.mul_scalar(other))
}

/// Apply element wise division operation.
/// Applies element wise division operation.
///
/// `y = x2 / x1`
pub fn div(&self, other: &Self) -> Self {
Self::new(self.value.div(&other.value))
}

/// Apply element wise division operation with scalar.
/// Applies element wise division operation with scalar.
///
/// `y = x2 / x1`
pub fn div_scalar(&self, other: &B::Elem) -> Self {
Expand Down Expand Up @@ -242,14 +242,26 @@ where
stats::var(self, dim)
}

/// Calculate the variance along the given dimension without applying the Bessel’s correction.
pub fn var_bias(&self, dim: usize) -> Self {
stats::var_bias(self, dim)
}

/// Calculate the variance along the given dimension and also returns the mean.
pub fn var_mean(&self, dim: usize) -> (Self, Self) {
let mean = self.mean_dim(dim);
let var = stats::var_with_mean(self, &mean, dim);
(var, mean)
}

/// Apply element wise equal comparison and returns a boolean tensor.
/// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.
pub fn var_mean_bias(&self, dim: usize) -> (Self, Self) {
let mean = self.mean_dim(dim);
let var = stats::var_with_mean_bias(self, &mean, dim);
(var, mean)
}

/// Applies element wise equal comparison and returns a boolean tensor.
///
/// # Panics
///
Expand All @@ -258,7 +270,7 @@ where
BoolTensor::new(self.value.equal(&other.value))
}

/// Apply element wise greater comparison and returns a boolean tensor.
/// Applies element wise greater comparison and returns a boolean tensor.
///
/// # Panics
///
Expand All @@ -267,7 +279,7 @@ where
BoolTensor::new(self.value.greater(&other.value))
}

/// Apply element wise greater-equal comparison and returns a boolean tensor.
/// Applies element wise greater-equal comparison and returns a boolean tensor.
///
/// # Panics
///
Expand All @@ -276,7 +288,7 @@ where
BoolTensor::new(self.value.greater_equal(&other.value))
}

/// Apply element wise lower comparison and returns a boolean tensor.
/// Applies element wise lower comparison and returns a boolean tensor.
///
/// # Panics
///
Expand All @@ -285,7 +297,7 @@ where
BoolTensor::new(self.value.lower(&other.value))
}

/// Apply element wise lower-equal comparison and returns a boolean tensor.
/// Applies element wise lower-equal comparison and returns a boolean tensor.
///
/// # Panics
///
Expand All @@ -294,27 +306,27 @@ where
BoolTensor::new(self.value.lower_equal(&other.value))
}

/// Apply element wise equal comparison and returns a boolean tensor.
/// Applies element wise equal comparison and returns a boolean tensor.
pub fn equal_scalar(&self, other: &B::Elem) -> BoolTensor<B, D> {
BoolTensor::new(self.value.equal_scalar(other))
}

/// Apply element wise greater comparison and returns a boolean tensor.
/// Applies element wise greater comparison and returns a boolean tensor.
pub fn greater_scalar(&self, other: &B::Elem) -> BoolTensor<B, D> {
BoolTensor::new(self.value.greater_scalar(other))
}

/// Apply element wise greater-equal comparison and returns a boolean tensor.
/// Applies element wise greater-equal comparison and returns a boolean tensor.
pub fn greater_equal_scalar(&self, other: &B::Elem) -> BoolTensor<B, D> {
BoolTensor::new(self.value.greater_equal_scalar(other))
}

/// Apply element wise lower comparison and returns a boolean tensor.
/// Applies element wise lower comparison and returns a boolean tensor.
pub fn lower_scalar(&self, other: &B::Elem) -> BoolTensor<B, D> {
BoolTensor::new(self.value.lower_scalar(other))
}

/// Apply element wise lower-equal comparison and returns a boolean tensor.
/// Applies element wise lower-equal comparison and returns a boolean tensor.
pub fn lower_equal_scalar(&self, other: &B::Elem) -> BoolTensor<B, D> {
BoolTensor::new(self.value.lower_equal_scalar(other))
}
Expand Down Expand Up @@ -407,7 +419,7 @@ where
Tensor::new(value)
}

/// Apply the argmax function along the given dimension and returns an integer tensor.
/// Applies the argmax function along the given dimension and returns an integer tensor.
///
/// # Example
///
Expand All @@ -426,7 +438,7 @@ where
Tensor::new(self.value.argmax(dim))
}

/// Apply the argmin function along the given dimension and returns an integer tensor.
/// Applies the argmin function along the given dimension and returns an integer tensor.
///
/// # Example
///
Expand Down
24 changes: 22 additions & 2 deletions burn-tensor/src/tensor/stats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,28 @@ pub fn var_with_mean<B: Backend, const D: usize>(
mean: &Tensor<B, D>,
dim: usize,
) -> Tensor<B, D> {
let n = tensor.shape().dims[dim] - 1;
let n = (n as f32).to_elem();
var_with_mean_n(tensor, mean, dim, tensor.shape().dims[dim] - 1)
}

pub fn var_bias<B: Backend, const D: usize>(tensor: &Tensor<B, D>, dim: usize) -> Tensor<B, D> {
let mean = tensor.mean_dim(dim);
var_with_mean_bias(tensor, &mean, dim)
}

pub fn var_with_mean_bias<B: Backend, const D: usize>(
tensor: &Tensor<B, D>,
mean: &Tensor<B, D>,
dim: usize,
) -> Tensor<B, D> {
var_with_mean_n(tensor, mean, dim, tensor.shape().dims[dim])
}

pub fn var_with_mean_n<B: Backend, const D: usize>(
tensor: &Tensor<B, D>,
mean: &Tensor<B, D>,
dim: usize,
n: usize,
) -> Tensor<B, D> {
let n = (n as f32).to_elem();
tensor.sub(mean).powf(2.0).sum_dim(dim).div_scalar(&n)
}
4 changes: 3 additions & 1 deletion burn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ pub mod train;
pub(crate) mod macros;

#[cfg(test)]
pub type TestBackend = crate::tensor::backend::NdArrayBackend<f32>;
pub type TestBackend = crate::tensor::backend::TchBackend<f32>;
#[cfg(test)]
pub type TestADBackend = crate::tensor::backend::TchADBackend<f32>;
2 changes: 1 addition & 1 deletion burn/src/macros.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
macro_rules! config {
($item:item) => {
#[derive(new, serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq, Eq)]
#[derive(new, serde::Serialize, serde::Deserialize, Clone, Debug)]
$item
};
}
Expand Down
Loading

0 comments on commit 48e0fbd

Please sign in to comment.