Skip to content

Commit

Permalink
[Breaking] Refactor Backend Names (tracel-ai#904)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Oct 29, 2023
1 parent e2a3329 commit 96524d4
Show file tree
Hide file tree
Showing 185 changed files with 1,734 additions and 1,788 deletions.
4 changes: 2 additions & 2 deletions ARCHITECTURE.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ Since float and int can have multiple precisions, the float and int element type
Note that the backend chooses the precision and not the user.
Since not all backends will support the same element types, no assumptions must be made.
Therefore, there are no methods on tensors to change the precision, except for the `to_full_precision` function, which ensures numerical stability on the current backend.
Backend implementations can provide a way to choose the precision, which can be accomplished with a generic parameter (e.g. `NdArrayBackend<f32>`).
Backend implementations can provide a way to choose the precision, which can be accomplished with a generic parameter (e.g. `NdArray<f32>`).

To be as general as possible, tensor operations are implemented as plain functions.
There is no object or self, just functions that take tensors as input and often return tensors as output as well.
Expand All @@ -218,5 +218,5 @@ Note that Burn is a dynamic graph deep learning framework, so backends may have

As of now, there is only one backend decorator that supports autodiff.
It follows the decorator pattern, making any backend differentiable.
However, the `ADBackend` trait abstracts how gradients are calculated, and other approaches to autodiff might be added later.
However, the `AutodiffBackend` trait abstracts how gradients are calculated, and other approaches to autodiff might be added later.
For more information about how the current autodiff backend works, you can read this [blog post](https://burn-rs.github.io/blog/burn-rusty-approach-to-tensor-handling).
30 changes: 15 additions & 15 deletions backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,28 @@ macro_rules! bench_on_backend {
() => {
#[cfg(feature = "wgpu")]
{
use burn::backend::wgpu::{AutoGraphicsApi, WgpuBackend, WgpuDevice};
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};

bench::<WgpuBackend<AutoGraphicsApi, f32, i32>>(&WgpuDevice::default());
bench::<Wgpu<AutoGraphicsApi, f32, i32>>(&WgpuDevice::default());
}

#[cfg(feature = "tch-gpu")]
{
use burn::backend::{tch::TchDevice, TchBackend};
use burn::backend::{libtorch::LibTorchDevice, LibTorch};

#[cfg(not(target_os = "macos"))]
let device = TchDevice::Cuda(0);
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = TchDevice::Mps;
bench::<TchBackend>(&device);
let device = LibTorchDevice::Mps;
bench::<LibTorch>(&device);
}

#[cfg(feature = "tch-cpu")]
{
use burn::backend::{tch::TchDevice, TchBackend};
use burn::backend::{libtorch::LibTorchDevice, LibTorch};

let device = TchDevice::Cpu;
bench::<TchBackend>(&device);
let device = LibTorchDevice::Cpu;
bench::<LibTorch>(&device);
}

#[cfg(any(
Expand All @@ -35,28 +35,28 @@ macro_rules! bench_on_backend {
))]
{
use burn::backend::ndarray::NdArrayDevice;
use burn::backend::NdArrayBackend;
use burn::backend::NdArray;

let device = NdArrayDevice::Cpu;
bench::<NdArrayBackend>(&device);
bench::<NdArray>(&device);
}

#[cfg(feature = "candle-cpu")]
{
use burn::backend::candle::CandleDevice;
use burn::backend::CandleBackend;
use burn::backend::Candle;

let device = CandleDevice::Cpu;
bench::<CandleBackend>(&device);
bench::<Candle>(&device);
}

#[cfg(feature = "candle-cuda")]
{
use burn::backend::candle::CandleDevice;
use burn::backend::CandleBackend;
use burn::backend::Candle;

let device = CandleDevice::Cuda(0);
bench::<CandleBackend>(&device);
bench::<Candle>(&device);
}
};
}
38 changes: 21 additions & 17 deletions burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
use crate::{grads::Gradients, graph::backward::backward, tensor::ADTensor};
use burn_tensor::backend::{ADBackend, Backend};

/// A decorator for a backend that enables automatic differentiation.
use crate::{grads::Gradients, graph::backward::backward, tensor::AutodiffTensor};
use burn_tensor::backend::{AutodiffBackend, Backend};
use core::marker::PhantomData;

