Skip to content

Commit

Permalink
Feat/op fusion decorator (tracel-ai#939)
Browse files Browse the repository at this point in the history
* WIP

* Impl backend decorator

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* Refactor

* Handle graph single ops execution

* WIP

* Starting to get concrete

* WIP

* Fix locator

* Implement add ops

* Start implementing ops

* Add more ops

* Add more ops

* More float ops

* Almost finish float ops

* Almost done with Int

* Some fix

* Into float

* Implement bool ops

* Almost done with MVP

* Fix adaptive pooling

* Add fusion as backend

* Fix memory leak

* Fix

* WIP Doc

* Doc all ops enum

* Initial docs

* Clippy

* Clippy v2

* Fix typos

* Fix doc

* Fix feature flags

* Add missing ops

* Some cleanup

* Revert u128 id

* cosmetic fixes

---------

Co-authored-by: louisfd <louisfd94@gmail.com>
  • Loading branch information
nathanielsimard and louisfd authored Nov 10, 2023
1 parent 6011ae0 commit 322480b
Show file tree
Hide file tree
Showing 37 changed files with 7,547 additions and 60 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ resolver = "2"
members = [
"burn",
"burn-autodiff",
"burn-fusion",
"burn-candle",
"burn-common",
"burn-compute",
Expand Down
9 changes: 6 additions & 3 deletions burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ wasm-sync = ["burn-tensor/wasm-sync", "burn-common/wasm-sync"]

# Backend
autodiff = ["burn-autodiff"]
fusion = ["burn-fusion", "burn-wgpu?/fusion"]

ndarray = ["__ndarray", "burn-ndarray/default"]
ndarray-no-std = ["__ndarray", "burn-ndarray"]
Expand All @@ -50,8 +51,6 @@ ndarray-blas-openblas-system = [
]
__ndarray = [] # Internal flag to know when one ndarray feature is enabled.

wgpu = ["burn-wgpu/default"]

tch = ["burn-tch"]

candle = ["burn-candle"]
Expand All @@ -63,8 +62,11 @@ experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
test-tch = ["tch"] # To use tch during testing, default uses ndarray.
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.

wgpu = ["burn-wgpu/default"]

[dependencies]


# ** Please make sure all dependencies support no_std when std is disabled **

burn-common = { path = "../burn-common", version = "0.11.0", default-features = false }
Expand All @@ -74,8 +76,9 @@ burn-tensor = { path = "../burn-tensor", version = "0.11.0", default-features =

# Backends
burn-ndarray = { path = "../burn-ndarray", version = "0.11.0", optional = true, default-features = false }
burn-autodiff = { path = "../burn-autodiff", version = "0.11.0", optional = true }
burn-wgpu = { path = "../burn-wgpu", version = "0.11.0", optional = true }
burn-autodiff = { path = "../burn-autodiff", version = "0.11.0", optional = true }
burn-fusion = { path = "../burn-fusion", version = "0.11.0", optional = true }
burn-tch = { path = "../burn-tch", version = "0.11.0", optional = true }
burn-candle = { path = "../burn-candle", version = "0.11.0", optional = true }

Expand Down
3 changes: 3 additions & 0 deletions burn-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ pub use burn_autodiff as autodiff;
#[cfg(feature = "autodiff")]
pub use burn_autodiff::Autodiff;

#[cfg(feature = "fusion")]
pub use burn_fusion::Fusion;

#[cfg(feature = "wgpu")]
pub use burn_wgpu as wgpu;

Expand Down
19 changes: 19 additions & 0 deletions burn-fusion/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science"]
description = "Kernel fusion backend decorator for the Burn framework"
edition = "2021"
keywords = ["deep-learning", "machine-learning", "data"]
license = "MIT OR Apache-2.0"
name = "burn-fusion"
readme = "README.md"
repository = "https://github.com/burn-rs/burn/tree/main/burn-fusion"
version = "0.11.0"

[features]

[dependencies]
burn-tensor = {path = "../burn-tensor", version = "0.11.0", default-features = false }
burn-common = {path = "../burn-common", version = "0.11.0" }
derive-new = {workspace = true}
spin = {workspace = true}
3 changes: 3 additions & 0 deletions burn-fusion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Burn Fusion

A kernel fusion backend decorator for Burn.
150 changes: 150 additions & 0 deletions burn-fusion/src/backend.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use crate::{
client::FusionClient, graph::TensorOpsDescription, FusionClientLocator, FusionTensor,
HandleContainer,
};
use burn_tensor::{backend::Backend, Shape};
use core::marker::PhantomData;
use std::sync::Arc;

pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new();

pub(crate) fn get_client<B: FusionBackend>(device: &B::FusionDevice) -> B::FusionClient {
CLIENTS.client(device)
}

/// Enable dynamic operation fusion on a backend that implements [fusion backend](crate::FusionBackend).
#[derive(Clone, Debug, Default)]
pub struct Fusion<B> {
_backend: PhantomData<B>,
}

impl<B: FusionBackend> Backend for Fusion<B> {
type Device = B::Device;

// TODO: Find a better way to handle full precision.
type FullPrecisionBackend = Self;
type FullPrecisionElem = B::FloatElem;

type TensorPrimitive<const D: usize> = FusionTensor<B::FusionClient>;

type FloatElem = B::FloatElem;

type IntTensorPrimitive<const D: usize> = FusionTensor<B::FusionClient>;

type IntElem = B::IntElem;

type BoolTensorPrimitive<const D: usize> = FusionTensor<B::FusionClient>;

fn name() -> String {
format!("Fusion<{}>", B::name())
}

fn seed(seed: u64) {
B::seed(seed);
}
}

/// The status of a [fusion ops](FusionOps).
pub enum FusionStatus {
/// No more operations can be fused.
Closed(FusionProperties),
/// More operations can be fused.
Open(FusionProperties),
}

/// The properties of a [fusion ops](FusionOps).
#[derive(Debug, Clone, Copy, Default)]
pub struct FusionProperties {
/// The score of the optimization, higher is better.
pub score: u64,
/// If the operation is ready to be executed.
pub ready: bool,
}

/// The fusion operation abstraction allows implementations to fuse many
/// [tensor operations](TensorOpsDescription) into one, improving the performance of the backend.
///
///
/// # Notes
///
/// The implementations are free to execute the registered operations the way they want to improve
/// the speed and efficiency of the computational graph. It doesn't mean that all registered
/// operations should be fused, but that another way of executing them is more efficient.
///
/// Also, it is important to return (FusionStatus::Closed) when no more registered operation can
/// improve the performance.
pub trait FusionOps<B: FusionBackend>: Send {
/// Register a new [tensor operation](TensorOpsDescription).
///
/// The return value should be either [closed](FusionStatus::Closed) or
/// [open](FusionStatus::Open).
///
/// When [closed](FusionStatus::Closed), it's assumed that no more operation can be added
/// to the current fusion operation. No [tensor operation](TensorOpsDescription) can be
/// ignored, they are either accepted or rejected, and the [status](FusionStatus) describes it.
fn register(&mut self, ops: Arc<TensorOpsDescription<B>>) -> FusionStatus;
/// Execute the operation.
fn execute(&mut self, handles: &mut HandleContainer<B>);
/// Reset the state.
fn reset(&mut self);
/// The size of operations fused.
fn len(&self) -> usize;
/// If the current operation is empty.
fn is_empty(&self) -> bool {
self.len() == 0
}
}

/// The device id.
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
pub struct DeviceId {
/// The type id identifies the type of the device.
pub type_id: u16,
/// The index id identifies the device number.
pub index_id: u32,
}

/// The handle device trait allows to get an id for a backend device.
pub trait FusionDevice: Clone + Send + Sync + PartialEq {
/// Return the [device id](DeviceId).
fn id(&self) -> DeviceId;
}

/// Trait that allows an existing [backend](Backend) to specify graph optimizations using
/// [fusion operation](crate::FusionOps).
pub trait FusionBackend: Backend {
/// The device type that can return an ID.
///
/// It can be the same as (Backend::Device), but must implement (FusionDevice).
type FusionDevice: FusionDevice + From<Self::Device> + Into<Self::Device>;
/// The type that can be used to point to a tensor of any kind.
type Handle: Sync + Send + Clone;
/// What kind of client should be used.
type FusionClient: FusionClient<FusionBackend = Self>;

/// The list of operations that will be used to optimize the computational graph.
fn operations() -> Vec<Box<dyn FusionOps<Self>>>;

/// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::TensorPrimitive).
fn float_tensor<const D: usize>(
handle: Self::Handle,
shape: Shape<D>,
) -> Self::TensorPrimitive<D>;
/// Convert a [handle](FusionBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive).
fn int_tensor<const D: usize>(
handle: Self::Handle,
shape: Shape<D>,
) -> Self::IntTensorPrimitive<D>;
/// Convert a [handle](FusionBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive).
fn bool_tensor<const D: usize>(
handle: Self::Handle,
shape: Shape<D>,
) -> Self::BoolTensorPrimitive<D>;

/// Convert a [float tensor](Backend::TensorPrimitive) to a [handle](FusionBackend::Handle).
fn float_tensor_handle<const D: usize>(tensor: Self::TensorPrimitive<D>) -> Self::Handle;
/// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](FusionBackend::Handle).
fn int_tensor_handle<const D: usize>(tensor: Self::IntTensorPrimitive<D>) -> Self::Handle;
/// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](FusionBackend::Handle).
fn bool_tensor_handle<const D: usize>(tensor: Self::BoolTensorPrimitive<D>) -> Self::Handle;
}
73 changes: 73 additions & 0 deletions burn-fusion/src/client/base.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use crate::{
graph::{GraphExecution, TensorOpsDescription},
FusionBackend, FusionTensor, TensorDescription, TensorId,
};
use burn_tensor::{
ops::{FloatElem, IntElem},
Data, Reader,
};

/// Define how to interact with the fusion server.
pub trait FusionClient: Send + Sync + Clone {
/// The [fusion backend](FusionBackend) associated type.
type FusionBackend: FusionBackend;
/// The [graph execution](GraphExecution) associated type.
type GraphExecution: GraphExecution<Self::FusionBackend>;

/// Create a new client for the given [fusion device](FusionBackend::FusionDevice).
fn new(device: <Self::FusionBackend as FusionBackend>::FusionDevice) -> Self;
/// Register a new [tensor operation description](TensorOpsDescription).
fn register(&self, ops: TensorOpsDescription<Self::FusionBackend>);
/// Sync the computation.
fn sync(&self);
/// Get the current device used by all operations handled by this client.
fn device(&self) -> &<Self::FusionBackend as FusionBackend>::FusionDevice;
/// Create an empty tensor.
fn create_tensor_empty(&self, shape: Vec<usize>) -> FusionTensor<Self>;
/// Create a float tensor with the given values.
fn create_tensor_float(
&self,
values: Vec<FloatElem<Self::FusionBackend>>,
shape: Vec<usize>,
) -> FusionTensor<Self>;
/// Create an integer tensor with the given values.
fn create_tensor_int(
&self,
values: Vec<IntElem<Self::FusionBackend>>,
shape: Vec<usize>,
) -> FusionTensor<Self>;
/// Create a bool tensor with the given values.
fn create_tensor_bool(&self, values: Vec<bool>, shape: Vec<usize>) -> FusionTensor<Self>;
/// Read the values contained by a float tensor.
fn read_tensor_float<const D: usize>(
&self,
tensor: TensorDescription,
) -> Reader<Data<FloatElem<Self::FusionBackend>, D>>;
/// Read the values contained by an int tensor.
fn read_tensor_int<const D: usize>(
&self,
tensor: TensorDescription,
) -> Reader<Data<IntElem<Self::FusionBackend>, D>>;
/// Read the values contained by a bool tensor.
fn read_tensor_bool<const D: usize>(&self, tensor: TensorDescription) -> Reader<Data<bool, D>>;
/// Change the client of the given float tensor.
fn change_client_float<const D: usize>(
&self,
tensor: TensorDescription,
client: Self,
) -> FusionTensor<Self>;
/// Change the client of the given int tensor.
fn change_client_int<const D: usize>(
&self,
tensor: TensorDescription,
client: Self,
) -> FusionTensor<Self>;
/// Change the client of the given bool tensor.
fn change_client_bool<const D: usize>(
&self,
tensor: TensorDescription,
client: Self,
) -> FusionTensor<Self>;
/// Drop the tensor with the given [tensor id](TensorId).
fn register_orphan(&self, id: &TensorId);
}
5 changes: 5 additions & 0 deletions burn-fusion/src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod base;
mod mutex;

pub use base::*;
pub use mutex::*;
Loading

0 comments on commit 322480b

Please sign in to comment.