Skip to content

Commit

Permalink
[Fusion] Support multi-precision fusion (tracel-ai#1718)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored May 2, 2024
1 parent 6b14bb8 commit 5d959e2
Show file tree
Hide file tree
Showing 46 changed files with 2,039 additions and 1,434 deletions.
8 changes: 0 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions crates/burn-common/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ impl<T> Reader<T> {
}

/// Map the current reader to another type.
pub fn map<O, F: FnOnce(T) -> O>(self, mapper: F) -> Reader<O>
pub fn map<O, F>(self, mapper: F) -> Reader<O>
where
T: 'static + Send,
O: 'static + Send,
F: 'static + Send,
F: FnOnce(T) -> O + 'static + Send,
{
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
return Reader::Async(Box::new(MappedReader::new(self, mapper)));
Expand Down
11 changes: 11 additions & 0 deletions crates/burn-compute/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@ pub struct ComputeRuntime<Device, Server: ComputeServer, Channel> {
clients: spin::Mutex<Option<HashMap<Device, ComputeClient<Server, Channel>>>>,
}

impl<Device, Server, Channel> Default for ComputeRuntime<Device, Server, Channel>
where
Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
Server: ComputeServer,
Channel: ComputeChannel<Server>,
{
fn default() -> Self {
Self::new()
}
}

impl<Device, Server, Channel> ComputeRuntime<Device, Server, Channel>
where
Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
Expand Down
12 changes: 6 additions & 6 deletions crates/burn-core/src/record/serde/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,13 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {
fn clone_unsafely<T>(thing: &T) -> T {
unsafe {
// Allocate memory for the clone.
let clone = ptr::null_mut();
// Correcting pointer usage based on feedback
let clone = ptr::addr_of_mut!(*clone);
let mut clone = std::mem::MaybeUninit::<T>::uninit();
// Get a mutable pointer to the allocated memory.
let clone_ptr = clone.as_mut_ptr();
// Copy the memory
ptr::copy_nonoverlapping(thing as *const T, clone, 1);
// Transmute the cloned data pointer into an owned instance of T.
ptr::read(clone)
ptr::copy_nonoverlapping(thing as *const T, clone_ptr, 1);
// Assume the cloned data is initialized and convert it to an owned instance of T.
clone.assume_init()
}
}

Expand Down
24 changes: 10 additions & 14 deletions crates/burn-core/src/record/serde/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ impl SerializerTrait for Serializer {
Ok(self)
}

fn serialize_newtype_struct<T: ?Sized>(
fn serialize_newtype_struct<T>(
self,
_name: &'static str,
value: &T,
) -> Result<Self::Ok, Self::Error>
where
T: Serialize,
T: Serialize + ?Sized,
{
value.serialize(self)
}
Expand Down Expand Up @@ -128,9 +128,9 @@ impl SerializerTrait for Serializer {
unimplemented!()
}

fn serialize_some<T: ?Sized>(self, value: &T) -> Result<Self::Ok, Self::Error>
fn serialize_some<T>(self, value: &T) -> Result<Self::Ok, Self::Error>
where
T: Serialize,
T: Serialize + ?Sized,
{
value.serialize(self)
}
Expand All @@ -152,15 +152,15 @@ impl SerializerTrait for Serializer {
unimplemented!()
}

fn serialize_newtype_variant<T: ?Sized>(
fn serialize_newtype_variant<T>(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_value: &T,
) -> Result<Self::Ok, Self::Error>
where
T: Serialize,
T: Serialize + ?Sized,
{
unimplemented!()
}
Expand Down Expand Up @@ -207,13 +207,9 @@ impl SerializeStruct for Serializer {
type Ok = NestedValue;
type Error = Error;

fn serialize_field<T: ?Sized>(
&mut self,
key: &'static str,
value: &T,
) -> Result<(), Self::Error>
fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error>
where
T: Serialize,
T: Serialize + ?Sized,
{
let serialized_value = value.serialize(Serializer::new())?;

Expand Down Expand Up @@ -248,9 +244,9 @@ impl SerializeSeq for Serializer {
type Ok = NestedValue;
type Error = Error;

fn serialize_element<T: ?Sized>(&mut self, value: &T) -> Result<(), Self::Error>
fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error>
where
T: Serialize,
T: Serialize + ?Sized,
{
let serialized_value = value.serialize(Serializer::new())?;

Expand Down
74 changes: 54 additions & 20 deletions crates/burn-fusion/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use crate::{
client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, PrecisionBridge,
};
use burn_tensor::{
backend::Backend,
backend::{Backend, DeviceOps},
ops::FloatTensor,
repr::{OperationDescription, ReprBackend},
Device,
};
Expand All @@ -11,30 +12,30 @@ use std::marker::PhantomData;

pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new();

pub(crate) fn get_client<B: FusionBackend>(device: &B::Device) -> B::FusionClient {
CLIENTS.client(device)
pub(crate) fn get_client<B: FusionBackend>(device: &Device<B>) -> Client<B::FusionRuntime> {
CLIENTS.client::<B::FusionRuntime>(device)
}

/// Enable dynamic operation fusion on a backend that implements [fusion backend](crate::FusionBackend).
#[derive(Clone, Debug, Default)]
pub struct Fusion<B> {
pub struct Fusion<B: FusionBackend> {
_backend: PhantomData<B>,
}

impl<B: FusionBackend> Backend for Fusion<B> {
type Device = B::Device;

type FullPrecisionBridge = PrecisionBridge;
type FullPrecisionBridge = PrecisionBridge<B::FullPrecisionBackend>;

type FloatTensorPrimitive<const D: usize> = FusionTensor<B::FusionClient>;
type FloatTensorPrimitive<const D: usize> = FusionTensor<B::FusionRuntime>;

type FloatElem = B::FloatElem;

type IntTensorPrimitive<const D: usize> = FusionTensor<B::FusionClient>;
type IntTensorPrimitive<const D: usize> = FusionTensor<B::FusionRuntime>;

type IntElem = B::IntElem;

type BoolTensorPrimitive<const D: usize> = FusionTensor<B::FusionClient>;
type BoolTensorPrimitive<const D: usize> = FusionTensor<B::FusionRuntime>;

fn name() -> String {
format!("fusion<{}>", B::name())
Expand All @@ -45,10 +46,14 @@ impl<B: FusionBackend> Backend for Fusion<B> {
}

fn sync(device: &Self::Device) {
let client = CLIENTS.client::<B::FusionClient>(&device.clone());
let client = CLIENTS.client::<B::FusionRuntime>(&device.clone());
client.drain();
B::sync(device)
}

fn ad_enabled() -> bool {
false
}
}

/// The status of a [builder](OptimizationBuilder).
Expand Down Expand Up @@ -101,32 +106,61 @@ pub trait OptimizationBuilder<O>: Send {
}

/// The operation created from the [builder](OptimizationBuilder).
pub trait Optimization<B: FusionBackend>: Send {
pub trait Optimization<R: FusionRuntime>: Send {
/// Execute the operation.
fn execute(&mut self, context: &mut Context<'_, B>);
fn execute(&mut self, context: &mut Context<'_, R::FusionHandle>);
/// The number of registered operations in this optimization.
fn len(&self) -> usize;
/// If the current optimization is empty.
fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns the state that can be serialized.
fn to_state(&self) -> B::OptimizationState;
fn to_state(&self) -> R::OptimizationState;
/// Create the optimization from the state.
fn from_state(device: &B::Device, state: B::OptimizationState) -> Self;
fn from_state(device: &R::FusionDevice, state: R::OptimizationState) -> Self;
}

/// Trait that allows an existing [backend](Backend) to specify graph optimizations using
/// [operation builder](crate::OptimizationBuilder).
pub trait FusionBackend: Backend + ReprBackend {
/// Type alias for `<R as FusionRuntime>::FusionDevice`.
pub type FusionDevice<R> = <R as FusionRuntime>::FusionDevice;
/// Type alias for `<R as FusionRuntime>::FusionHandle`.
pub type FusionHandle<R> = <R as FusionRuntime>::FusionHandle;
/// Type alias for `<R as FusionRuntime>::FusionClient`.
pub type Client<R> = <R as FusionRuntime>::FusionClient;

/// Trait that defines a runtime that will benefits from fused operations.
pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug {
/// The state that can be serialized for an optimization.
type OptimizationState: Serialize + DeserializeOwned;
/// Optimization type for the backend.
type Optimization: Optimization<Self>;
/// What kind of client should be used.
type FusionClient: FusionClient<FusionBackend = Self>;
/// Handle used to store tensor dynamically.
type FusionHandle: Clone + Send;
/// Device used by the runtime.
type FusionDevice: DeviceOps;
/// The client to interact with the runtime.
type FusionClient: FusionClient<Self>;

/// The list of optimizations that will be used to optimize the computational graph.
fn optimizations(device: Device<Self>)
-> Vec<Box<dyn OptimizationBuilder<Self::Optimization>>>;
fn optimizations(
device: Self::FusionDevice,
) -> Vec<Box<dyn OptimizationBuilder<Self::Optimization>>>;
}

/// Trait that allows an existing [backend](Backend) to specify graph optimizations using
/// [operation builder](crate::OptimizationBuilder).
pub trait FusionBackend:
ReprBackend<Handle = FusionHandle<Self::FusionRuntime>, Device = FusionDevice<Self::FusionRuntime>>
{
/// The runtime used for this backend.
type FusionRuntime: FusionRuntime;

/// Cast a float tensor and returns the resulting handle.
fn cast_float<const D: usize>(
tensor: FloatTensor<Self, D>,
dtype: burn_tensor::DType,
) -> Self::Handle;

/// Pointer to the full precision fusion backend.
type FullPrecisionBackend: FusionBackend<FusionRuntime = Self::FusionRuntime>;
}
95 changes: 82 additions & 13 deletions crates/burn-fusion/src/bridge.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,94 @@
use burn_tensor::backend::BackendBridge;

use crate::{Fusion, FusionBackend};
use crate::{
client::FusionClient, stream::execution::Operation, Fusion, FusionBackend, FusionRuntime,
};
use burn_tensor::{
backend::BackendBridge,
ops::FloatTensor,
repr::{
BaseOperationDescription, HandleContainer, OperationDescription, UnaryOperationDescription,
},
Element,
};
use std::marker::PhantomData;

#[derive(Debug)]
/// Fusion bridge.
pub struct PrecisionBridge;
pub struct PrecisionBridge<B: FusionBackend> {
_backend: PhantomData<B>,
}

impl<B: FusionBackend> BackendBridge<Fusion<B>> for PrecisionBridge {
type Target = Fusion<B>;
impl<R, BInput, BTarget> BackendBridge<Fusion<BInput>> for PrecisionBridge<BTarget>
where
BInput: FusionBackend<FusionRuntime = R>,
BTarget: FusionBackend<FusionRuntime = R>,
R: FusionRuntime + 'static,
{
type Target = Fusion<BTarget>;

fn into_target<const D: usize>(
tensor: burn_tensor::ops::FloatTensor<Fusion<B>, D>,
tensor: FloatTensor<Fusion<BInput>, D>,
_device: Option<burn_tensor::Device<Self::Target>>,
) -> burn_tensor::ops::FloatTensor<Self::Target, D> {
tensor
) -> FloatTensor<Self::Target, D> {
cast::<R, BInput, BTarget, D>(tensor)
}

fn from_target<const D: usize>(
tensor: burn_tensor::ops::FloatTensor<Self::Target, D>,
_device: Option<burn_tensor::Device<Fusion<B>>>,
) -> burn_tensor::ops::FloatTensor<Fusion<B>, D> {
tensor
tensor: FloatTensor<Self::Target, D>,
_device: Option<burn_tensor::Device<Fusion<BInput>>>,
) -> FloatTensor<Fusion<BInput>, D> {
cast::<R, BTarget, BInput, D>(tensor)
}
}

fn cast<R, BInput, BTarget, const D: usize>(
input: FloatTensor<Fusion<BInput>, D>,
) -> FloatTensor<Fusion<BTarget>, D>
where
BInput: FusionBackend<FusionRuntime = R>,
BTarget: FusionBackend<FusionRuntime = R>,
R: FusionRuntime + 'static,
{
#[derive(new)]
struct Cast<R: FusionRuntime, BInput: FusionBackend, BTarget: FusionBackend, const D: usize> {
desc: UnaryOperationDescription,
_bi: PhantomData<BInput>,
_bt: PhantomData<BTarget>,
_runtime: PhantomData<R>,
}

impl<const D: usize, R, BInput, BTarget> Operation<BTarget::FusionRuntime>
for Cast<R, BInput, BTarget, D>
where
BInput: FusionBackend<FusionRuntime = R>,
BTarget: FusionBackend<FusionRuntime = R>,
R: FusionRuntime,
{
fn execute(
self: Box<Self>,
handles: &mut HandleContainer<<BTarget::FusionRuntime as FusionRuntime>::FusionHandle>,
) {
let input = handles.get_float_tensor::<BInput, D>(&self.desc.input);
let output = BInput::cast_float(input, BTarget::FloatElem::dtype());

handles.register_handle(self.desc.out.id, output);
}
}

let stream = input.stream;
let out = input
.client
.tensor_uninitialized(input.shape.clone(), BTarget::FloatElem::dtype());

let desc = UnaryOperationDescription {
input: input.into_description(),
out: out.to_description_out(),
};

out.client.register(
vec![stream],
OperationDescription::BaseFloat(BaseOperationDescription::Cast(desc.clone())),
Cast::<R, BInput, BTarget, D>::new(desc),
);

out
}
Loading

0 comments on commit 5d959e2

Please sign in to comment.