Skip to content

Commit

Permalink
Fix/named tensor (tracel-ai#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Nov 26, 2022
1 parent e0e787f commit acb14ad
Show file tree
Hide file tree
Showing 14 changed files with 143 additions and 138 deletions.
182 changes: 73 additions & 109 deletions burn-autodiff/src/tests/multithread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,118 +4,82 @@ mod tests {
use burn_tensor::Data;

#[test]
fn should_diff_mean() {
let data_1 = Data::<f32, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = Data::<f32, 2>::from([[4.0, -7.0], [2.0, 3.0]]);

let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);

let tensor_3 = tensor_1.matmul(&tensor_2);
let tensor_4 = tensor_1.mul(&tensor_3.mean().unsqueeze());
let grads = tensor_4.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

grad_1
.to_data()
.assert_approx_eq(&Data::from([[3.5, 9.5], [3.5, 9.5]]), 5);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[-0.75, -0.75], [3.0, 3.0]]), 5);
}
fn should_behave_the_same_with_multithread() {
let data_1: Data<f32, 2> = Data::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2: Data<f32, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);

#[test]
fn should_diff_sum_1() {
let data_1 = Data::<f32, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = Data::<f32, 2>::from([[4.0, -7.0], [2.0, 3.0]]);

let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);

let tensor_3 = tensor_1.matmul(&tensor_2);
let tensor_4 = tensor_1.mul(&tensor_3.sum().unsqueeze());
let grads = tensor_4.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

grad_1
.to_data()
.assert_approx_eq(&Data::from([[14.0, 38.0], [14.0, 38.0]]), 5);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[-3.0, -3.0], [12.0, 12.0]]), 5);
}
let with_move = || {
let tensor_1 = TestADTensor::from_data(data_1.clone());
let tensor_2 = TestADTensor::from_data(data_2.clone());

#[test]
fn should_diff_sum_2() {
let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]);

let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);

let tensor_3 = tensor_1.matmul(&tensor_2);
let tensor_4 = tensor_3.sum_dim(1);
let tensor_5 = tensor_4.mul(&tensor_3);

let grads = tensor_5.sum().backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

grad_1
.to_data()
.assert_approx_eq(&Data::from([[494.0, 722.0], [2990.0, 4370.0]]), 3);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[690.0, 690.0], [958.0, 958.0]]), 3);
}
let tensor_3 = tensor_1.matmul(&tensor_2);
let tensor_4 = tensor_3.matmul(&tensor_2);
let tensor_5 = tensor_4.matmul(&tensor_3);

#[test]
fn should_diff_mean_dim() {
let data_1 = Data::<f32, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = Data::<f32, 2>::from([[4.0, -7.0], [2.0, 3.0]]);

let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);

let tensor_3 = tensor_1.matmul(&tensor_2);
let tensor_4 = tensor_1.mul(&tensor_3.mean_dim(1).unsqueeze());
let grads = tensor_4.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

grad_1
.to_data()
.assert_approx_eq(&Data::from([[4.0, 36.0], [3.0, -17.0]]), 5);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[9.0, 9.0], [35.5, 35.5]]), 5);
}
// Task 1
let tensor_1_cloned = tensor_1.clone();
let tensor_2_cloned = tensor_2.clone();
let tensor_5_cloned = tensor_5.clone();

#[test]
fn should_diff_sum_dim() {
let data_1 = Data::<f32, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = Data::<f32, 2>::from([[4.0, -7.0], [2.0, 3.0]]);

let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);

let tensor_3 = tensor_1.matmul(&tensor_2);
let tensor_4 = tensor_1.mul(&tensor_3.sum_dim(1).unsqueeze());
let grads = tensor_4.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

grad_1
.to_data()
.assert_approx_eq(&Data::from([[8.0, 72.0], [6.0, -34.0]]), 5);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[18.0, 18.0], [71.0, 71.0]]), 5);
let first_call = move || {
let tensor_6_1 = tensor_5_cloned.matmul(&tensor_2_cloned);
tensor_6_1.matmul(&tensor_1_cloned)
};

// Task 2
let tensor_1_cloned = tensor_1.clone();
let tensor_2_cloned = tensor_2.clone();
let tensor_5_cloned = tensor_5;

let second_call = move || {
let tensor_6_2 = tensor_5_cloned.matmul(&tensor_1_cloned);
tensor_6_2.matmul(&tensor_2_cloned)
};

let tensor_7_1_handle = std::thread::spawn(first_call);
let tensor_7_2_handle = std::thread::spawn(second_call);

