Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf/transaction #2521

Merged
merged 14 commits into from
Nov 21, 2024
Prev Previous commit
Next Next commit
Update:
  • Loading branch information
nathanielsimard committed Nov 20, 2024
commit 0e813529fc430900d0d8693b71b8eecf3d7dfd08
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/conv/conv2d/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,6 @@ pub fn conv_transpose2d<R: JitRuntime, E: FloatElement, I: IntElement>(

#[allow(unused)]
pub(crate) fn debug_data<R: JitRuntime, E: JitElement>(tensor: JitTensor<R, E>) -> TensorData {
let bytes = tensor.client.read(tensor.handle.binding());
let bytes = tensor.client.read_one(tensor.handle.binding());
TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape)
}
6 changes: 3 additions & 3 deletions crates/burn-jit/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ pub(crate) fn from_data<R: JitRuntime, E: JitElement>(
pub(crate) async fn into_data<R: JitRuntime, E: JitElement>(tensor: JitTensor<R, E>) -> TensorData {
let tensor = kernel::into_contiguous(tensor);

let bytes = tensor.client.read_async(tensor.handle.binding()).await;
let bytes = tensor.client.read_one_async(tensor.handle.binding()).await;
TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape)
}

#[allow(unused, reason = "useful for debugging kernels")]
pub(crate) fn into_data_sync<R: JitRuntime, E: JitElement>(tensor: JitTensor<R, E>) -> TensorData {
let tensor = kernel::into_contiguous(tensor);

let bytes = tensor.client.read(tensor.handle.binding());
let bytes = tensor.client.read_one(tensor.handle.binding());
TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape)
}

pub(crate) async fn bool_into_data<R: JitRuntime>(tensor: JitTensor<R, u32>) -> TensorData {
let tensor = kernel::into_contiguous(tensor);
let bytes = tensor.client.read_async(tensor.handle.binding()).await;
let bytes = tensor.client.read_one_async(tensor.handle.binding()).await;
TensorData::new(
u32::from_bytes(&bytes).iter().map(|i| *i != 0).collect(),
tensor.shape,
Expand Down
5 changes: 4 additions & 1 deletion crates/burn-jit/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ where
let strategy = tensor.strategy();
let qtensor = kernel::into_contiguous(tensor.qtensor);

let bytes = qtensor.client.read_async(qtensor.handle.binding()).await;
let bytes = qtensor
.client
.read_one_async(qtensor.handle.binding())
.await;

// TensorData keeps quantized values packed into 32-bit unsigned integers so we can
// keep the current representation, just cast the bytes as u32.
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/ops/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ where

async move {
let mut data = client
.read_many_async(bindings)
.read_async(bindings)
.await
.into_iter()
.map(|t| Some(t))
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ where
device: R::Device,
) -> Self {
let bytes = burn_common::reader::try_read_sync(
self.client.read_async(self.handle.clone().binding()),
self.client.read_one_async(self.handle.clone().binding()),
)
.expect("Can only change client synchronously");
let handle = client.create(&bytes);
Expand Down