Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf/transaction #2521

Merged
merged 14 commits into from
Nov 21, 2024
Prev Previous commit
Next Next commit
TransactionQuery
  • Loading branch information
nathanielsimard committed Nov 20, 2024
commit 8c000fc3b5c37f0fb06928f81d88ad928068deb9
10 changes: 5 additions & 5 deletions crates/burn-autodiff/src/ops/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ impl<B: Backend, C: CheckpointStrategy> TransactionOps<Self> for Autodiff<B, C>
) -> impl std::future::Future<Output = burn_tensor::ops::TransactionResult> + 'static + Send
{
B::tr_execute(Transaction {
floats: transaction
.floats
read_floats: transaction
.read_floats
.into_iter()
.map(|t| t.primitive)
.collect(),
qfloats: transaction.qfloats,
ints: transaction.ints,
bools: transaction.bools,
read_qfloats: transaction.read_qfloats,
read_ints: transaction.read_ints,
read_bools: transaction.read_bools,
})
}
}
16 changes: 8 additions & 8 deletions crates/burn-fusion/src/ops/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,23 @@ impl<B: FusionBackend> TransactionOps<Fusion<B>> for Fusion<B> {
) -> impl std::future::Future<Output = burn_tensor::ops::TransactionResult> + 'static + Send
{
B::tr_execute(Transaction {
floats: transaction
.floats
read_floats: transaction
.read_floats
.into_iter()
.map(|t| t.client.clone().resolve_tensor_float::<B>(t))
.collect(),
qfloats: transaction
.qfloats
read_qfloats: transaction
.read_qfloats
.into_iter()
.map(|_t| todo!("Quantization not supported yet"))
.collect(),
ints: transaction
.ints
read_ints: transaction
.read_ints
.into_iter()
.map(|t| t.client.clone().resolve_tensor_int::<B>(t))
.collect(),
bools: transaction
.bools
read_bools: transaction
.read_bools
.into_iter()
.map(|t| t.client.clone().resolve_tensor_bool::<B>(t))
.collect(),
Expand Down
12 changes: 6 additions & 6 deletions crates/burn-jit/src/ops/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ where

let mut kinds = Vec::new();

transaction.floats.into_iter().for_each(|t| {
transaction.read_floats.into_iter().for_each(|t| {
if client.is_none() {
client = Some(t.client.clone());
}
Expand All @@ -37,7 +37,7 @@ where
num_bindings += 1;
bindings.push(t.handle.binding())
});
transaction.ints.into_iter().for_each(|t| {
transaction.read_ints.into_iter().for_each(|t| {
if client.is_none() {
client = Some(t.client.clone());
}
Expand All @@ -46,7 +46,7 @@ where
num_bindings += 1;
bindings.push(t.handle.binding())
});
transaction.bools.into_iter().for_each(|t| {
transaction.read_bools.into_iter().for_each(|t| {
if client.is_none() {
client = Some(t.client.clone());
}
Expand All @@ -72,23 +72,23 @@ where
match kind {
Kind::Float(index, shape, dtype) => {
let bytes = data.get_mut(index).unwrap().take().unwrap();
result.floats.push(TensorData {
result.read_floats.push(TensorData {
bytes,
shape,
dtype,
});
}
Kind::Int(index, shape, dtype) => {
let bytes = data.get_mut(index).unwrap().take().unwrap();
result.ints.push(TensorData {
result.read_ints.push(TensorData {
bytes,
shape,
dtype,
});
}
Kind::Bool(index, shape, dtype) => {
let bytes = data.get_mut(index).unwrap().take().unwrap();
result.bools.push(TensorData {
result.read_bools.push(TensorData {
bytes,
shape,
dtype,
Expand Down
28 changes: 28 additions & 0 deletions crates/burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use crate::{
};
use crate::{DType, Element, TensorPrimitive};

use super::TransactionQuery;

/// A tensor with a given backend, shape and data type.
///
/// # Indexing
Expand Down Expand Up @@ -2028,6 +2030,18 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
tensor: Self::Primitive,
) -> impl Future<Output = TensorData> + 'static + Send;

/// Read the data from the tensor using a transaction.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function,
/// which is more high-level and designed for public use.
fn read(tr: &mut TransactionQuery<B>, tensor: Self::Primitive);

/// Creates a tensor from the given data.
///
/// # Arguments
Expand Down Expand Up @@ -2257,6 +2271,10 @@ impl<B: Backend> BasicOps<B> for Float {
TensorPrimitive::Float(B::float_empty(shape, device))
}

fn read(tr: &mut TransactionQuery<B>, tensor: Self::Primitive) {
tr.read_float(tensor);
}

fn shape(tensor: &Self::Primitive) -> Shape {
match tensor {
TensorPrimitive::Float(tensor) => B::float_shape(tensor),
Expand Down Expand Up @@ -2445,6 +2463,11 @@ impl<B: Backend> BasicOps<B> for Int {
fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
B::int_empty(shape, device)
}

fn read(tr: &mut TransactionQuery<B>, tensor: Self::Primitive) {
tr.read_int(tensor);
}

fn shape(tensor: &Self::Primitive) -> Shape {
B::int_shape(tensor)
}
Expand Down Expand Up @@ -2544,6 +2567,11 @@ impl<B: Backend> BasicOps<B> for Bool {
fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
B::bool_empty(shape, device)
}

fn read(tr: &mut TransactionQuery<B>, tensor: Self::Primitive) {
tr.read_bool(tensor);
}

fn shape(tensor: &Self::Primitive) -> Shape {
B::bool_shape(tensor)
}
Expand Down
54 changes: 31 additions & 23 deletions crates/burn-tensor/src/tensor/api/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use super::{Bool, Int, Tensor, TensorPrimitive};
use crate::{backend::Backend, ops::Transaction, TensorData};
use super::{BasicOps, Tensor, TensorPrimitive};
use crate::{
backend::Backend,
ops::{BoolTensor, IntTensor, Transaction},
TensorData,
};

#[derive(Default)]
pub struct TransactionBuilder<B: Backend> {
pub struct TransactionQuery<B: Backend> {
op: Transaction<B>,
orders: Vec<Order>,
}
Expand All @@ -14,38 +18,42 @@ enum Order {
Bool(usize),
}

impl<B: Backend> TransactionBuilder<B> {
pub fn float<const D: usize>(mut self, tensor: Tensor<B, D>) -> Self {
match tensor.into_primitive() {
impl<B: Backend> TransactionQuery<B> {
pub fn read<const D: usize, K: BasicOps<B>>(mut self, tensor: Tensor<B, D, K>) -> Self {
K::read(&mut self, tensor.into_primitive());
self
}

pub(crate) fn read_float(&mut self, tensor: TensorPrimitive<B>) {
match tensor {
TensorPrimitive::Float(tensor) => {
self.orders.push(Order::Float(self.op.floats.len()));
self.op.floats.push(tensor);
self.orders.push(Order::Float(self.op.read_floats.len()));
self.op.read_floats.push(tensor);
}
TensorPrimitive::QFloat(tensor) => {
self.orders.push(Order::QFloat(self.op.qfloats.len()));
self.op.qfloats.push(tensor);
self.orders.push(Order::QFloat(self.op.read_qfloats.len()));
self.op.read_qfloats.push(tensor);
}
}
self
}
pub fn int<const D: usize>(mut self, tensor: Tensor<B, D, Int>) -> Self {
self.orders.push(Order::Int(self.op.ints.len()));
self.op.ints.push(tensor.into_primitive());
self

pub(crate) fn read_int(&mut self, tensor: IntTensor<B>) {
self.orders.push(Order::Int(self.op.read_ints.len()));
self.op.read_ints.push(tensor);
}
pub fn bool<const D: usize>(mut self, tensor: Tensor<B, D, Bool>) -> Self {
self.orders.push(Order::Bool(self.op.bools.len()));
self.op.bools.push(tensor.into_primitive());
self

pub(crate) fn read_bool(&mut self, tensor: BoolTensor<B>) {
self.orders.push(Order::Bool(self.op.read_bools.len()));
self.op.read_bools.push(tensor);
}

pub fn execute(self) -> Vec<TensorData> {
let result = burn_common::future::block_on(B::tr_execute(self.op));

let mut floats: Vec<_> = result.floats.into_iter().map(|t| Some(t)).collect();
let mut qfloats: Vec<_> = result.qfloats.into_iter().map(|t| Some(t)).collect();
let mut ints: Vec<_> = result.ints.into_iter().map(|t| Some(t)).collect();
let mut bools: Vec<_> = result.bools.into_iter().map(|t| Some(t)).collect();
let mut floats: Vec<_> = result.read_floats.into_iter().map(|t| Some(t)).collect();
let mut qfloats: Vec<_> = result.read_qfloats.into_iter().map(|t| Some(t)).collect();
let mut ints: Vec<_> = result.read_ints.into_iter().map(|t| Some(t)).collect();
let mut bools: Vec<_> = result.read_bools.into_iter().map(|t| Some(t)).collect();

self.orders
.into_iter()
Expand Down
32 changes: 16 additions & 16 deletions crates/burn-tensor/src/tensor/ops/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ use crate::{backend::Backend, TensorData};

#[derive(Default)]
pub struct Transaction<B: Backend> {
pub floats: Vec<FloatTensor<B>>,
pub qfloats: Vec<QuantizedTensor<B>>,
pub ints: Vec<IntTensor<B>>,
pub bools: Vec<BoolTensor<B>>,
pub read_floats: Vec<FloatTensor<B>>,
pub read_qfloats: Vec<QuantizedTensor<B>>,
pub read_ints: Vec<IntTensor<B>>,
pub read_bools: Vec<BoolTensor<B>>,
}

#[derive(Default)]
pub struct TransactionResult {
pub floats: Vec<TensorData>,
pub qfloats: Vec<TensorData>,
pub ints: Vec<TensorData>,
pub bools: Vec<TensorData>,
pub read_floats: Vec<TensorData>,
pub read_qfloats: Vec<TensorData>,
pub read_ints: Vec<TensorData>,
pub read_bools: Vec<TensorData>,
}

pub trait TransactionOps<B: Backend> {
Expand All @@ -29,24 +29,24 @@ pub trait TransactionOps<B: Backend> {
let mut ints = Vec::new();
let mut bools = Vec::new();

for t in transaction.floats {
for t in transaction.read_floats {
floats.push(B::float_into_data(t).await);
}
for t in transaction.qfloats {
for t in transaction.read_qfloats {
qfloats.push(B::q_into_data(t).await);
}
for t in transaction.ints {
for t in transaction.read_ints {
ints.push(B::int_into_data(t).await);
}
for t in transaction.bools {
for t in transaction.read_bools {
bools.push(B::bool_into_data(t).await);
}

TransactionResult {
floats,
qfloats,
ints,
bools,
read_floats: floats,
read_qfloats: qfloats,
read_ints: ints,
read_bools: bools,
}
}
}
Expand Down
32 changes: 27 additions & 5 deletions crates/burn-train/src/learner/classification.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::metric::processor::ItemLazy;
use crate::metric::{AccuracyInput, Adaptor, HammingScoreInput, LossInput};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{Int, Tensor, TransactionBuilder};
use burn_core::tensor::{Int, Tensor, TransactionQuery};
use burn_ndarray::NdArray;

/// Simple classification output adapted for multiple metrics.
Expand All @@ -21,10 +21,10 @@ impl<B: Backend> ItemLazy for ClassificationOutput<B> {
type ItemSync = ClassificationOutput<NdArray>;

fn sync(self) -> Self::ItemSync {
let [output, loss, targets] = TransactionBuilder::default()
.float(self.output)
.float(self.loss)
.int(self.targets)
let [output, loss, targets] = TransactionQuery::default()
.read(self.output)
.read(self.loss)
.read(self.targets)
.execute()
.try_into()
.expect("Correct amount of data");
Expand Down Expand Up @@ -64,6 +64,28 @@ pub struct MultiLabelClassificationOutput<B: Backend> {
pub targets: Tensor<B, 2, Int>,
}

impl<B: Backend> ItemLazy for MultiLabelClassificationOutput<B> {
type ItemSync = MultiLabelClassificationOutput<NdArray>;

fn sync(self) -> Self::ItemSync {
let [output, loss, targets] = TransactionQuery::default()
.read(self.output)
.read(self.loss)
.read(self.targets)
.execute()
.try_into()
.expect("Correct amount of data");

let device = &Default::default();

MultiLabelClassificationOutput {
output: Tensor::from_data(output, device),
loss: Tensor::from_data(loss, device),
targets: Tensor::from_data(targets, device),
}
}
}

impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelClassificationOutput<B> {
fn adapt(&self) -> HammingScoreInput<B> {
HammingScoreInput::new(self.output.clone(), self.targets.clone())
Expand Down
Loading