Skip to content

Commit

Permalink
Fix/burn remote deadlock + burn fusion draining (tracel-ai#2492)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Nov 15, 2024
1 parent 81a4e54 commit fb05a55
Show file tree
Hide file tree
Showing 23 changed files with 217 additions and 311 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ members = [
exclude = [
"examples/notebook",
"examples/raspberry-pi-pico", # will cause dependency building issues otherwise
# "crates/burn-cuda", # comment this line to work on burn-cuda
]

[workspace.package]
Expand Down Expand Up @@ -168,4 +167,3 @@ tracel-xtask = { version = "~1.1" }

[profile.dev]
debug = 0 # Speed up compilation time and not necessary.
opt-level = 2
8 changes: 7 additions & 1 deletion crates/burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ default = [
"burn-tch?/default",
"burn-tensor/default",
"burn-wgpu?/default",
"burn-router?/default",
"burn-cuda?/default",
"burn-hip?/default",
"burn-autodiff?/default",
"burn-hip?/default",
]
doc = [
"std",
Expand All @@ -40,6 +41,7 @@ doc = [
"vision",
"autodiff",
"remote",
"router",
"server",
# Doc features
"burn-candle/doc",
Expand All @@ -49,6 +51,7 @@ doc = [
"burn-tch/doc",
"burn-tensor/doc",
"burn-wgpu/doc",
"burn-router/doc",
"burn-cuda/doc",
"burn-hip/doc",
]
Expand All @@ -63,6 +66,7 @@ std = [
"burn-ndarray?/std",
"burn-tensor/std",
"burn-wgpu?/std",
"burn-router?/std",
"burn-cuda?/std",
"burn-hip?/std",
"flate2",
Expand All @@ -89,6 +93,7 @@ openblas = ["burn-ndarray?/blas-openblas"]
openblas-system = ["burn-ndarray?/blas-openblas-system"]
template = ["burn-wgpu?/template"]
remote = ["burn-remote/client"]
router = ["burn-router"]
server = ["burn-remote/server"]

candle = ["burn-candle"]
Expand Down Expand Up @@ -136,6 +141,7 @@ burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true,
burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true }
burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false }
burn-remote = { path = "../burn-remote", version = "0.16.0", default-features = false, optional = true }
burn-router = { path = "../burn-router", version = "0.16.0", default-features = false, optional = true }

data-encoding = { workspace = true }
uuid = { workspace = true }
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,6 @@ pub use burn_tch as libtorch;

#[cfg(feature = "tch")]
pub use burn_tch::LibTorch;

#[cfg(feature = "router")]
pub use burn_router::Router;
49 changes: 10 additions & 39 deletions crates/burn-fusion/src/stream/base.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::collections::BTreeSet;

use super::{execution::Operation, OperationConverter, RelativeOps};
use crate::FusionRuntime;
use burn_tensor::repr::OperationDescription;
use burn_tensor::repr::{OperationDescription, TensorId};

pub use burn_common::id::StreamId;

/// A growing list of [tensor operation descriptions](OperationDescription).
pub struct OperationQueue<R: FusionRuntime> {
Expand All @@ -15,6 +19,7 @@ pub struct OperationQueue<R: FusionRuntime> {
pub(crate) relative: Vec<OperationDescription>,
pub(crate) converter: OperationConverter,
pub(crate) operations: Vec<Box<dyn Operation<R>>>,
pub(crate) ids: BTreeSet<TensorId>,
}

impl<R: FusionRuntime> Default for OperationQueue<R> {
Expand All @@ -23,44 +28,6 @@ impl<R: FusionRuntime> Default for OperationQueue<R> {
}
}

/// The stream id.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub struct StreamId {
#[cfg(feature = "std")]
value: std::thread::ThreadId,
#[cfg(not(feature = "std"))]
value: (),
}

impl StreamId {
/// Get the current stream id.
pub fn current() -> Self {
Self {
#[cfg(feature = "std")]
value: Self::id(),
#[cfg(not(feature = "std"))]
value: (),
}
}

#[cfg(feature = "std")]
fn id() -> std::thread::ThreadId {
std::thread_local! {
static ID: std::cell::OnceCell::<std::thread::ThreadId> = const { std::cell::OnceCell::new() };
};

// Getting the current thread is expensive, so we cache the value into a thread local
// variable, which is very fast.
ID.with(|cell| *cell.get_or_init(|| std::thread::current().id()))
}
}

impl core::fmt::Display for StreamId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("StreamID({:?})", self.value))
}
}

