Skip to content

Commit

Permalink
Fix/constant tensors (tracel-ai#984)
Browse files Browse the repository at this point in the history
* Generalize autodiff tensor

* Can have int const module

* Update example

* Support no-std with burn-import

* Fix typos

* Fix alloc problems

* Revert burn-import changes

* Fix examples

* Support Int and Bool Params

* Fix

* Add comment
  • Loading branch information
nathanielsimard authored Nov 21, 2023
1 parent 2f079e9 commit cabbaab
Show file tree
Hide file tree
Showing 25 changed files with 679 additions and 181 deletions.
24 changes: 24 additions & 0 deletions burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,28 @@ impl<B: Backend> AutodiffBackend for Autodiff<B> {
grads.remove(tensor);
grads.register::<B, D>(tensor.node.clone(), grad);
}

fn int_inner<const D: usize>(
tensor: burn_tensor::ops::IntTensor<Self, D>,
) -> burn_tensor::ops::IntTensor<Self::InnerBackend, D> {
tensor
}

fn bool_inner<const D: usize>(
tensor: burn_tensor::ops::BoolTensor<Self, D>,
) -> burn_tensor::ops::BoolTensor<Self::InnerBackend, D> {
tensor
}

fn int_from_inner<const D: usize>(
tensor: burn_tensor::ops::IntTensor<Self::InnerBackend, D>,
) -> burn_tensor::ops::IntTensor<Self, D> {
tensor
}

fn bool_from_inner<const D: usize>(
tensor: burn_tensor::ops::BoolTensor<Self::InnerBackend, D>,
) -> burn_tensor::ops::BoolTensor<Self, D> {
tensor
}
}
114 changes: 47 additions & 67 deletions burn-core/src/module/base.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,41 @@
use alloc::vec::Vec;

use super::ParamId;
use crate::{
record::Record,
tensor::backend::{AutodiffBackend, Backend},
};
use alloc::vec::Vec;
pub use burn_derive::Module;
use burn_tensor::Tensor;
use burn_tensor::{Bool, Int, Tensor};

/// Type alias to `Vec<B::Device>` which supports `no_std` environements, but automatically using
/// the `alloc` crate.
pub type Devices<B> = Vec<<B as Backend>::Device>;

// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
// We may consider making it public in the future.
macro_rules! module {
(map=$module:ident, ops=$item:expr) => {{
struct Mapper;
impl<B: Backend> ModuleMapper<B> for Mapper {
fn map<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
fn map_float<const D: usize>(
&mut self,
_id: &ParamId,
tensor: Tensor<B, D>,
) -> Tensor<B, D> {
let func = $item;
func(tensor)
}
}
let mut mapper = Mapper;
$module.map(&mut mapper)
}};
(map=$module:ident, ops=$item:expr, capture={$capture:ident: $ty:ty}) => {{
struct Mapper<'a, B: Backend> {
capture: &'a $ty,
backend: core::marker::PhantomData<B>,
}
impl<'a, B: Backend> ModuleMapper<B> for Mapper<'a, B> {
fn map<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
let func = $item;
func(tensor, self.capture)
}
}
let mut mapper = Mapper {
capture: $capture,
backend: core::marker::PhantomData,
};
$module.map(&mut mapper)
}};
(visit=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
(visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
struct Visitor<'a, B: Backend> {
state: &'a mut $state_ty,
backend: core::marker::PhantomData<B>,
}
impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
fn visit<const D: usize>(&mut self, _id: &ParamId, tensor: &Tensor<B, D>) {
fn visit_float<const D: usize>(&mut self, _id: &ParamId, tensor: &Tensor<B, D>) {
let func = $item;
func(tensor, &mut self.state)
}
Expand Down Expand Up @@ -94,43 +84,17 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
/// Type to save and load the module.
type Record: Record;

/// Get the device list of the module and all of its sub-modules.
fn devices(&self) -> Vec<B::Device> {
module!(
visit = self,
ops = |tensor: &Tensor<B, D>, state: &mut Vec<B::Device>| {
let device = tensor.device();
if !state.contains(&device) {
state.push(device);
}
},
state = Vec<B::Device>,
init = Vec::new
)
}
/// Collects devices in the given vector and returns it with the devices found in the module
/// structure without duplicates.
fn devices(&self, devices: Devices<B>) -> Devices<B>;

/// Fork the module and all of its sub-modules to the given device.
///
/// # Notes
///
/// This is similar to [to_device](Module::to_device), but it ensures the module will
/// have its own autodiff graph.
fn fork(self, device: &B::Device) -> Self {
module!(
map = self,
ops = |tensor: Tensor<B, D>, device: &B::Device| {
let is_require_grad = tensor.is_require_grad();
let mut tensor = tensor.to_device(device).detach();

if is_require_grad {
tensor = tensor.require_grad();
}

tensor
},
capture = { device: B::Device }
)
}
fn fork(self, device: &B::Device) -> Self;

/// Move the module and all of its sub-modules to the given device.
///
Expand All @@ -139,13 +103,7 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
/// The device operations will be registered in the autodiff graph. Therefore, be sure to call
/// backward only one time even if you have the same module on multiple devices. If you want to
/// call backward multiple times, look into using [fork](Module::fork) instead.
fn to_device(self, device: &B::Device) -> Self {
module!(
map = self,
ops = |tensor: Tensor<B, D>, device: &B::Device| tensor.to_device(device),
capture = { device: B::Device }
)
}
fn to_device(self, device: &B::Device) -> Self;

/// Each tensor in the module tree will not require grad.
///
Expand All @@ -164,18 +122,18 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
/// Get the number of parameters the module has, including all of its sub-modules.
fn num_params(&self) -> usize {
module!(
visit = self,
visit_float = self,
ops = |tensor: &Tensor<B, D>, state: &mut usize| {
*state += tensor.shape().num_elements();
},
state = usize,
init = || 0
)
}
/// Visit each tensor in the module with a [visitor](ModuleVisitor).
/// Visit each tensor parameter in the module with a [visitor](ModuleVisitor).
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V);

/// Map each tensor in the module with a [mapper](ModuleMapper).
/// Map each tensor parameter in the module with a [mapper](ModuleMapper).
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self;

/// Load the module state from a record.
Expand Down Expand Up @@ -233,14 +191,36 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {

/// Module visitor trait.
pub trait ModuleVisitor<B: Backend> {
/// Visit a tensor in the module.
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>);
/// Visit a float tensor in the module.
fn visit_float<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D>) {}
/// Visit an int tensor in the module.
fn visit_int<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D, Int>) {}
/// Visit a bool tensor in the module.
fn visit_bool<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D, Bool>) {}
}

