Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/constant tensors #984

Merged
merged 11 commits into from
Nov 21, 2023
24 changes: 24 additions & 0 deletions burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,28 @@ impl<B: Backend> AutodiffBackend for Autodiff<B> {
grads.remove(tensor);
grads.register::<B, D>(tensor.node.clone(), grad);
}

fn int_inner<const D: usize>(
tensor: burn_tensor::ops::IntTensor<Self, D>,
) -> burn_tensor::ops::IntTensor<Self::InnerBackend, D> {
tensor
}

fn bool_inner<const D: usize>(
tensor: burn_tensor::ops::BoolTensor<Self, D>,
) -> burn_tensor::ops::BoolTensor<Self::InnerBackend, D> {
tensor
}

fn int_from_inner<const D: usize>(
tensor: burn_tensor::ops::IntTensor<Self::InnerBackend, D>,
) -> burn_tensor::ops::IntTensor<Self, D> {
tensor
}

fn bool_from_inner<const D: usize>(
tensor: burn_tensor::ops::BoolTensor<Self::InnerBackend, D>,
) -> burn_tensor::ops::BoolTensor<Self, D> {
tensor
}
}
114 changes: 47 additions & 67 deletions burn-core/src/module/base.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,41 @@
use alloc::vec::Vec;

use super::ParamId;
use crate::{
record::Record,
tensor::backend::{AutodiffBackend, Backend},
};
use alloc::vec::Vec;
pub use burn_derive::Module;
use burn_tensor::Tensor;
use burn_tensor::{Bool, Int, Tensor};

/// Type alias to `Vec<B::Device>` which supports `no_std` environements, but automatically using
/// the `alloc` crate.
pub type Devices<B> = Vec<<B as Backend>::Device>;

// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
// We may consider making it public in the future.
macro_rules! module {
(map=$module:ident, ops=$item:expr) => {{
struct Mapper;
impl<B: Backend> ModuleMapper<B> for Mapper {
fn map<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
fn map_float<const D: usize>(
&mut self,
_id: &ParamId,
tensor: Tensor<B, D>,
) -> Tensor<B, D> {
let func = $item;
func(tensor)
}
}
let mut mapper = Mapper;
$module.map(&mut mapper)
}};
(map=$module:ident, ops=$item:expr, capture={$capture:ident: $ty:ty}) => {{
struct Mapper<'a, B: Backend> {
capture: &'a $ty,
backend: core::marker::PhantomData<B>,
}
impl<'a, B: Backend> ModuleMapper<B> for Mapper<'a, B> {
fn map<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
let func = $item;
func(tensor, self.capture)
}
}
let mut mapper = Mapper {
capture: $capture,
backend: core::marker::PhantomData,
};
$module.map(&mut mapper)
}};
(visit=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
(visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
struct Visitor<'a, B: Backend> {
state: &'a mut $state_ty,
backend: core::marker::PhantomData<B>,
}
impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
fn visit<const D: usize>(&mut self, _id: &ParamId, tensor: &Tensor<B, D>) {
fn visit_float<const D: usize>(&mut self, _id: &ParamId, tensor: &Tensor<B, D>) {
let func = $item;
func(tensor, &mut self.state)
}
Expand Down Expand Up @@ -94,43 +84,17 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
/// Type to save and load the module.
type Record: Record;

/// Get the device list of the module and all of its sub-modules.
fn devices(&self) -> Vec<B::Device> {
module!(
visit = self,
ops = |tensor: &Tensor<B, D>, state: &mut Vec<B::Device>| {
let device = tensor.device();
if !state.contains(&device) {
state.push(device);
}
},
state = Vec<B::Device>,
init = Vec::new
)
}
/// Collects devices in the given vector and returns it with the devices found in the module
/// structure without duplicates.
fn devices(&self, devices: Devices<B>) -> Devices<B>;

/// Fork the module and all of its sub-modules to the given device.
///
/// # Notes
///
/// This is similar to [to_device](Module::to_device), but it ensures the module will
/// have its own autodiff graph.
fn fork(self, device: &B::Device) -> Self {
module!(
map = self,
ops = |tensor: Tensor<B, D>, device: &B::Device| {
let is_require_grad = tensor.is_require_grad();
let mut tensor = tensor.to_device(device).detach();

if is_require_grad {
tensor = tensor.require_grad();
}

tensor
},
capture = { device: B::Device }
)
}
fn fork(self, device: &B::Device) -> Self;

/// Move the module and all of its sub-modules to the given device.
///
Expand All @@ -139,13 +103,7 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
/// The device operations will be registered in the autodiff graph. Therefore, be sure to call
/// backward only one time even if you have the same module on multiple devices. If you want to
/// call backward multiple times, look into using [fork](Module::fork) instead.
fn to_device(self, device: &B::Device) -> Self {
module!(
map = self,
ops = |tensor: Tensor<B, D>, device: &B::Device| tensor.to_device(device),
capture = { device: B::Device }
)
}
fn to_device(self, device: &B::Device) -> Self;

/// Each tensor in the module tree will not require grad.
///
Expand All @@ -164,18 +122,18 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
/// Get the number of parameters the module has, including all of its sub-modules.
fn num_params(&self) -> usize {
module!(
visit = self,
visit_float = self,
ops = |tensor: &Tensor<B, D>, state: &mut usize| {
*state += tensor.shape().num_elements();
},
state = usize,
init = || 0
)
}
/// Visit each tensor in the module with a [visitor](ModuleVisitor).
/// Visit each tensor parameter in the module with a [visitor](ModuleVisitor).
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V);

/// Map each tensor in the module with a [mapper](ModuleMapper).
/// Map each tensor parameter in the module with a [mapper](ModuleMapper).
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self;

/// Load the module state from a record.
Expand Down Expand Up @@ -233,14 +191,36 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {

/// Module visitor trait.
pub trait ModuleVisitor<B: Backend> {
/// Visit a tensor in the module.
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>);
/// Visit a float tensor in the module.
fn visit_float<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D>) {}
/// Visit an int tensor in the module.
fn visit_int<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D, Int>) {}
/// Visit a bool tensor in the module.
fn visit_bool<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D, Bool>) {}
}

