Skip to content

Commit

Permalink
WGPU: Support elemwise operation fusion (tracel-ai#948)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Nov 15, 2023
1 parent 4fc0c27 commit 24014ac
Show file tree
Hide file tree
Showing 30 changed files with 1,712 additions and 361 deletions.
1 change: 1 addition & 0 deletions backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ ndarray-blas-openblas = ["burn/ndarray-blas-openblas"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
wgpu-fusion = ["burn/wgpu", "burn/fusion"]

[dependencies]
burn = { path = "../burn" }
Expand Down
8 changes: 8 additions & 0 deletions backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
#[macro_export]
macro_rules! bench_on_backend {
() => {
#[cfg(feature = "wgpu-fusion")]
{
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
use burn::backend::Fusion;

bench::<Fusion<Wgpu<AutoGraphicsApi, f32, i32>>>(&WgpuDevice::default());
}

#[cfg(feature = "wgpu")]
{
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
Expand Down
14 changes: 10 additions & 4 deletions burn-fusion/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
client::FusionClient, graph::TensorOpsDescription, FusionClientLocator, FusionTensor,
HandleContainer,
};
use burn_tensor::{backend::Backend, Shape};
use burn_tensor::{backend::Backend, Device, Shape};
use core::marker::PhantomData;
use std::sync::Arc;

Expand Down Expand Up @@ -36,12 +36,18 @@ impl<B: FusionBackend> Backend for Fusion<B> {
type BoolTensorPrimitive<const D: usize> = FusionTensor<B::FusionClient>;

fn name() -> String {
format!("Fusion<{}>", B::name())
format!("fusion<{}>", B::name())
}

fn seed(seed: u64) {
B::seed(seed);
}

fn sync(device: &Self::Device) {
let client = CLIENTS.client::<B::FusionClient>(&device.clone().into());
client.drain_graph();
B::sync(device)
}
}

/// The status of a [fusion ops](FusionOps).
Expand Down Expand Up @@ -116,14 +122,14 @@ pub trait FusionBackend: Backend {
/// The device type that can return an ID.
///
/// It can be the same as (Backend::Device), but must implement (FusionDevice).
type FusionDevice: FusionDevice + From<Self::Device> + Into<Self::Device>;
type FusionDevice: FusionDevice + From<Self::Device> + Into<Self::Device> + core::fmt::Debug;
/// The type that can be used to point to a tensor of any kind.
type Handle: Sync + Send + Clone;
/// What kind of client should be used.
type FusionClient: FusionClient<FusionBackend = Self>;

/// The list of operations that will be used to optimize the computational graph.
fn operations() -> Vec<Box<dyn FusionOps<Self>>>;
fn operations(device: &Device<Self>) -> Vec<Box<dyn FusionOps<Self>>>;

/// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::TensorPrimitive).
fn float_tensor<const D: usize>(
Expand Down
24 changes: 8 additions & 16 deletions burn-fusion/src/client/base.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
graph::{GraphExecution, TensorOpsDescription},
FusionBackend, FusionTensor, TensorDescription, TensorId,
FusionBackend, FusionTensor, Handle, TensorDescription, TensorId,
};
use burn_tensor::{
ops::{FloatElem, IntElem},
Expand All @@ -18,26 +18,18 @@ pub trait FusionClient: Send + Sync + Clone {
fn new(device: <Self::FusionBackend as FusionBackend>::FusionDevice) -> Self;
/// Register a new [tensor operation description](TensorOpsDescription).
fn register(&self, ops: TensorOpsDescription<Self::FusionBackend>);
/// Sync the computation.
fn sync(&self);
/// Register all lazy computation.
fn drain_graph(&self);
/// Get the current device used by all operations handled by this client.
fn device(&self) -> &<Self::FusionBackend as FusionBackend>::FusionDevice;
/// Create an empty tensor.
fn create_tensor_empty(&self, shape: Vec<usize>) -> FusionTensor<Self>;
/// Create a float tensor with the given values.
fn create_tensor_float(
/// Create a new [fusion tensor](FusionTensor), but with no resources allocated to it.
fn tensor_uninitialized(&self, shape: Vec<usize>) -> FusionTensor<Self>;
/// Create a tensor with the given handle and shape.
fn register_tensor(
&self,
values: Vec<FloatElem<Self::FusionBackend>>,
handle: Handle<Self::FusionBackend>,
shape: Vec<usize>,
) -> FusionTensor<Self>;
/// Create an integer tensor with the given values.
fn create_tensor_int(
&self,
values: Vec<IntElem<Self::FusionBackend>>,
shape: Vec<usize>,
) -> FusionTensor<Self>;
/// Create a bool tensor with the given values.
fn create_tensor_bool(&self, values: Vec<bool>, shape: Vec<usize>) -> FusionTensor<Self>;
/// Read the values contained by a float tensor.
fn read_tensor_float<const D: usize>(
&self,
Expand Down
43 changes: 15 additions & 28 deletions burn-fusion/src/client/mutex.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::FusionClient;
use crate::{
graph::{GraphExecution, TensorOpsDescription},
FusionBackend, FusionServer, FusionTensor,
FusionBackend, FusionServer, FusionTensor, Handle,
};
use burn_tensor::ops::FloatElem;
use spin::Mutex;
Expand Down Expand Up @@ -49,10 +49,11 @@ where
self.server.lock().register(ops);
}

fn sync(&self) {
self.server.lock().sync();
fn drain_graph(&self) {
self.server.lock().drain_graph();
}
fn create_tensor_empty(&self, shape: Vec<usize>) -> FusionTensor<Self> {

fn tensor_uninitialized(&self, shape: Vec<usize>) -> FusionTensor<Self> {
let id = self.server.lock().create_empty_handle();

FusionTensor::new(id, shape, self.clone())
Expand All @@ -61,38 +62,24 @@ where
fn device(&self) -> &<Self::FusionBackend as FusionBackend>::FusionDevice {
&self.device
}

fn read_tensor_float<const D: usize>(
fn register_tensor(
&self,
tensor: crate::TensorDescription,
) -> burn_tensor::Reader<burn_tensor::Data<FloatElem<Self::FusionBackend>, D>> {
self.server.lock().read_float(tensor)
}

fn create_tensor_float(
&self,
values: Vec<FloatElem<Self::FusionBackend>>,
handle: Handle<Self::FusionBackend>,
shape: Vec<usize>,
) -> FusionTensor<Self> {
let id = self.server.lock().create_float_handle(values);
let mut server = self.server.lock();
let id = server.create_empty_handle();
server.handles.register_handle(id.as_ref().clone(), handle);
core::mem::drop(server);

FusionTensor::new(id, shape, self.clone())
}

fn create_tensor_int(
fn read_tensor_float<const D: usize>(
&self,
values: Vec<burn_tensor::ops::IntElem<Self::FusionBackend>>,
shape: Vec<usize>,
) -> FusionTensor<Self> {
let id = self.server.lock().create_int_handle(values);

FusionTensor::new(id, shape, self.clone())
}

fn create_tensor_bool(&self, values: Vec<bool>, shape: Vec<usize>) -> FusionTensor<Self> {
let id = self.server.lock().create_bool_handle(values);

FusionTensor::new(id, shape, self.clone())
tensor: crate::TensorDescription,
) -> burn_tensor::Reader<burn_tensor::Data<FloatElem<Self::FusionBackend>, D>> {
self.server.lock().read_float(tensor)
}

fn read_tensor_int<const D: usize>(
Expand Down
9 changes: 7 additions & 2 deletions burn-fusion/src/graph/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ impl<B: FusionBackend> GraphExecution<B> for GreedyGraphExecution {
}

match find_best_optimization_index(optimizations) {
Some(index) => graph.execute_optimization(handles, index, optimizations),
None => graph.execute(handles),
Some(index) => {
graph.execute_optimization(handles, index, optimizations);
}
None => {
graph.execute(handles);
optimizations.iter_mut().for_each(|ops| ops.reset());
}
}

if graph.is_empty() {
Expand Down
Loading

0 comments on commit 24014ac

Please sign in to comment.