Skip to content

Commit

Permalink
Feat/module visitor (tracel-ai#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Dec 17, 2022
1 parent d959241 commit a599eae
Show file tree
Hide file tree
Showing 26 changed files with 379 additions and 248 deletions.
4 changes: 4 additions & 0 deletions burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,8 @@ impl<B: Backend> ADBackend for ADBackendDecorator<B> {
fn from_inner<const D: usize>(tensor: B::TensorPrimitive<D>) -> ADTensor<D, B> {
ADTensor::from_tensor(tensor)
}

fn node_id<const D: usize>(tensor: &Self::TensorPrimitive<D>) -> String {
tensor.node.id.to_string()
}
}
19 changes: 15 additions & 4 deletions burn-autodiff/src/graph/grad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::graph::{
traversal::{BreadthFirstSearch, GraphTraversal},
};
use burn_tensor::{
backend::{ADBackend, Gradients},
backend::{ADBackend, Backend, Gradients},
ops::Zeros,
Tensor,
};
Expand All @@ -20,16 +20,27 @@ impl<B: ADBackend> Gradients<B> for Grads {
grads: HashMap::new(),
}
}
fn get<const D: usize>(&self, id: &str) -> Option<&Tensor<B::InnerBackend, D>> {

fn get<const D: usize>(&self, id: &str) -> Option<Tensor<B::InnerBackend, D>> {
let grad = match self.grads.get(id) {
Some(grad) => grad,
None => return None,
};

grad.downcast_ref()
let tensor = grad.downcast_ref().map(
|primitive: &<B::InnerBackend as Backend>::TensorPrimitive<D>| {
Tensor::from_primitive(primitive.clone())
},
);
tensor
}

fn register<const D: usize>(&mut self, id: String, value: Tensor<B::InnerBackend, D>) {
self.grads.insert(id, Box::new(value));
self.grads.insert(id, Box::new(value.into_primitive()));
}

fn len(&self) -> usize {
self.grads.len()
}
}

Expand Down
12 changes: 5 additions & 7 deletions burn-derive/src/module/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {

let param = Param::from_ast(ast);
let num_params_fn = param.gen_num_params_fn();
let update_params_fn = param.gen_update_params_fn();
let load_optim_state = param.gen_load_optim_state_fn();
let register_optim_state = param.gen_register_optim_state_fn();
let visit = param.gen_visit_fn();
let visit_mut = param.gen_visit_mut_fn();
let devices_fn = param.gen_devices_fn();
let to_device_fn = param.gen_to_device_fn();
let state_fn = param.gen_state_fn();
Expand All @@ -28,16 +27,15 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {

#devices_fn
#to_device_fn
#detach_fn

#state_fn
#load_fn

#num_params_fn
#update_params_fn
#detach_fn

#load_optim_state
#register_optim_state
#visit
#visit_mut
}

impl #generics burn::module::ADModule for #name #generics_ty where B: burn::tensor::backend::ADBackend, {
Expand Down
38 changes: 6 additions & 32 deletions burn-derive/src/module/param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,59 +47,33 @@ impl Param {
}
}

