Skip to content

Commit

Permalink
Feat/multi device (tracel-ai#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Dec 20, 2022
1 parent a599eae commit 3a91c2c
Show file tree
Hide file tree
Showing 34 changed files with 650 additions and 273 deletions.
4 changes: 0 additions & 4 deletions burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,4 @@ impl<B: Backend> ADBackend for ADBackendDecorator<B> {
fn from_inner<const D: usize>(tensor: B::TensorPrimitive<D>) -> ADTensor<D, B> {
ADTensor::from_tensor(tensor)
}

fn node_id<const D: usize>(tensor: &Self::TensorPrimitive<D>) -> String {
tensor.node.id.to_string()
}
}
38 changes: 2 additions & 36 deletions burn-autodiff/src/graph/grad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,14 @@ use crate::graph::{
node::{BackwardNode, ForwardNode},
traversal::{BreadthFirstSearch, GraphTraversal},
};
use burn_tensor::{
backend::{ADBackend, Backend, Gradients},
ops::Zeros,
Tensor,
};
use burn_tensor::ops::Zeros;
use std::{any::Any, collections::HashMap, ops::Add};

#[derive(Default)]
#[derive(Default, Debug)]
pub struct Grads {
grads: HashMap<String, Box<dyn Any + Send + Sync>>,
}

impl<B: ADBackend> Gradients<B> for Grads {
fn empty() -> Self {
Self {
grads: HashMap::new(),
}
}

fn get<const D: usize>(&self, id: &str) -> Option<Tensor<B::InnerBackend, D>> {
let grad = match self.grads.get(id) {
Some(grad) => grad,
None => return None,
};

let tensor = grad.downcast_ref().map(
|primitive: &<B::InnerBackend as Backend>::TensorPrimitive<D>| {
Tensor::from_primitive(primitive.clone())
},
);
tensor
}

fn register<const D: usize>(&mut self, id: String, value: Tensor<B::InnerBackend, D>) {
self.grads.insert(id, Box::new(value.into_primitive()));
}

fn len(&self) -> usize {
self.grads.len()
}
}

impl Grads {
pub fn register_node<T>(&mut self, node: &BackwardNode<T>)
where
Expand Down
5 changes: 5 additions & 0 deletions burn-derive/src/module/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
let load_fn = param.gen_load_fn();
let inner_fn = param.gen_inner_fn();
let detach_fn = param.gen_detach_fn();
let clone_fn = param.gen_clone_fn();
let generics_names_except_backend = generics_names_except_backend(&ast.generics);

let gen = quote! {
Expand Down Expand Up @@ -48,6 +49,10 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
impl #generics std::fmt::Display for #name #generics_ty #generics_where {
#display_fn
}

impl #generics Clone for #name #generics_ty #generics_where {
#clone_fn
}
};

gen.into()
Expand Down
27 changes: 27 additions & 0 deletions burn-derive/src/module/param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,33 @@ impl Param {
}
}

pub fn gen_clone_fn(&self) -> TokenStream {
let mut body = quote! {};
let mut names = Vec::new();
let mut fields = Vec::new();

fields.append(&mut self.fields_param.clone());
fields.append(&mut self.fields_other.clone());
for field in fields {
let name = field.ident();
names.push(name.clone());

body.extend(quote! {
let #name = self.#name.clone();
});
}

quote! {
fn clone(&self) -> Self {
#body

Self {
#(#names),*
}
}
}
}

