Skip to content

Commit

Permalink
Feat/inner module (tracel-ai#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Aug 30, 2022
1 parent 68e0525 commit 674e078
Show file tree
Hide file tree
Showing 15 changed files with 198 additions and 77 deletions.
11 changes: 9 additions & 2 deletions burn-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
let state_fn = param.gen_state_fn();
let load_from_parent_fn = param.gen_load_from_parent_fn();
let load_fn = param.gen_load_fn();
let inner_fn = param.gen_inner_fn();

let gen = quote! {
impl #generics burn::module::Module for #name #generics_ty #generics_where {
Expand All @@ -44,11 +45,17 @@ fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
#load_fn
}

impl #generics burn::module::ADModule for #name #generics_ty where B: burn::tensor::back::ad::Backend, {
type ADBackend=B;
type InnerModule=#name<B::InnerBackend>;

#inner_fn
}

impl #generics std::fmt::Display for #name #generics_ty #generics_where {
#display_fn
}
};

gen.into()
let tokens = gen.into();
tokens
}
26 changes: 25 additions & 1 deletion burn-derive/src/param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl Param {
}

quote! {
fn update_params<O: burn::optim::Optimizer<B>>(&mut self, grads: &burn::tensor::Gradients, optim: &mut O)
fn update_params<O: burn::optim::Optimizer<Backend = B>>(&mut self, grads: &burn::tensor::Gradients, optim: &mut O)
where
B: burn::tensor::back::ad::Backend {
#body
Expand Down Expand Up @@ -98,6 +98,30 @@ impl Param {
.into()
}

pub fn gen_inner_fn(&self) -> TokenStream {
let mut body = quote! {};
let mut names = Vec::new();
for field in self.fields.iter() {
let name = field.ident();
names.push(name.clone());

body.extend(quote! {
let #name = self.#name.inner();
});
}

quote! {
fn inner(&self) -> Self::InnerModule {
#body

Self::InnerModule {
#(#names),*
}
}
}
.into()
}

pub fn gen_state_fn(&self) -> TokenStream {
let mut body = quote! {
let mut state = burn::module::State::new(self.name());
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ half = { version = "1.6", features = ["num-traits"] } # needs to be 1.6 to work

# Backends
tch = { version = "0.8", optional = true }
lazy_static = "1.4"
ndarray = { version = "0.15", optional = true }

# Autodiff
Expand Down
6 changes: 3 additions & 3 deletions burn-tensor/src/tensor/api/af.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ pub fn softmax<const D: usize, B: Backend>(tensor: &Tensor<B, D>, dim: usize) ->
}

pub fn log_softmax<const D: usize, B: Backend>(tensor: &Tensor<B, D>, dim: usize) -> Tensor<B, D> {
let tensor_tmp = match Precision::Half == B::Elem::precision() {
true => {
let tensor_tmp = match B::Elem::precision() {
Precision::Half => {
let tensor_full = tensor.to_full_precision();
let tensor_tmp = tensor_full.exp().sum_dim(dim).log();
Tensor::from_full_precision(tensor_tmp)
}
false => tensor.exp().sum_dim(dim).log(),
_ => tensor.exp().sum_dim(dim).log(),
};

tensor.sub(&tensor_tmp)
Expand Down
10 changes: 10 additions & 0 deletions burn-tensor/src/tensor/backend/tch/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
use crate::tensor::{ops::TensorOpsUtilities, Data, Element, Shape, TensorTrait};

lazy_static::lazy_static! {
static ref NO_GRAD: tch::NoGradGuard = {
tch::no_grad_guard()
};
}

#[derive(Debug, PartialEq)]
pub struct TchTensor<P: tch::kind::Element, const D: usize> {
pub kind: TchKind<P>,
Expand Down Expand Up @@ -65,6 +71,8 @@ impl<P: tch::kind::Element + Default, const D: usize> TchTensor<P, D> {
let shape_tch = TchShape::from(data.shape);
let kind = TchKind::new();
let tensor = tensor.reshape(&shape_tch.dims).to_kind(kind.kind());

lazy_static::initialize(&NO_GRAD);
let tensor = tensor.set_requires_grad(false);

Self {
Expand All @@ -81,6 +89,8 @@ impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> T
let device = tch::Device::Cpu;
let kind = TchKind::new();
let tensor = tch::Tensor::empty(&shape_tch.dims, (kind.kind(), device.clone()));

lazy_static::initialize(&NO_GRAD);
let tensor = tensor.set_requires_grad(false);

Self {
Expand Down
54 changes: 32 additions & 22 deletions examples/mnist.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use burn::data::dataloader::batcher::Batcher;
use burn::data::dataloader::DataLoaderBuilder;
use burn::data::dataloader::{DataLoaderBuilder, Detach};
use burn::data::dataset::source::huggingface::{MNISTDataset, MNISTItem};
use burn::module::{Forward, Module, Param};
use burn::nn;
Expand All @@ -8,8 +8,8 @@ use burn::tensor::af::relu;
use burn::tensor::back::{ad, Backend};
use burn::tensor::losses::cross_entropy_with_logits;
use burn::tensor::{Data, ElementConversion, Shape, Tensor};
use burn::train::logger::{AsyncLogger, CLILogger};
use burn::train::metric::{AccuracyMetric, CUDAMetric, LossMetric, Metric};
use burn::train::logger::{AsyncLogger, CLILogger, TextPlot};
use burn::train::metric::{AccuracyMetric, CUDAMetric, LossMetric};
use burn::train::{ClassificationLearner, ClassificationOutput, SupervisedTrainer};
use std::sync::Arc;

Expand Down Expand Up @@ -118,7 +118,16 @@ struct MNISTBatch<B: Backend> {
targets: Tensor<B, 2>,
}

impl<B: ad::Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
impl<B: ad::Backend> Detach for MNISTBatch<B> {
fn detach(self) -> Self {
Self {
images: self.images.detach(),
targets: self.targets.detach(),
}
}
}

impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
fn batch(&self, items: Vec<MNISTItem>) -> MNISTBatch<B> {
let images = items
.iter()
Expand All @@ -133,8 +142,8 @@ impl<B: ad::Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
.map(|item| Tensor::<B, 2>::one_hot(item.label, 10))
.collect();

let images = Tensor::cat(images, 0).to_device(self.device).detach();
let targets = Tensor::cat(targets, 0).to_device(self.device).detach();
let images = Tensor::cat(images, 0).to_device(self.device);
let targets = Tensor::cat(targets, 0).to_device(self.device);

MNISTBatch { images, targets }
}
Expand All @@ -143,18 +152,11 @@ impl<B: ad::Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
fn run<B: ad::Backend>(device: B::Device) {
let batch_size = 128;
let learning_rate = 5.5e-2;
let num_epochs = 100;
let num_epochs = 10;
let num_workers = 8;
let num_layers = 4;
let hidden_dim = 1024;
let hidden_dim = 3024;
let seed = 42;
let metrics = || -> Vec<Box<dyn Metric<ClassificationOutput<B>>>> {
vec![
Box::new(LossMetric::new()),
Box::new(AccuracyMetric::new()),
Box::new(CUDAMetric::new()),
]
};

let mut model: Model<B> = Model::new(784, hidden_dim, num_layers, 10);
model.to_device(device);
Expand All @@ -167,25 +169,34 @@ fn run<B: ad::Backend>(device: B::Device) {
);

let optim: SGDOptimizer<B> = SGDOptimizer::new(learning_rate);
let batcher = Arc::new(MNISTBatcher::<B> { device });
let dataloader_train = DataLoaderBuilder::new(batcher.clone())
let batcher_train = Arc::new(MNISTBatcher::<B> { device });
let batcher_valid = Arc::new(MNISTBatcher::<B::InnerBackend> { device });
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(batch_size)
.shuffle(seed)
.num_workers(num_workers)
.build(Arc::new(MNISTDataset::train()));
let dataloader_test = DataLoaderBuilder::new(batcher.clone())
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
.batch_size(batch_size)
.num_workers(num_workers)
.build(Arc::new(MNISTDataset::test()));

let learner = ClassificationLearner::new(model);
let learner = ClassificationLearner::new(model, optim);

let logger_train = Box::new(AsyncLogger::new(Box::new(CLILogger::new(
metrics(),
vec![
Box::new(TextPlot::new(LossMetric::new())),
Box::new(AccuracyMetric::new()),
Box::new(CUDAMetric::new()),
],
"Train".to_string(),
))));
let logger_valid = Box::new(AsyncLogger::new(Box::new(CLILogger::new(
metrics(),
vec![
Box::new(TextPlot::new(LossMetric::new())),
Box::new(AccuracyMetric::new()),
Box::new(CUDAMetric::new()),
],
"Valid".to_string(),
))));

Expand All @@ -195,7 +206,6 @@ fn run<B: ad::Backend>(device: B::Device) {
logger_train,
logger_valid,
learner,
optim,
);

trainer.run(num_epochs);
Expand Down
4 changes: 4 additions & 0 deletions src/data/dataloader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ pub use builder::*;
pub use dataloader::*;
pub use multithread::*;
pub use strategy::*;

pub trait Detach {
fn detach(self) -> Self;
}
14 changes: 12 additions & 2 deletions src/module/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,11 @@ where
pub trait Module: Send + Sync + std::fmt::Debug + std::fmt::Display {
type Backend: back::Backend;

fn update_params<O: Optimizer<Self::Backend>>(&mut self, grads: &Gradients, optim: &mut O)
where
fn update_params<O: Optimizer<Backend = Self::Backend>>(
&mut self,
grads: &Gradients,
optim: &mut O,
) where
Self::Backend: back::ad::Backend;
fn devices(&self) -> Vec<<Self::Backend as back::Backend>::Device>;
fn to_device(&mut self, device: <Self::Backend as back::Backend>::Device);
Expand All @@ -93,6 +96,13 @@ pub trait Module: Send + Sync + std::fmt::Debug + std::fmt::Display {
fn num_params(&self) -> usize;
}

pub trait ADModule: Module + Send + Sync + std::fmt::Debug + std::fmt::Display {
type ADBackend: back::ad::Backend;
type InnerModule: Module<Backend = <Self::ADBackend as back::ad::Backend>::InnerBackend>;

fn inner(&self) -> Self::InnerModule;
}

pub trait Forward<In, Out> {
fn forward(&self, input: In) -> Out;
}
53 changes: 46 additions & 7 deletions src/module/param.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::module::{Module, State};
use crate::module::{ADModule, Module, State};
use crate::optim::Optimizer;
use crate::tensor::{back, Gradients, Tensor};
use serde::de::DeserializeOwned;
Expand Down Expand Up @@ -28,7 +28,7 @@ impl<const D: usize, B: back::Backend> Param<Tensor<B, D>> {
self.value.shape().num_elements()
}

pub fn update_params<O: Optimizer<B>>(&mut self, grads: &Gradients, optim: &mut O)
pub fn update_params<O: Optimizer<Backend = B>>(&mut self, grads: &Gradients, optim: &mut O)
where
B: back::ad::Backend,
{
Expand Down Expand Up @@ -60,6 +60,13 @@ impl<const D: usize, B: back::Backend> Param<Tensor<B, D>> {
let data = state.get(name);
self.value = Tensor::from_data_device(data, self.value.device());
}

pub fn inner(&self) -> Param<Tensor<B::InnerBackend, D>>
where
B: back::ad::Backend,
{
Param::new(self.value.inner())
}
}

impl<const D: usize, B: back::Backend> Param<Option<Tensor<B, D>>> {
Expand All @@ -71,7 +78,7 @@ impl<const D: usize, B: back::Backend> Param<Option<Tensor<B, D>>> {
0
}

pub fn update_params<O: Optimizer<B>>(&mut self, grads: &Gradients, optim: &mut O)
pub fn update_params<O: Optimizer<Backend = B>>(&mut self, grads: &Gradients, optim: &mut O)
where
B: back::ad::Backend,
{
Expand Down Expand Up @@ -118,15 +125,28 @@ impl<const D: usize, B: back::Backend> Param<Option<Tensor<B, D>>> {

self.value = value;
}

pub fn inner(&self) -> Param<Option<Tensor<B::InnerBackend, D>>>
where
B: back::ad::Backend,
{
match &self.value {
Some(tensor) => Param::new(Some(tensor.inner())),
None => Param::new(None),
}
}
}

impl<M: Module> Param<M> {
pub fn num_params(&self) -> usize {
self.value.num_params()
}

pub fn update_params<O: Optimizer<M::Backend>>(&mut self, grads: &Gradients, optim: &mut O)
where
pub fn update_params<O: Optimizer<Backend = M::Backend>>(
&mut self,
grads: &Gradients,
optim: &mut O,
) where
M::Backend: back::ad::Backend,
{
self.value.update_params(grads, optim);
Expand Down Expand Up @@ -157,6 +177,14 @@ impl<M: Module> Param<M> {
{
self.value.load_from_parent(name, state);
}

pub fn inner(&self) -> Param<M::InnerModule>
where
M: ADModule,
M::Backend: back::ad::Backend,
{
Param::new(self.value.inner())
}
}

impl<M: Module> Param<Vec<M>> {
Expand All @@ -169,8 +197,11 @@ impl<M: Module> Param<Vec<M>> {
num_params
}

pub fn update_params<O: Optimizer<M::Backend>>(&mut self, grads: &Gradients, optim: &mut O)
where
pub fn update_params<O: Optimizer<Backend = M::Backend>>(
&mut self,
grads: &Gradients,
optim: &mut O,
) where
M::Backend: back::ad::Backend,
{
for module in self.value.iter_mut() {
Expand Down Expand Up @@ -209,4 +240,12 @@ impl<M: Module> Param<Vec<M>> {
{
todo!();
}

pub fn inner(&self) -> Param<Vec<M::InnerModule>>
where
M: ADModule,
M::Backend: back::ad::Backend,
{
Param::new(self.value.iter().map(|v| v.inner()).collect())
}
}
6 changes: 4 additions & 2 deletions src/optim/optim.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::tensor::back::ad::Backend;
use crate::tensor::{Gradients, Tensor};

pub trait Optimizer<B: Backend>: Send + Sync {
fn update<const D: usize>(&mut self, tensor: &mut Tensor<B, D>, grads: &Gradients);
pub trait Optimizer: Send + Sync {
type Backend: Backend;

fn update<const D: usize>(&mut self, tensor: &mut Tensor<Self::Backend, D>, grads: &Gradients);
}
Loading

0 comments on commit 674e078

Please sign in to comment.