Skip to content

Commit

Permalink
Refactor Param wrapping only for Tensor (tracel-ai#259)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Mar 31, 2023
1 parent 7364d09 commit 32d38be
Show file tree
Hide file tree
Showing 44 changed files with 881 additions and 911 deletions.
33 changes: 11 additions & 22 deletions burn-core/src/module/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ use burn_tensor::Tensor;
/// This will make your module trainable, savable and loadable via
/// [state](Module::state) and [load](Module::load).
///
/// Module concrete types should define their parameters via the [Param](crate::module::Param)
/// struct.
///
/// # Example
///
/// A module should have a [backend](crate::tensor::backend::Backend) defined as a generic
Expand All @@ -26,39 +23,34 @@ use burn_tensor::Tensor;
///
/// use burn::{
/// nn,
/// module::{Param, Module},
/// module::Module,
/// tensor::Tensor,
/// tensor::backend::Backend,
/// };
///
/// #[derive(Module, Debug)]
/// struct MyModule<B: Backend> {
/// my_param: Param<nn::Linear<B>>,
/// my_param: nn::Linear<B>,
/// my_other_field: usize,
/// }
/// ```
pub trait Module: Clone + Send + Sync + core::fmt::Debug + core::fmt::Display {
type Backend: Backend;

pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
/// Get the device list of the module and all of its sub-modules.
fn devices(&self) -> Vec<<Self::Backend as Backend>::Device>;
fn devices(&self) -> Vec<B::Device>;
/// Move the module and all of its sub-modules to the given device.
fn to_device(self, device: &<Self::Backend as Backend>::Device) -> Self;
fn to_device(self, device: &B::Device) -> Self;
/// Load the module state.
fn load(
self,
state: &State<<Self::Backend as Backend>::FloatElem>,
) -> Result<Self, LoadingError>;
fn load(self, state: &State<B::FloatElem>) -> Result<Self, LoadingError>;
/// Get the module state.
fn state(&self) -> State<<Self::Backend as Backend>::FloatElem>;
fn state(&self) -> State<B::FloatElem>;
/// Detach the module from the graph.
fn detach(self) -> Self;
/// Get the number of parameters the module has, including all of its sub-modules.
fn num_params(&self) -> usize;
/// Visit each tensor in the module with a [visitor](ModuleVisitor).
fn visit<V: ModuleVisitor<Self::Backend>>(&self, visitor: &mut V);
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V);
/// Map each tensor in the module with a [mapper](ModuleMapper).
fn map<M: ModuleMapper<Self::Backend>>(self, mapper: &mut M) -> Self;
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self;
}

pub trait ModuleVisitor<B: Backend> {
Expand All @@ -70,11 +62,8 @@ pub trait ModuleMapper<B: Backend> {
}

/// Module with auto-differentiation backend.
pub trait ADModule:
Module<Backend = Self::ADBackend> + Send + Sync + core::fmt::Debug + core::fmt::Display
{
type ADBackend: ADBackend;
type InnerModule: Module<Backend = <Self::ADBackend as ADBackend>::InnerBackend>;
pub trait ADModule<B: ADBackend>: Module<B> + Send + Sync + core::fmt::Debug {
type InnerModule: Module<B::InnerBackend>;

/// Get the same module, but on the inner backend without auto-differentiation.
fn inner(self) -> Self::InnerModule;
Expand Down
86 changes: 86 additions & 0 deletions burn-core/src/module/param/constant.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use crate as burn;

#[macro_export]
macro_rules! constant {
(module) => {
fn devices(&self) -> alloc::vec::Vec<<B as burn_tensor::backend::Backend>::Device> {
alloc::vec::Vec::new()
}

fn to_device(self, _device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
self
}

fn load(
self,
_state: &burn::module::State<<B as burn_tensor::backend::Backend>::FloatElem>,
) -> Result<Self, burn::module::LoadingError> {
Ok(self)
}

fn state(&self) -> burn::module::State<<B as burn_tensor::backend::Backend>::FloatElem> {
burn::module::State::StateNamed(burn::module::StateNamed::new())
}

fn detach(self) -> Self {
self
}

fn num_params(&self) -> usize {
0
}

fn visit<V: burn::module::ModuleVisitor<B>>(&self, _visitor: &mut V) {
// Nothing to do
}

fn map<M: burn::module::ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
self
}
};

(ad_module, $type:ty) => {
type InnerModule = $type;

fn inner(self) -> Self::InnerModule {
self
}

fn from_inner(module: Self::InnerModule) -> Self {
module
}
};

($type:ty) => {
impl<B: burn::tensor::backend::Backend> burn::module::Module<B> for $type {
constant!(module);
}

impl<B: burn::tensor::backend::ADBackend> burn::module::ADModule<B> for $type {
constant!(ad_module, $type);
}
};
}

// General Types
constant!(alloc::string::String);
constant!(bool);

// Float Types
constant!(f64);
constant!(f32);
constant!(half::bf16);
constant!(half::f16);

// Unsigned Integer Types
constant!(usize);
constant!(u64);
constant!(u32);
constant!(u16);
constant!(u8);

// Signed Integer Types
constant!(i64);
constant!(i32);
constant!(i16);
constant!(i8);
4 changes: 2 additions & 2 deletions burn-core/src/module/param/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
mod base;
mod constant;
mod id;
mod module;
mod primitive;
mod running;
mod tensor;
mod visitor;

pub use base::*;
pub use id::*;
pub use module::*;
pub use running::*;
pub use tensor::*;
pub use visitor::*;
201 changes: 0 additions & 201 deletions burn-core/src/module/param/module.rs

This file was deleted.

Loading

0 comments on commit 32d38be

Please sign in to comment.