pub fn gen_state_fn(&self) -> TokenStream {
let mut body = quote! {
let mut state = burn::module::StateNamed::new();
Expand Down
4 changes: 1 addition & 3 deletions burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use super::Gradients;
use crate::ops::*;
use crate::tensor::Element;

Expand Down Expand Up @@ -46,14 +45,13 @@ pub(crate) type ADBackendTensorPrimitive<const D: usize, B> =

pub trait ADBackend: Backend {
type InnerBackend: Backend<Device = Self::Device, Elem = Self::Elem>;
type Gradients: Gradients<Self>;
type Gradients: Send + Sync;

fn backward<const D: usize>(tensor: &Self::TensorPrimitive<D>) -> Self::Gradients;
fn grad<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
grads: &Self::Gradients,
) -> Option<ADBackendTensorPrimitive<D, Self>>;
fn node_id<const D: usize>(tensor: &Self::TensorPrimitive<D>) -> String;
fn inner<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
) -> <Self::InnerBackend as Backend>::TensorPrimitive<D>;
Expand Down
12 changes: 0 additions & 12 deletions burn-tensor/src/tensor/backend/grad.rs

This file was deleted.

2 changes: 0 additions & 2 deletions burn-tensor/src/tensor/backend/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
mod base;
pub use base::*;
mod grad;
pub use grad::*;

// Not needed for now, usefull for different tensor memory layout
// pub mod conversion;
4 changes: 0 additions & 4 deletions burn-tensor/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -683,10 +683,6 @@ impl<const D: usize, B: ADBackend> Tensor<B, D> {
B::grad(&self.value, grads).map(Tensor::new)
}

pub fn node_id(&self) -> String {
B::node_id(&self.value)
}

pub fn inner(&self) -> Tensor<B::InnerBackend, D> {
Tensor::new(B::inner(&self.value))
}
Expand Down
73 changes: 73 additions & 0 deletions burn-tensor/src/tensor/container.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use crate::{backend::Backend, Tensor};
use std::{any::Any, collections::HashMap};

/// Contains tensor of arbitrary dimension.
#[derive(Debug)]
pub struct TensorContainer<B: Backend, ID> {
tensors: HashMap<ID, Box<dyn Any + Send + Sync>>,
_b: B,
}

impl<B, ID> Default for TensorContainer<B, ID>
where
B: Backend,
ID: std::hash::Hash + PartialEq + Eq,
{
fn default() -> Self {
Self::new()
}
}

impl<B, ID> TensorContainer<B, ID>
where
B: Backend,
ID: std::hash::Hash + PartialEq + Eq,
{
/// Create an empty container.
pub fn new() -> Self {
Self {
tensors: HashMap::new(),
_b: B::default(),
}
}

/// Get a tensor with the given ID.
pub fn get<const D: usize>(&self, id: &ID) -> Option<Tensor<B, D>> {
let grad = match self.tensors.get(id) {
Some(grad) => grad,
None => return None,
};

let tensor = grad
.downcast_ref()
.map(|primitive: &B::TensorPrimitive<D>| Tensor::from_primitive(primitive.clone()));
tensor
}

/// Register a new tensor for the given ID.
///
/// # Notes
///
/// If a tensor is already registered for the given ID, it will be replaced.
pub fn register<const D: usize>(&mut self, id: ID, value: Tensor<B, D>) {
self.tensors.insert(id, Box::new(value.into_primitive()));
}

/// Remove a tensor for the given ID and returns it.
pub fn remove<const D: usize>(&mut self, id: &ID) -> Option<Tensor<B, D>> {
self.tensors
.remove(id)
.map(|item| item.downcast::<B::TensorPrimitive<D>>().unwrap())
.map(|primitive| Tensor::from_primitive(*primitive))
}

/// The number of tensors registered.
pub fn len(&self) -> usize {
self.tensors.len()
}

/// If any tensor is contained.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
3 changes: 2 additions & 1 deletion burn-tensor/src/tensor/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
pub mod ops;
pub(crate) mod stats;

mod base;
Expand All @@ -15,8 +14,10 @@ pub use shape::*;

pub mod activation;
pub mod backend;
pub mod container;
pub mod loss;
pub mod module;
pub mod ops;

#[cfg(feature = "experimental-named-tensor")]
mod named;
Expand Down
2 changes: 1 addition & 1 deletion burn/src/module/param/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::module::{LoadingError, State, StateNamed};
use crate::tensor::Element;

/// Define a trainable parameter.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Param<T> {
pub(super) id: ParamId,
pub(super) value: T,
Expand Down
35 changes: 15 additions & 20 deletions burn/src/optim/base.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use super::visitor::{GradientsLoader, GradientsRegister, ModuleTensorUpdater};
use crate::module::{LoadingError, Module, ParamId, State, StateNamed};
use super::visitor::{GradientsLoader, GradientsParams, GradientsRegister, ModuleTensorUpdater};
use crate::module::{ADModule, LoadingError, Module, ParamId, State, StateNamed};
use crate::tensor::backend::{ADBackend, Backend};
use crate::tensor::{Data, Tensor};
use burn_tensor::backend::Gradients;

pub trait Optimizer: Send + Sync {
type Backend: ADBackend;
Expand All @@ -12,13 +11,13 @@ pub trait Optimizer: Send + Sync {
&mut self,
id: &ParamId,
tensor: &mut Tensor<Self::Backend, D>,
grads: &<Self::Backend as ADBackend>::Gradients,
grad: Tensor<<Self::Backend as ADBackend>::InnerBackend, D>,
);

/// Update the parameters of the given module using the given the gradients.
fn update_module<M>(&mut self, module: &mut M, grads: &<Self::Backend as ADBackend>::Gradients)
fn update_module<M>(&mut self, module: &mut M, grads: GradientsParams<M::ADBackend>)
where
M: Module<Backend = Self::Backend>,
M: ADModule<ADBackend = Self::Backend>,
Self: Sized,
{
let mut visitor = ModuleTensorUpdater::new(self, grads);
Expand Down Expand Up @@ -92,31 +91,27 @@ pub trait Optimizer: Send + Sync {
}
}

pub(super) fn register_state_gradients<const D: usize, B: ADBackend, F: Fn(&str) -> String>(
pub(super) fn register_state_gradients<const D: usize, B: ADBackend, F: Fn(&ParamId) -> String>(
id: &ParamId,
state: &mut StateNamed<B::Elem>,
grads: &B::Gradients,
grads: &GradientsParams<B>,
id_to_key: F,
) {
let id = id.to_string();

if let Some(velocity) = grads.get::<D>(&id) {
let data = State::Data(velocity.to_data().serialize());
state.register_state(id_to_key(&id).as_str(), data);
if let Some(grad) = grads.get::<D>(id) {
let data = State::Data(grad.to_data().serialize());
state.register_state(id_to_key(id).as_str(), data);
};
}

pub(super) fn load_state_gradients<const D: usize, B: ADBackend, F: Fn(&str) -> String>(
pub(super) fn load_state_gradients<const D: usize, B: ADBackend, F: Fn(&ParamId) -> String>(
id: &ParamId,
state: &StateNamed<B::Elem>,
grads: &mut B::Gradients,
grads: &mut GradientsParams<B>,
id_to_key: F,
device: &B::Device,
) {
let id = id.to_string();

if let Some(State::Data(data)) = state.get(id_to_key(&id).as_str()) {
let velocity = Tensor::<B::InnerBackend, D>::from_data_device(Data::from(data), *device);
grads.register(id, velocity);
if let Some(State::Data(data)) = state.get(id_to_key(id).as_str()) {
let tensor = Tensor::<B::InnerBackend, D>::from_data_device(Data::from(data), *device);
grads.register(id.clone(), tensor);
};
}
14 changes: 6 additions & 8 deletions burn/src/optim/decay.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate as burn;

use super::visitor::GradientsParams;
use super::{load_state_gradients, register_state_gradients};
use crate::config::Config;
use crate::module::{ParamId, StateNamed};
use crate::tensor::backend::ADBackend;
use crate::tensor::{ElementConversion, Tensor};
use burn_tensor::backend::Gradients;

/// Configuration to create [WeightDecay](WeightDecay).
#[derive(Config)]
Expand All @@ -17,14 +17,14 @@ pub struct WeightDecayConfig {
/// Weight decay implementation that transforms gradients.
pub struct WeightDecay<B: ADBackend> {
penalty: B::Elem,
gradients: B::Gradients,
gradients: GradientsParams<B>,
}

impl<B: ADBackend> WeightDecay<B> {
pub fn new(config: &WeightDecayConfig) -> Self {
Self {
penalty: config.penalty.to_elem(),
gradients: B::Gradients::empty(),
gradients: GradientsParams::<B>::new(),
}
}

Expand All @@ -33,15 +33,13 @@ impl<B: ADBackend> WeightDecay<B> {
id: &ParamId,
grad: Tensor<B::InnerBackend, D>,
) -> Tensor<B::InnerBackend, D> {
let id = id.to_string();

let grad = match self.gradients.get::<D>(&id) {
let grad = match self.gradients.get::<D>(id) {
Some(grad_last_step) => grad_last_step.mul_scalar(self.penalty).add(&grad),
None => grad,
};

// Update gradients
self.gradients.register(id, grad.clone());
self.gradients.register(id.clone(), grad.clone());

grad
}
Expand All @@ -58,7 +56,7 @@ impl<B: ADBackend> WeightDecay<B> {
load_state_gradients::<D, B, _>(id, state, &mut self.gradients, Self::state_key, device);
}

fn state_key(id: &str) -> String {
fn state_key(id: &ParamId) -> String {
format!("weight-decay-{}", id)
}
}
Loading

0 comments on commit 3a91c2c

Please sign in to comment.