/// Enable auto-differentiation on a backend.
///
/// This works as a backend decorator, extending the functionality of any backend with
/// backpropagation.
#[derive(Clone, Copy, Debug, Default)]
pub struct ADBackendDecorator<B> {
_b: B,
pub struct Autodiff<B> {
_b: PhantomData<B>,
}

impl<B: Backend> Backend for ADBackendDecorator<B> {
impl<B: Backend> Backend for Autodiff<B> {
type Device = B::Device;

type FullPrecisionElem = B::FullPrecisionElem;
type FullPrecisionBackend = ADBackendDecorator<B::FullPrecisionBackend>;
type FullPrecisionBackend = Autodiff<B::FullPrecisionBackend>;

type TensorPrimitive<const D: usize> = ADTensor<B, D>;
type TensorPrimitive<const D: usize> = AutodiffTensor<B, D>;
type FloatElem = B::FloatElem;

type IntTensorPrimitive<const D: usize> = B::IntTensorPrimitive<D>;
Expand All @@ -38,37 +42,37 @@ impl<B: Backend> Backend for ADBackendDecorator<B> {
}
}

impl<B: Backend> ADBackend for ADBackendDecorator<B> {
impl<B: Backend> AutodiffBackend for Autodiff<B> {
type InnerBackend = B;
type Gradients = Gradients;

fn backward<const D: usize>(tensor: ADTensor<B, D>) -> Gradients {
fn backward<const D: usize>(tensor: AutodiffTensor<B, D>) -> Gradients {
backward(tensor)
}

fn grad<const D: usize>(
tensor: &ADTensor<B, D>,
tensor: &AutodiffTensor<B, D>,
grads: &Gradients,
) -> Option<B::TensorPrimitive<D>> {
grads.get(tensor)
}

fn grad_remove<const D: usize>(
tensor: &ADTensor<B, D>,
tensor: &AutodiffTensor<B, D>,
grads: &mut Gradients,
) -> Option<B::TensorPrimitive<D>> {
grads.remove(tensor)
}
fn inner<const D: usize>(tensor: ADTensor<B, D>) -> B::TensorPrimitive<D> {
fn inner<const D: usize>(tensor: AutodiffTensor<B, D>) -> B::TensorPrimitive<D> {
tensor.primitive
}

fn from_inner<const D: usize>(tensor: B::TensorPrimitive<D>) -> ADTensor<B, D> {
ADTensor::new(tensor)
fn from_inner<const D: usize>(tensor: B::TensorPrimitive<D>) -> AutodiffTensor<B, D> {
AutodiffTensor::new(tensor)
}

fn grad_replace<const D: usize>(
tensor: &ADTensor<B, D>,
tensor: &AutodiffTensor<B, D>,
grads: &mut Self::Gradients,
grad: B::TensorPrimitive<D>,
) {
Expand Down
6 changes: 3 additions & 3 deletions burn-autodiff/src/grads.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use burn_tensor::{backend::Backend, container::TensorContainer, Tensor};

use crate::{
graph::{NodeRef, Requirement},
tensor::ADTensor,
tensor::AutodiffTensor,
};

/// Gradient identifier.
Expand Down Expand Up @@ -54,7 +54,7 @@ impl Gradients {
/// Removes a grad tensor from the container.
pub fn remove<B: Backend, const D: usize>(
&mut self,
tensor: &ADTensor<B, D>,
tensor: &AutodiffTensor<B, D>,
) -> Option<TensorPrimitive<B, D>> {
self.container
.remove::<B, D>(&tensor.node.id.value)
Expand All @@ -64,7 +64,7 @@ impl Gradients {
/// Gets a grad tensor from the container.
pub fn get<B: Backend, const D: usize>(
&self,
tensor: &ADTensor<B, D>,
tensor: &AutodiffTensor<B, D>,
) -> Option<TensorPrimitive<B, D>> {
self.container
.get::<B, D>(&tensor.node.id.value)
Expand Down
4 changes: 2 additions & 2 deletions burn-autodiff/src/graph/backward.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use burn_tensor::backend::Backend;

use crate::{grads::Gradients, tensor::ADTensor};
use crate::{grads::Gradients, tensor::AutodiffTensor};

use super::{traversal::BreadthFirstSearch, Graph, NodeRef, StepBoxed};

pub fn backward<B: Backend, const D: usize>(root: ADTensor<B, D>) -> Gradients {
pub fn backward<B: Backend, const D: usize>(root: AutodiffTensor<B, D>) -> Gradients {
let grads = Gradients::new::<B, D>(root.node.clone(), root.primitive);
let tape = build_tape(root.node, root.graph);

Expand Down
14 changes: 8 additions & 6 deletions burn-autodiff/src/ops/activation.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use crate::{
grads::Gradients,
ops::{unary, Backward, Ops, OpsKind},
tensor::ADTensor,
ADBackendDecorator,
Autodiff,
};
use burn_tensor::{
backend::Backend,
ops::{ActivationOps, FloatTensor},
};
use burn_tensor::{backend::Backend, ops::ActivationOps};

impl<B: Backend> ActivationOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn gelu<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
impl<B: Backend> ActivationOps<Autodiff<B>> for Autodiff<B> {
fn gelu<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
#[derive(Debug)]
struct Gelu<const D: usize>;

Expand All @@ -32,7 +34,7 @@ impl<B: Backend> ActivationOps<ADBackendDecorator<B>> for ADBackendDecorator<B>
}
}

fn relu<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
fn relu<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
#[derive(Debug)]
struct Relu;

Expand Down
16 changes: 10 additions & 6 deletions burn-autodiff/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
graph::{
NodeRef, Requirement, {Graph, Step},
},
tensor::ADTensor,
tensor::AutodiffTensor,
};
use burn_tensor::{backend::Backend, Shape};
use std::marker::PhantomData;
Expand Down Expand Up @@ -37,7 +37,7 @@ where
BO: Backward<B, D, N, State = ()>,
{
/// Prepare a stateless operation.
pub fn stateless(self, output: <B as Backend>::TensorPrimitive<D>) -> ADTensor<B, D> {
pub fn stateless(self, output: <B as Backend>::TensorPrimitive<D>) -> AutodiffTensor<B, D> {
match self.stateful() {
OpsKind::Tracked(prep) => prep.finish((), output),
OpsKind::UnTracked(prep) => prep.finish(output),
Expand Down Expand Up @@ -77,8 +77,8 @@ where
BO: Backward<B, D, N, State = S>,
{
/// Finish the preparation of an untracked operation and returns the output tensor.
pub fn finish(self, output: <B as Backend>::TensorPrimitive<D>) -> ADTensor<B, D> {
ADTensor::from_parents(
pub fn finish(self, output: <B as Backend>::TensorPrimitive<D>) -> AutodiffTensor<B, D> {
AutodiffTensor::from_parents(
output,
&self.nodes,
self.graphs.into_iter(),
Expand All @@ -94,8 +94,12 @@ where
BO: Backward<B, D, N, State = S>,
{
/// Finish the preparation of a tracked operation and returns the output tensor.
pub fn finish(self, state: S, output: <B as Backend>::TensorPrimitive<D>) -> ADTensor<B, D> {
let output = ADTensor::from_parents(
pub fn finish(
self,
state: S,
output: <B as Backend>::TensorPrimitive<D>,
) -> AutodiffTensor<B, D> {
let output = AutodiffTensor::from_parents(
output,
&self.nodes,
self.graphs.into_iter(),
Expand Down
32 changes: 15 additions & 17 deletions burn-autodiff/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::{
tensor::{ADTensor, BoolTensor, IntTensor},
ADBackendDecorator,
};
use crate::{tensor::AutodiffTensor, Autodiff};

use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Reader, Shape};
use burn_tensor::{
backend::Backend,
ops::{BoolTensor, BoolTensorOps, IntTensor},
Data, Device, Reader, Shape,
};

impl<B: Backend> BoolTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn bool_from_data<const D: usize>(data: Data<bool, D>, device: &B::Device) -> BoolTensor<B, D> {
impl<B: Backend> BoolTensorOps<Self> for Autodiff<B> {
fn bool_from_data<const D: usize>(data: Data<bool, D>, device: &Device<B>) -> BoolTensor<B, D> {
B::bool_from_data(data, device)
}

Expand All @@ -28,12 +29,12 @@ impl<B: Backend> BoolTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B>

fn bool_to_device<const D: usize>(
tensor: BoolTensor<B, D>,
device: &B::Device,
device: &Device<B>,
) -> BoolTensor<B, D> {
B::bool_to_device(tensor, device)
}

fn bool_device<const D: usize>(tensor: &BoolTensor<B, D>) -> B::Device {
fn bool_device<const D: usize>(tensor: &BoolTensor<B, D>) -> Device<B> {
B::bool_device(tensor)
}

Expand All @@ -51,10 +52,7 @@ impl<B: Backend> BoolTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B>
B::bool_slice(tensor, ranges)
}

fn bool_empty<const D: usize>(
shape: Shape<D>,
device: &<ADBackendDecorator<B> as Backend>::Device,
) -> BoolTensor<B, D> {
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<B>) -> BoolTensor<B, D> {
B::bool_empty(shape, device)
}

Expand Down Expand Up @@ -83,15 +81,15 @@ impl<B: Backend> BoolTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B>

fn bool_into_float<const D: usize>(
tensor: BoolTensor<B, D>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
ADTensor::new(B::bool_into_float(tensor))
) -> <Autodiff<B> as Backend>::TensorPrimitive<D> {
AutodiffTensor::new(B::bool_into_float(tensor))
}

fn bool_swap_dims<const D: usize>(
tensor: <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D>,
tensor: <Autodiff<B> as Backend>::BoolTensorPrimitive<D>,
dim1: usize,
dim2: usize,
) -> <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D> {
) -> <Autodiff<B> as Backend>::BoolTensorPrimitive<D> {
B::bool_swap_dims(tensor, dim1, dim2)
}
}
Loading

0 comments on commit 96524d4

Please sign in to comment.