pub fn gen_update_params_fn(&self) -> TokenStream {
pub fn gen_visit_fn(&self) -> TokenStream {
let mut body = quote! {};
for field in self.fields_param.iter() {
let name = field.ident();
body.extend(quote! {
self.#name.update_params(grads, optim);
self.#name.visit(visitor);
});
}

quote! {
fn update_params<O: burn::optim::Optimizer<Backend = B>>(
&mut self,
grads: &<B as burn::tensor::backend::ADBackend>::Gradients,
optim: &mut O
)
where
B: burn::tensor::backend::ADBackend {
fn visit<V: burn::module::ModuleVisitor<Self::Backend>>(&self, visitor: &mut V) {
#body
}
}
}

pub fn gen_load_optim_state_fn(&self) -> TokenStream {
pub fn gen_visit_mut_fn(&self) -> TokenStream {
let mut body = quote! {};
for field in self.fields_param.iter() {
let name = field.ident();
body.extend(quote! {
self.#name.load_optim_state(optim, state_optim);
self.#name.visit_mut(visitor);
});
}

quote! {
fn load_optim_state<O: burn::optim::Optimizer<Backend = B>>(&self, optim: &mut O, state_optim: &burn::module::StateNamed<B::Elem>)
where
B: burn::tensor::backend::ADBackend {
#body
}
}
}

pub fn gen_register_optim_state_fn(&self) -> TokenStream {
let mut body = quote! {};
for field in self.fields_param.iter() {
let name = field.ident();
body.extend(quote! {
self.#name.register_optim_state(optim, state_optim);
});
}

quote! {
fn register_optim_state<O: burn::optim::Optimizer<Backend = B>>(&self, optim: &O, state_optim: &mut burn::module::StateNamed<B::Elem>)
where
B: burn::tensor::backend::ADBackend {
fn visit_mut<V: burn::module::ModuleVisitorMut<Self::Backend>>(&mut self, visitor: &mut V) {
#body
}
}
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pub trait ADBackend: Backend {
tensor: &Self::TensorPrimitive<D>,
grads: &Self::Gradients,
) -> Option<ADBackendTensorPrimitive<D, Self>>;
fn node_id<const D: usize>(tensor: &Self::TensorPrimitive<D>) -> String;
fn inner<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
) -> <Self::InnerBackend as Backend>::TensorPrimitive<D>;
Expand Down
6 changes: 5 additions & 1 deletion burn-tensor/src/tensor/backend/grad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ use crate::Tensor;

pub trait Gradients<B: ADBackend>: Send + Sync {
fn empty() -> Self;
fn get<const D: usize>(&self, id: &str) -> Option<&Tensor<B::InnerBackend, D>>;
fn get<const D: usize>(&self, id: &str) -> Option<Tensor<B::InnerBackend, D>>;
fn register<const D: usize>(&mut self, id: String, value: Tensor<B::InnerBackend, D>);
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
4 changes: 4 additions & 0 deletions burn-tensor/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,10 @@ impl<const D: usize, B: ADBackend> Tensor<B, D> {
B::grad(&self.value, grads).map(Tensor::new)
}

pub fn node_id(&self) -> String {
B::node_id(&self.value)
}

pub fn inner(&self) -> Tensor<B::InnerBackend, D> {
Tensor::new(B::inner(&self.value))
}
Expand Down
47 changes: 15 additions & 32 deletions burn/src/module/base.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{State, StateNamed};
use crate::optim::Optimizer;
use super::{ParamId, State};
use crate::tensor::backend::{ADBackend, Backend};
pub use burn_derive::Module;
use burn_tensor::Tensor;

/// Trait for all neural network modules.
///
Expand Down Expand Up @@ -46,37 +46,20 @@ pub trait Module: Send + Sync + std::fmt::Debug + std::fmt::Display {
fn detach(&mut self);
/// Get the number of parameters the module has, including all of its sub-modules.
fn num_params(&self) -> usize;
/// Update the module parameters with the given gradients and [optimizer](Optimizer).
fn update_params<O: Optimizer<Backend = Self::Backend>>(
&mut self,
grads: &<Self::Backend as ADBackend>::Gradients,
optim: &mut O,
) where
Self::Backend: ADBackend;
/// Load the [optimizer](Optimizer) state for the module, including all of its sub-modules.
/// Visit each tensor in the module with a [visitor](ModuleVisitor).
fn visit<V: ModuleVisitor<Self::Backend>>(&self, visitor: &mut V);
/// Visit each tensor in the module with a [visitor](ModuleVisitorMut).
///
/// # Note
///
/// This method should only be called by generated code, see [load](Optimizer::load) to load
/// the state of the optimizer.
fn load_optim_state<O: Optimizer<Backend = Self::Backend>>(
&self,
optim: &mut O,
state_optim: &StateNamed<<Self::Backend as Backend>::Elem>,
) where
Self::Backend: ADBackend;
/// Register the [optimizer](Optimizer) state for the module, including all of its sub-modules.
///
/// # Note
///
/// This method should only be called by generated code, see [state](Optimizer::state) to get
/// the state of the optimizer.
fn register_optim_state<O: Optimizer<Backend = Self::Backend>>(
&self,
optim: &O,
state_optim: &mut StateNamed<<Self::Backend as Backend>::Elem>,
) where
Self::Backend: ADBackend;
/// Note that each tensor is mutable and may be updated by the visitor.
fn visit_mut<V: ModuleVisitorMut<Self::Backend>>(&mut self, visitor: &mut V);
}

pub trait ModuleVisitor<B: Backend> {
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>);
}

pub trait ModuleVisitorMut<B: Backend> {
fn visit_mut<const D: usize>(&mut self, id: &ParamId, tensor: &mut Tensor<B, D>);
}

/// Module with auto-differentiation backend.
Expand Down
90 changes: 23 additions & 67 deletions burn/src/module/param/module.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::{load_with_id, state_with_id, Param};
use crate::module::{ADModule, LoadingError, Module, State, StateNamed};
use crate::optim::Optimizer;
use crate::module::{
ADModule, LoadingError, Module, ModuleVisitor, ModuleVisitorMut, State, StateNamed,
};
use crate::tensor::backend::{ADBackend, Backend};

impl<M: Module> Module for Param<M> {
Expand All @@ -10,36 +11,6 @@ impl<M: Module> Module for Param<M> {
self.value.num_params()
}

fn update_params<O: Optimizer<Backend = M::Backend>>(
&mut self,
grads: &<M::Backend as ADBackend>::Gradients,
optim: &mut O,
) where
M::Backend: ADBackend,
{
self.value.update_params(grads, optim);
}

fn load_optim_state<O: Optimizer<Backend = M::Backend>>(
&self,
optim: &mut O,
state_optim: &StateNamed<<M::Backend as Backend>::Elem>,
) where
M::Backend: ADBackend,
{
self.value.load_optim_state(optim, state_optim);
}

fn register_optim_state<O: Optimizer<Backend = M::Backend>>(
&self,
optim: &O,
state_optim: &mut StateNamed<<M::Backend as Backend>::Elem>,
) where
M::Backend: ADBackend,
{
self.value.register_optim_state(optim, state_optim);
}

fn devices(&self) -> Vec<<M::Backend as Backend>::Device> {
self.value.devices()
}
Expand All @@ -64,6 +35,14 @@ impl<M: Module> Module for Param<M> {
fn detach(&mut self) {
self.value.detach()
}

fn visit<V: ModuleVisitor<Self::Backend>>(&self, visitor: &mut V) {
self.value.visit(visitor);
}

fn visit_mut<V: ModuleVisitorMut<Self::Backend>>(&mut self, visitor: &mut V) {
self.value.visit_mut(visitor);
}
}

impl<M: Module> Module for Param<Vec<M>> {
Expand All @@ -78,41 +57,6 @@ impl<M: Module> Module for Param<Vec<M>> {
num_params
}

fn update_params<O: Optimizer<Backend = M::Backend>>(
&mut self,
grads: &<M::Backend as ADBackend>::Gradients,
optim: &mut O,
) where
M::Backend: ADBackend,
{
for module in self.value.iter_mut() {
module.update_params(grads, optim);
}
}

fn load_optim_state<O: Optimizer<Backend = M::Backend>>(
&self,
optim: &mut O,
state_optim: &StateNamed<<M::Backend as Backend>::Elem>,
) where
M::Backend: ADBackend,
{
for module in self.value.iter() {
module.load_optim_state(optim, state_optim);
}
}
fn register_optim_state<O: Optimizer<Backend = M::Backend>>(
&self,
optim: &O,
state_optim: &mut StateNamed<<M::Backend as Backend>::Elem>,
) where
M::Backend: ADBackend,
{
for module in self.value.iter() {
module.register_optim_state(optim, state_optim);
}
}

fn devices(&self) -> Vec<<M::Backend as Backend>::Device> {
let mut devices = Vec::new();
for module in self.value.iter() {
Expand Down Expand Up @@ -165,6 +109,18 @@ impl<M: Module> Module for Param<Vec<M>> {
value.detach();
}
}

fn visit<V: ModuleVisitor<Self::Backend>>(&self, visitor: &mut V) {
for module in self.value.iter() {
module.visit(visitor);
}
}

fn visit_mut<V: ModuleVisitorMut<Self::Backend>>(&mut self, visitor: &mut V) {
for module in self.value.iter_mut() {
module.visit_mut(visitor);
}
}
}

impl<M: Module> Param<Vec<M>> {
Expand Down
Loading

0 comments on commit a599eae

Please sign in to comment.