/// Module mapper trait.
pub trait ModuleMapper<B: Backend> {
/// Map a tensor in the module.
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D>;
/// Map a float tensor in the module.
fn map_float<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor
}
/// Map an int tensor in the module.
fn map_int<const D: usize>(
&mut self,
_id: &ParamId,
tensor: Tensor<B, D, Int>,
) -> Tensor<B, D, Int> {
tensor
}
/// Map a bool tensor in the module.
fn map_bool<const D: usize>(
&mut self,
_id: &ParamId,
tensor: Tensor<B, D, Bool>,
) -> Tensor<B, D, Bool> {
tensor
}
}

/// Module with auto-differentiation backend.
Expand Down
79 changes: 53 additions & 26 deletions burn-core/src/module/param/constant.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
use core::marker::PhantomData;

use crate::{
self as burn,
module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor},
module::{AutodiffModule, Devices, Module, ModuleMapper, ModuleVisitor},
record::Record,
};
use burn::record::PrecisionSettings;
use burn_tensor::{
backend::{AutodiffBackend, Backend},
Tensor,
BasicAutodiffOps, BasicOps, Tensor,
};

use super::ParamId;
use core::marker::PhantomData;

/// Record used for constant type implementing the [module](crate::module::Module) trait.
#[derive(Debug, Clone, Copy, new)]
Expand Down Expand Up @@ -69,6 +66,18 @@ macro_rules! constant {
fn into_record(self) -> Self::Record {
burn::module::ConstantRecord::new()
}

fn to_device(self, _: &B::Device) -> Self {
self
}

fn fork(self, _: &B::Device) -> Self {
self
}

fn devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
devices
}
};

(ad_module, $type:ty) => {
Expand Down Expand Up @@ -113,27 +122,13 @@ constant!(i32);
constant!(i16);
constant!(i8);

impl<const D: usize, B: Backend> Module<B> for Tensor<B, D> {
impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
type Record = ConstantRecord;

fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
// Important:
// We need to implement visit method for Tensor Module because
// to_device will be called during the visit method of the ModuleVisitor

// We are using a dummy param id because the visit method requires a param id
let dummy_param_id = ParamId::new();
visitor.visit(&dummy_param_id, self)
}

fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
// Important:
// We need to implement visit method for Tensor Module because
// to_device will be called during the visit method of the ModuleVisitor
fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {}

// We are using a dummy param id because the visit method requires a param id
let dummy_param_id = ParamId::new();
mapper.map(&dummy_param_id, self)
fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
self
}

fn into_record(self) -> Self::Record {
Expand All @@ -143,10 +138,30 @@ impl<const D: usize, B: Backend> Module<B> for Tensor<B, D> {
fn load_record(self, _record: Self::Record) -> Self {
self
}

fn to_device(self, device: &B::Device) -> Self {
self.to_device(device)
}

fn fork(self, device: &B::Device) -> Self {
self.to_device(device)
}

fn devices(&self, mut devices: Devices<B>) -> Devices<B> {
let device = self.device();

if !devices.contains(&device) {
devices.push(device)
}

devices
}
}

impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Tensor<B, D> {
type InnerModule = Tensor<B::InnerBackend, D>;
impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
for Tensor<B, D, K>
{
type InnerModule = Tensor<B::InnerBackend, D, K::InnerKind>;

fn valid(&self) -> Self::InnerModule {
self.clone().inner()
Expand All @@ -171,6 +186,18 @@ impl<B: Backend> Module<B> for PhantomData<B> {
fn into_record(self) -> Self::Record {
ConstantRecord::new()
}

fn to_device(self, _: &<B as Backend>::Device) -> Self {
self
}

fn fork(self, _: &<B as Backend>::Device) -> Self {
self
}

fn devices(&self, devices: Devices<B>) -> Devices<B> {
devices
}
}

impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
Expand Down
Loading