let tensor_7_1 = tensor_7_1_handle.join().unwrap();
let tensor_7_2 = tensor_7_2_handle.join().unwrap();
let tensor_8 = tensor_7_1.matmul(&tensor_7_2);

let grads = tensor_8.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

(grad_1, grad_2)
};
let without_move = || {
let tensor_1 = TestADTensor::from_data(data_1.clone());
let tensor_2 = TestADTensor::from_data(data_2.clone());

let tensor_3 = tensor_1.matmul(&tensor_2);
let tensor_4 = tensor_3.matmul(&tensor_2);
let tensor_5 = tensor_4.matmul(&tensor_3);

// Task 1
let tensor_6_1 = tensor_5.matmul(&tensor_2);
let tensor_7_1 = tensor_6_1.matmul(&tensor_1);

// Task 2
let tensor_6_2 = tensor_5.matmul(&tensor_1);
let tensor_7_2 = tensor_6_2.matmul(&tensor_2);

let tensor_8 = tensor_7_1.matmul(&tensor_7_2);

let grads = tensor_8.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

(grad_1, grad_2)
};

let (grad_1, grad_2) = without_move();
let (grad_1_moved, grad_2_moved) = with_move();

assert_eq!(grad_1.to_data(), grad_1_moved.to_data());
assert_eq!(grad_2.to_data(), grad_2_moved.to_data());
}
}
25 changes: 24 additions & 1 deletion burn-derive/src/module/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
let load_fn = param.gen_load_fn();
let inner_fn = param.gen_inner_fn();
let detach_fn = param.gen_detach_fn();
let generics_names_except_backend = generics_names_except_backend(&ast.generics);