impl<R: FusionRuntime> OperationQueue<R> {
/// Create a new empty queue.
pub fn new() -> Self {
Expand All @@ -69,6 +36,7 @@ impl<R: FusionRuntime> OperationQueue<R> {
relative: Vec::new(),
converter: OperationConverter::default(),
operations: Vec::new(),
ids: BTreeSet::new(),
}
}

Expand All @@ -78,6 +46,9 @@ impl<R: FusionRuntime> OperationQueue<R> {
/// representation that can be reused when the same pattern emerge in different but similar
/// scenario, so that the same optimization can be used.
pub fn add(&mut self, global: OperationDescription, operation: Box<dyn Operation<R>>) {
for node in global.nodes() {
self.ids.insert(node.id);
}
let relative = global.to_relative(&mut self.converter);
self.relative.push(relative);
self.global.push(global);
Expand Down
39 changes: 26 additions & 13 deletions crates/burn-fusion/src/stream/multi.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use burn_tensor::repr::{HandleContainer, OperationDescription};
use burn_tensor::repr::{HandleContainer, OperationDescription, TensorDescription};

use super::{
execution::{ExecutionMode, Operation, Processor, StreamSegment},
Expand Down Expand Up @@ -32,7 +32,7 @@ impl<R: FusionRuntime> MultiStream<R> {
operation: Box<dyn Operation<R>>,
handles: &mut HandleContainer<R::FusionHandle>,
) {
let id = self.maybe_drain(streams, handles);
let id = self.resolve_streams(streams, handles, &desc);

let stream = match self.streams.get_mut(&id) {
Some(stream) => stream,
Expand Down Expand Up @@ -64,7 +64,7 @@ impl<R: FusionRuntime> MultiStream<R> {
}
}

/// Drain the streams
/// Drain a stream
pub fn drain(&mut self, handles: &mut HandleContainer<R::FusionHandle>, id: StreamId) {
if let Some(mut stream) = self.streams.remove(&id) {
stream.processor.process(
Expand All @@ -79,29 +79,42 @@ impl<R: FusionRuntime> MultiStream<R> {
/// When one of the provided streams is different from the current stream, we drain them.
///
/// Returns the current stream id.
fn maybe_drain(
fn resolve_streams(
&mut self,
streams: Vec<StreamId>,
handles: &mut HandleContainer<R::FusionHandle>,
op: &OperationDescription,
) -> StreamId {
let streams = Self::remove_duplicate(streams);
let current = StreamId::current();

if streams.len() == 1 {
// The only case where we don't need to drain, because we will process
// the operation queue of the current stream right after this.
if streams[0] == current {
return current;
}
}

for id in streams {
self.drain(handles, id);
if id != current {
self.resolve_stream(handles, id, op.nodes());
}
}

current
}

/// Drain the stream only if one of the tensor in the given nodes is also included in the
/// stream queue.
fn resolve_stream(
&mut self,
handles: &mut HandleContainer<R::FusionHandle>,
id: StreamId,
nodes: Vec<&TensorDescription>,
) {
if let Some(stream) = self.streams.get(&id) {
for node in nodes {
if stream.queue.ids.contains(&node.id) {
self.drain(handles, id);
return;
}
}
}
}

fn remove_duplicate(items: Vec<StreamId>) -> Vec<StreamId> {
if items.len() == 1 {
return items;
Expand Down
22 changes: 12 additions & 10 deletions crates/burn-hip/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ pub type Hip<F = f32, I = i32> = JitBackend<HipRuntime, F, I>;
#[cfg(feature = "fusion")]
pub type Hip<F = f32, I = i32> = burn_fusion::Fusion<JitBackend<HipRuntime, F, I>>;

#[cfg(target_os = "linux")]
#[cfg(test)]
mod tests {
use burn_jit::JitBackend;

pub type TestRuntime = cubecl::hip::HipRuntime;
pub use half::{bf16, f16};

burn_jit::testgen_all!();
}
// TODO: Hang the computer when AMD isn't available.
//
// #[cfg(target_os = "linux")]
// #[cfg(test)]
// mod tests {
// use burn_jit::JitBackend;
//
// pub type TestRuntime = cubecl::hip::HipRuntime;
// pub use half::{bf16, f16};
//
// burn_jit::testgen_all!();
// }
5 changes: 3 additions & 2 deletions crates/burn-remote/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ version.workspace = true
[features]
default = []
doc = []
client = ["tokio-tungstenite"]
client = ["tokio-tungstenite", "async-channel", "tokio/sync"]
server = ["axum", "tracing-core", "tracing-subscriber"]


Expand All @@ -28,13 +28,14 @@ derive-new = {workspace = true }
log = { workspace = true }

# Shared dependencies
tokio = { version = "1.37", features = ["sync", "rt-multi-thread"] }
tokio = { version = "1.37", features = ["rt-multi-thread"] }
serde = { workspace = true, features = ["derive"] }
serde_bytes = { workspace = true }
rmp-serde = { workspace = true }
futures-util = { version = "0.3" }

# Client dependencies
async-channel = { workspace = true, optional = true }
tokio-tungstenite = { version = "0.24", optional = true }

# Server dependencies
Expand Down
39 changes: 17 additions & 22 deletions crates/burn-remote/src/client/base.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use super::worker::{ClientRequest, ClientWorker};
use crate::shared::{ComputeTask, ConnectionId, Task, TaskResponseContent};
use async_channel::Sender;
use burn_common::id::StreamId;
use burn_tensor::repr::TensorId;
use std::{
future::Future,
sync::{atomic::AtomicU64, Arc},
};
use tokio::sync::mpsc::Sender;

pub use super::WsDevice;

Expand Down Expand Up @@ -46,22 +46,19 @@ pub(crate) struct WsSender {
}

impl WsSender {
pub(crate) fn send(&self, task: ComputeTask) -> impl Future<Output = ()> + Send {
pub(crate) fn send(&self, task: ComputeTask) {
let position = self
.position_counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let stream_id = StreamId::current();
let sender = self.sender.clone();

async move {
sender
.send(ClientRequest::WithoutCallback(Task::Compute(
task,
ConnectionId::new(position, stream_id),
)))
.await
.unwrap();
}
sender
.send_blocking(ClientRequest::WithoutCallback(Task::Compute(
task,
ConnectionId::new(position, stream_id),
)))
.unwrap();
}

pub(crate) fn new_tensor_id(&self) -> TensorId {
Expand All @@ -79,20 +76,18 @@ impl WsSender {
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let stream_id = StreamId::current();
let sender = self.sender.clone();
let (callback_sender, mut callback_recv) = tokio::sync::mpsc::channel(1);
let (callback_sender, callback_recv) = async_channel::bounded(1);
sender
.send_blocking(ClientRequest::WithSyncCallback(
Task::Compute(task, ConnectionId::new(position, stream_id)),
callback_sender,
))
.unwrap();

async move {
sender
.send(ClientRequest::WithSyncCallback(
Task::Compute(task, ConnectionId::new(position, stream_id)),
callback_sender,
))
.await
.unwrap();

match callback_recv.recv().await {
Some(val) => val,
None => panic!(""),
Ok(val) => val,
Err(err) => panic!("{err:?}"),
}
}
}
Expand Down
Loading

0 comments on commit fb05a55

Please sign in to comment.