Skip to content

Commit

Permalink
Refactor/burn tensor (tracel-ai#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Sep 3, 2022
1 parent 1ff11e5 commit 548a1be
Show file tree
Hide file tree
Showing 59 changed files with 397 additions and 503 deletions.
Empty file removed burn-derive/src/device.rs
Empty file.
2 changes: 1 addition & 1 deletion burn-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
#load_fn
}

impl #generics burn::module::ADModule for #name #generics_ty where B: burn::tensor::back::ad::Backend, {
impl #generics burn::module::ADModule for #name #generics_ty where B: burn::tensor::backend::ADBackend, {
type ADBackend=B;
type InnerModule=#name<B::InnerBackend>;

Expand Down
6 changes: 3 additions & 3 deletions burn-derive/src/param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl Param {
quote! {
fn update_params<O: burn::optim::Optimizer<Backend = B>>(&mut self, grads: &burn::tensor::Gradients, optim: &mut O)
where
B: burn::tensor::back::ad::Backend {
B: burn::tensor::backend::ADBackend {
#body
}
}
Expand Down Expand Up @@ -134,7 +134,7 @@ impl Param {
}

quote! {
fn state(&self) -> burn::module::State<<Self::Backend as burn::tensor::back::Backend>::Elem>
fn state(&self) -> burn::module::State<<Self::Backend as burn::tensor::backend::Backend>::Elem>
{
#body
burn::module::State::StateNamed(state)
Expand All @@ -159,7 +159,7 @@ impl Param {
});
}
quote! {
fn load(&mut self, state: &burn::module::State<<Self::Backend as burn::tensor::back::Backend>::Elem>) -> Result<(), burn::module::LoadingError>
fn load(&mut self, state: &burn::module::State<<Self::Backend as burn::tensor::backend::Backend>::Elem>) -> Result<(), burn::module::LoadingError>
{
#body
Ok(())
Expand Down
18 changes: 9 additions & 9 deletions burn-tensor/examples/autodiff_simple.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use burn_tensor::{af, back, Data, Distribution, Shape, Tensor};
use burn_tensor::{activation, backend, Data, Distribution, Shape, Tensor};

fn loss<B: back::Backend>(x: &Tensor<B, 2>, y: &Tensor<B, 2>) -> Tensor<B, 2> {
fn loss<B: backend::Backend>(x: &Tensor<B, 2>, y: &Tensor<B, 2>) -> Tensor<B, 2> {
let z = x.matmul(y);
let z = af::relu(&z);
let z = activation::relu(&z);

println!("fn name : loss");
println!("backend : {}", B::name());
Expand All @@ -11,7 +11,7 @@ fn loss<B: back::Backend>(x: &Tensor<B, 2>, y: &Tensor<B, 2>) -> Tensor<B, 2> {
z
}

fn run_ad<B: back::ad::Backend>(x: Data<B::Elem, 2>, y: Data<B::Elem, 2>) {
fn run_ad<B: backend::ADBackend>(x: Data<B::Elem, 2>, y: Data<B::Elem, 2>) {
println!("---------- Ad Enabled -----------");
let x: Tensor<B, 2> = Tensor::from_data(x);
let y: Tensor<B, 2> = Tensor::from_data(y);
Expand All @@ -26,7 +26,7 @@ fn run_ad<B: back::ad::Backend>(x: Data<B::Elem, 2>, y: Data<B::Elem, 2>) {
println!("")
}

fn run<B: back::Backend>(x: Data<B::Elem, 2>, y: Data<B::Elem, 2>) {
fn run<B: backend::Backend>(x: Data<B::Elem, 2>, y: Data<B::Elem, 2>) {
println!("---------- Ad Disabled ----------");
loss::<B>(&Tensor::from_data(x.clone()), &Tensor::from_data(y.clone()));
println!("---------------------------------");
Expand All @@ -40,13 +40,13 @@ fn main() {

#[cfg(feature = "ndarray")]
{
run::<back::NdArray<f32>>(x.clone(), y.clone());
run_ad::<back::ad::NdArray<f32>>(x.clone(), y.clone());
run::<backend::NdArrayBackend<f32>>(x.clone(), y.clone());
run_ad::<backend::NdArrayADBackend<f32>>(x.clone(), y.clone());
}

#[cfg(feature = "tch")]
{
run::<back::Tch<f32>>(x.clone(), y.clone());
run_ad::<back::ad::Tch<f32>>(x.clone(), y.clone());
run::<backend::TchBackend<f32>>(x.clone(), y.clone());
run_ad::<backend::TchADBackend<f32>>(x.clone(), y.clone());
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::Tensor;
use crate::back::Backend;
use crate::backend::Backend;
use crate::Tensor;
use crate::{ElementPrecision, Precision};

pub fn relu<const D: usize, B: Backend>(tensor: &Tensor<B, D>) -> Tensor<B, D> {
Expand Down
3 changes: 3 additions & 0 deletions burn-tensor/src/tensor/activation/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod activation;

pub use activation::*;
59 changes: 0 additions & 59 deletions burn-tensor/src/tensor/api/ad.rs

This file was deleted.

23 changes: 0 additions & 23 deletions burn-tensor/src/tensor/api/back.rs

This file was deleted.

11 changes: 0 additions & 11 deletions burn-tensor/src/tensor/api/mod.rs

This file was deleted.

95 changes: 0 additions & 95 deletions burn-tensor/src/tensor/api/ops.rs

This file was deleted.

2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/autodiff/activation/relu.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::tensor::backend::backend::Backend;
use crate::{
activation::ReLU,
execute_ops,
graph::ops::{UnaryOps, UnaryOpsNodeState},
ops::activation::*,
ops::*,
register_ops,
tensor::backend::autodiff::ADTensor,
Expand Down
6 changes: 3 additions & 3 deletions burn-tensor/src/tensor/backend/autodiff/ops/aggregation.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::tensor::api::Tensor;
use crate::tensor::tensor_trait::ElementConversion;
use crate::{back::Backend, tensor::ops::*};
use crate::tensor::ElementConversion;
use crate::Tensor;
use crate::{backend::Backend, tensor::ops::*};
use crate::{
define_ops, execute_ops,
graph::ops::{UnaryOps, UnaryOpsNodeState},
Expand Down
Loading

0 comments on commit 548a1be

Please sign in to comment.