Skip to content

Commit

Permalink
feat: inplace tensor api. (tracel-ai#187)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Mar 1, 2023
1 parent 25deb5a commit e6e7f4d
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 30 deletions.
10 changes: 8 additions & 2 deletions burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ impl<B: Backend> ADBackend for ADBackendDecorator<B> {
type InnerBackend = B;
type Gradients = Gradients;

fn backward<const D: usize>(tensor: &ADTensor<B, D>) -> Gradients {
backward(tensor.clone())
fn backward<const D: usize>(tensor: ADTensor<B, D>) -> Gradients {
backward(tensor)
}

fn grad<const D: usize>(
Expand All @@ -43,6 +43,12 @@ impl<B: Backend> ADBackend for ADBackendDecorator<B> {
grads.get(tensor)
}

fn grad_remove<const D: usize>(
tensor: &ADTensor<B, D>,
grads: &mut Gradients,
) -> Option<B::TensorPrimitive<D>> {
grads.remove(tensor)
}
fn inner<const D: usize>(tensor: &ADTensor<B, D>) -> B::TensorPrimitive<D> {
tensor.primitive.clone()
}
Expand Down
10 changes: 10 additions & 0 deletions burn-autodiff/src/grads.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ impl Gradients {
}
}

/// Remove a grad tensor from the container.
pub fn remove<B: Backend, const D: usize>(
&mut self,
tensor: &ADTensor<B, D>,
) -> Option<TensorPrimitive<B, D>> {
self.container
.remove::<B, D>(&tensor.node.id.value)
.map(|tensor| tensor.into_primitive())
}

pub fn get<B: Backend, const D: usize>(
&self,
tensor: &ADTensor<B, D>,
Expand Down
6 changes: 3 additions & 3 deletions burn-common/src/id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ mod tests {

#[test]
fn not_empty_test() {
assert_eq!(IdGenerator::generate().is_empty(), false);
assert!(!IdGenerator::generate().is_empty());
}

#[test]
Expand All @@ -47,7 +47,7 @@ mod tests {
let mut set: BTreeSet<String> = BTreeSet::new();

for _i in 0..IDS_CNT {
assert_eq!(set.insert(IdGenerator::generate()), true);
assert!(set.insert(IdGenerator::generate()));
}

assert_eq!(set.len(), IDS_CNT);
Expand All @@ -69,7 +69,7 @@ mod tests {

let handle = thread::spawn(move || {
for _i in 0..NUM_REPEATS {
assert_eq!(set.insert(IdGenerator::generate()), true);
assert!(set.insert(IdGenerator::generate()));
}
});
handles.push(handle);
Expand Down
18 changes: 10 additions & 8 deletions burn-core/src/optim/adam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,16 @@ impl<B: ADBackend> Optimizer for Adam<B> {
tensor: &mut Tensor<B, D>,
grad: Tensor<B::InnerBackend, D>,
) {
let grad = match &mut self.weight_decay {
Some(weight_decay) => weight_decay.transform(id, grad),
None => grad,
};
let grad = self.momentum.transform::<B, D>(id, grad);

let delta = grad.mul_scalar(self.learning_rate);
tensor.update(tensor.inner() - delta);
tensor.inplace(grad, |tensor, grad| {
let grad = match &mut self.weight_decay {
Some(weight_decay) => weight_decay.transform(id, grad),
None => grad,
};
let grad = self.momentum.transform::<B, D>(id, grad);
let delta = grad.mul_scalar(self.learning_rate);

Tensor::from_inner(tensor.inner() - delta)
})
}

fn register_param_state<const D: usize>(&self, id: &ParamId, state: &mut StateNamed<B::Elem>) {
Expand Down
24 changes: 13 additions & 11 deletions burn-core/src/optim/sgd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,19 @@ impl<B: ADBackend> Optimizer for Sgd<B> {
tensor: &mut Tensor<B, D>,
grad: Tensor<B::InnerBackend, D>,
) {
let grad = match &mut self.weight_decay {
Some(weight_decay) => weight_decay.transform(id, grad),
None => grad,
};
let grad = match &mut self.momentum {
Some(momentum) => momentum.transform(id, grad),
None => grad,
};

let delta = grad.mul_scalar(self.learning_rate);
tensor.update(tensor.inner() - delta);
tensor.inplace(grad, |tensor, grad| {
let grad = match &mut self.weight_decay {
Some(weight_decay) => weight_decay.transform(id, grad),
None => grad,
};
let grad = match &mut self.momentum {
Some(momentum) => momentum.transform(id, grad),
None => grad,
};
let delta = grad.mul_scalar(self.learning_rate);

Tensor::from_inner(tensor.inner() - delta)
})
}

fn register_param_state<const D: usize>(&self, id: &ParamId, state: &mut StateNamed<B::Elem>) {
Expand Down
4 changes: 2 additions & 2 deletions burn-core/src/optim/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl<'a, B: ADBackend, O: Optimizer<Backend = B>> ModuleVisitorMut<B>
for ModuleTensorUpdater<'a, O>
{
fn visit_mut<const D: usize>(&mut self, id: &ParamId, tensor: &mut Tensor<B, D>) {
if let Some(grad) = self.grads.get::<B::InnerBackend, D>(id) {
if let Some(grad) = self.grads.remove::<B::InnerBackend, D>(id) {
self.optimizer.update_tensor(id, tensor, grad);
}
}
Expand All @@ -64,7 +64,7 @@ impl<'a, B: ADBackend, O: Optimizer<Backend = B>> ModuleVisitor<B> for Gradients

impl<'a, B: ADBackend> ModuleVisitor<B> for GradientsParamsConverter<'a, B> {
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>) {
if let Some(grad) = tensor.grad(&self.grads) {
if let Some(grad) = tensor.grad_remove(&mut self.grads) {
self.grads_params
.register::<B::InnerBackend, D>(id.clone(), grad);
}
Expand Down
6 changes: 5 additions & 1 deletion burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,15 @@ pub trait ADBackend: Backend {
type InnerBackend: Backend<Device = Self::Device, Elem = Self::Elem>;
type Gradients: Send + Sync;

fn backward<const D: usize>(tensor: &Self::TensorPrimitive<D>) -> Self::Gradients;
fn backward<const D: usize>(tensor: Self::TensorPrimitive<D>) -> Self::Gradients;
fn grad<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
grads: &Self::Gradients,
) -> Option<ADBackendTensorPrimitive<D, Self>>;
fn grad_remove<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
grads: &mut Self::Gradients,
) -> Option<ADBackendTensorPrimitive<D, Self>>;
fn inner<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
) -> <Self::InnerBackend as Backend>::TensorPrimitive<D>;
Expand Down
35 changes: 32 additions & 3 deletions burn-tensor/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,11 @@ where
Self::new(tensor)
}

/// Create an empty tensor of the given shape.
pub fn empty<S: Into<Shape<D>>>(shape: S) -> Self {
Tensor::new(B::empty(shape.into(), &B::Device::default()))
}

/// Create a tensor of the given shape where each element is one.
pub fn ones_device<S: Into<Shape<D>>>(shape: S, device: &B::Device) -> Self {
let tensor = B::ones(shape.into(), device);
Expand Down Expand Up @@ -731,19 +736,43 @@ where

impl<const D: usize, B: ADBackend> Tensor<B, D> {
pub fn backward(&self) -> B::Gradients {
B::backward::<D>(&self.value)
B::backward::<D>(self.value.clone())
}

/// Get the gradients of a tensor if it exist.
///
/// Returns a new reference to the same tensor. Therefore the same grad tensor can
/// be accessed multiple times. If you only need to get the gradients one time,
/// consider using [grad_remove](Tensor::grad_remove) for better performance.
pub fn grad(&self, grads: &B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
B::grad(&self.value, grads).map(Tensor::new)
}

/// Remove the grad tensor from the [grads](ADBackend::Gradients) struct returning the result.
pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
B::grad_remove(&self.value, grads).map(Tensor::new)
}

pub fn inner(&self) -> Tensor<B::InnerBackend, D> {
Tensor::new(B::inner(&self.value))
}

pub fn update(&mut self, other_inner: Tensor<B::InnerBackend, D>) {
self.value = B::from_inner(other_inner.value);
/// Executes an operation on the tensor and modifies its value.
///
/// # Notes
///
/// This won't necessary reuse the same tensor data/buffer, but it should if there is
/// no other reference pointing to the same tensor.
///
/// Wrapping operations with inplace is not an optimization, it's mainly there if you
/// want to mutate a tensor by using owned operations. A plausible usage would be to
/// update the weights of a mutable model reference.
pub fn inplace<F: FnOnce(Self, Args) -> Self, Args>(&mut self, args: Args, func: F) {
let mut tensor_owned = Tensor::empty([0; D]);
core::mem::swap(&mut tensor_owned, self);

let mut tensor_new = func(tensor_owned, args);
core::mem::swap(&mut tensor_new, self);
}

pub fn from_inner(inner: Tensor<B::InnerBackend, D>) -> Self {
Expand Down

0 comments on commit e6e7f4d

Please sign in to comment.