/// Module mapper trait.
pub trait ModuleMapper<B: Backend> {
/// Map a tensor in the module.
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D>;
/// Map a float tensor in the module.
fn map_float<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor
}
/// Map an int tensor in the module.
fn map_int<const D: usize>(
&mut self,
_id: &ParamId,
tensor: Tensor<B, D, Int>,
) -> Tensor<B, D, Int> {
tensor
}
/// Map a bool tensor in the module.
fn map_bool<const D: usize>(
&mut self,
_id: &ParamId,
tensor: Tensor<B, D, Bool>,
) -> Tensor<B, D, Bool> {
tensor
}
}

/// Module with auto-differentiation backend.
Expand Down
79 changes: 53 additions & 26 deletions burn-core/src/module/param/constant.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
use core::marker::PhantomData;

use crate::{
self as burn,
module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor},
module::{AutodiffModule, Devices, Module, ModuleMapper, ModuleVisitor},
record::Record,
};
use burn::record::PrecisionSettings;
use burn_tensor::{
backend::{AutodiffBackend, Backend},
Tensor,
BasicAutodiffOps, BasicOps, Tensor,
};

use super::ParamId;
use core::marker::PhantomData;

/// Record used for constant type implementing the [module](crate::module::Module) trait.
#[derive(Debug, Clone, Copy, new)]
Expand Down Expand Up @@ -69,6 +66,18 @@ macro_rules! constant {
fn into_record(self) -> Self::Record {
burn::module::ConstantRecord::new()
}

fn to_device(self, _: &B::Device) -> Self {
self
}

fn fork(self, _: &B::Device) -> Self {
self
}

fn devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
devices
}
};

(ad_module, $type:ty) => {
Expand Down Expand Up @@ -113,27 +122,13 @@ constant!(i32);
constant!(i16);
constant!(i8);

impl<const D: usize, B: Backend> Module<B> for Tensor<B, D> {
impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
type Record = ConstantRecord;

fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
// Important:
// We need to implement visit method for Tensor Module because
// to_device will be called during the visit method of the ModuleVisitor

// We are using a dummy param id because the visit method requires a param id
let dummy_param_id = ParamId::new();
visitor.visit(&dummy_param_id, self)
}

fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
// Important:
// We need to implement visit method for Tensor Module because
// to_device will be called during the visit method of the ModuleVisitor
fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {}

// We are using a dummy param id because the visit method requires a param id
let dummy_param_id = ParamId::new();
mapper.map(&dummy_param_id, self)
fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
self
}

fn into_record(self) -> Self::Record {
Expand All @@ -143,10 +138,30 @@ impl<const D: usize, B: Backend> Module<B> for Tensor<B, D> {
fn load_record(self, _record: Self::Record) -> Self {
self
}

fn to_device(self, device: &B::Device) -> Self {
self.to_device(device)
}

fn fork(self, device: &B::Device) -> Self {
self.to_device(device)
}

fn devices(&self, mut devices: Devices<B>) -> Devices<B> {
let device = self.device();

if !devices.contains(&device) {
devices.push(device)
}

devices
}
}

impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Tensor<B, D> {
type InnerModule = Tensor<B::InnerBackend, D>;
impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
for Tensor<B, D, K>
{
type InnerModule = Tensor<B::InnerBackend, D, K::InnerKind>;

fn valid(&self) -> Self::InnerModule {
self.clone().inner()
Expand All @@ -171,6 +186,18 @@ impl<B: Backend> Module<B> for PhantomData<B> {
fn into_record(self) -> Self::Record {
ConstantRecord::new()
}

fn to_device(self, _: &<B as Backend>::Device) -> Self {
self
}

fn fork(self, _: &<B as Backend>::Device) -> Self {
self
}

fn devices(&self, devices: Devices<B>) -> Devices<B> {
devices
}
}

impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
Expand Down
Loading

0 comments on commit cabbaab

Please sign in to comment.