let gen = quote! {
impl #generics burn::module::Module for #name #generics_ty #generics_where {
Expand All @@ -41,7 +42,7 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {

impl #generics burn::module::ADModule for #name #generics_ty where B: burn::tensor::backend::ADBackend, {
type ADBackend=B;
type InnerModule=#name<B::InnerBackend>;
type InnerModule=#name<B::InnerBackend, #generics_names_except_backend>;

#inner_fn
}
Expand All @@ -53,3 +54,25 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {

gen.into()
}

fn generics_names_except_backend(generics: &syn::Generics) -> proc_macro2::TokenStream {
let mut named = quote! {};

generics.params.iter().for_each(|param| {
match param {
syn::GenericParam::Type(ty) => {
if ty.ident != "B" {
let ident = &ty.ident;
named.extend(quote! { #ident, });
}
}
syn::GenericParam::Lifetime(_) => panic!("Lifetime not supported in module"),
syn::GenericParam::Const(c) => {
let ident = &c.ident;
named.extend(quote! { #ident, });
}
};
});

named
}
2 changes: 1 addition & 1 deletion burn-tensor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ edition = "2021"
[features]
default = []
export_tests = ["burn-tensor-testgen"]
named_tensor = []
experimental-named-tensor = []

[dependencies]
burn-tensor-testgen = { version = "0.3.0", path = "../burn-tensor-testgen", optional = true }
Expand Down
4 changes: 2 additions & 2 deletions burn-tensor/src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub mod backend;
pub mod loss;
pub mod module;

#[cfg(feature = "named_tensor")]
#[cfg(feature = "experimental-named-tensor")]
mod named;
#[cfg(feature = "named_tensor")]
#[cfg(feature = "experimental-named-tensor")]
pub use named::*;
17 changes: 17 additions & 0 deletions burn-tensor/src/tensor/named/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,27 @@ use crate::backend::Backend;
use crate::{Distribution, NamedDims, Shape, Tensor};

/// A tensor with named dimensions.
#[derive(Debug, Clone)]
pub struct NamedTensor<B: Backend, D: NamedDims<B>> {
pub(crate) tensor: D::Tensor,
}

impl<B: Backend, ND: NamedDims<B, Tensor = Tensor<B, D>>, const D: usize> From<NamedTensor<B, ND>>
for Tensor<B, D>
{
fn from(nt: NamedTensor<B, ND>) -> Self {
nt.tensor
}
}

impl<B: Backend, ND: NamedDims<B, Tensor = Tensor<B, D>>, const D: usize> From<Tensor<B, D>>
for NamedTensor<B, ND>
{
fn from(tensor: Tensor<B, D>) -> Self {
Self::from_tensor(tensor)
}
}

impl<B: Backend, const D: usize, ND: NamedDims<B>> std::fmt::Display for NamedTensor<B, ND>
where
ND: NamedDims<B, Tensor = Tensor<B, D>>,
Expand Down
5 changes: 3 additions & 2 deletions burn-tensor/src/tensor/named/dims.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
use crate::backend::Backend;
use crate::Tensor;

pub trait Dim {
pub trait Dim: std::fmt::Debug {
fn to_string() -> String;
}

pub trait NamedDims<B: Backend> {
pub trait NamedDims<B: Backend>: std::fmt::Debug {
type Tensor;
fn to_string() -> String;
}

#[macro_export]
macro_rules! NamedDim {
($name:ident) => {
#[derive(Debug)]
pub struct $name;
impl Dim for $name {
fn to_string() -> String {
Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/named/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod base;
mod dims;
mod matmul;
mod permut;
mod swap_dims;

pub use base::*;
pub use dims::*;
Original file line number Diff line number Diff line change
@@ -1,50 +1,50 @@
use crate::backend::Backend;
use crate::{Dim, NamedDims, NamedTensor, Tensor};

pub trait Permut<N, const D1: usize, const D2: usize> {
fn permut(&self) -> N;
pub trait SwapDims<N, const D1: usize, const D2: usize> {
fn swap_dims(&self) -> N;
}

impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
where
ND: NamedDims<B, Tensor = Tensor<B, D>>,
{
/// Permut two dimensions.
pub fn permut<ND2, const D1: usize, const D2: usize>(&self) -> NamedTensor<B, ND2>
/// Swap two dimensions.
pub fn swap_dims<ND2, const D1: usize, const D2: usize>(&self) -> NamedTensor<B, ND2>
where
ND2: NamedDims<B, Tensor = Tensor<B, D>>,
Self: Permut<NamedTensor<B, ND2>, D1, D2>,
Self: SwapDims<NamedTensor<B, ND2>, D1, D2>,
{
Permut::permut(self)
SwapDims::swap_dims(self)
}
}

macro_rules! generate_permut {
(2 => $output:ty, ($dim1:expr, $dim2:expr)) => {
impl<B: Backend, D1: Dim, D2: Dim> Permut<NamedTensor<B, $output>, $dim1, $dim2>
impl<B: Backend, D1: Dim, D2: Dim> SwapDims<NamedTensor<B, $output>, $dim1, $dim2>
for NamedTensor<B, (D1, D2)>
{
fn permut(&self) -> NamedTensor<B, $output> {
fn swap_dims(&self) -> NamedTensor<B, $output> {
NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
}
}
};

(3 => $output:ty, ($dim1:expr, $dim2:expr)) => {
impl<B: Backend, D1: Dim, D2: Dim, D3: Dim> Permut<NamedTensor<B, $output>, $dim1, $dim2>
impl<B: Backend, D1: Dim, D2: Dim, D3: Dim> SwapDims<NamedTensor<B, $output>, $dim1, $dim2>
for NamedTensor<B, (D1, D2, D3)>
{
fn permut(&self) -> NamedTensor<B, $output> {
fn swap_dims(&self) -> NamedTensor<B, $output> {
NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
}
}
};

(4 => $output:ty, ($dim1:expr, $dim2:expr)) => {
impl<B: Backend, D1: Dim, D2: Dim, D3: Dim, D4: Dim>
Permut<NamedTensor<B, $output>, $dim1, $dim2> for NamedTensor<B, (D1, D2, D3, D4)>
SwapDims<NamedTensor<B, $output>, $dim1, $dim2> for NamedTensor<B, (D1, D2, D3, D4)>
{
fn permut(&self) -> NamedTensor<B, $output> {
fn swap_dims(&self) -> NamedTensor<B, $output> {
NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
}
}
Expand Down
3 changes: 2 additions & 1 deletion burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ license = "MIT/Apache-2.0"
edition = "2021"

[features]
named_tensor = ["burn-tensor/named_tensor"]
default = []
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]

[dependencies]
burn-tensor = { version = "0.3.0", path = "../burn-tensor" }
Expand Down
4 changes: 2 additions & 2 deletions burn/src/optim/momentum.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use burn_tensor::backend::Gradients;
use crate as burn;

use super::{load_state_gradients, register_state_gradients};
use crate as burn;
use crate::config::Config;
use crate::module::{ParamId, StateNamed};
use crate::tensor::backend::ADBackend;
use crate::tensor::{ElementConversion, Tensor};
use burn_tensor::backend::Gradients;

/// Configuration to create momentum [Momentum](Momentum).
#[derive(Config)]
Expand Down
Loading

0 comments on commit acb14ad

Please sign in to comment.