Skip to content

Commit

Permalink
Remove dtype rewrite (#2528)
Browse files Browse the repository at this point in the history
* Remove dtype rewrite

* Remove test

* Fix tensor display for bool encoding abstraction

---------

Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com>
  • Loading branch information
ArthurBrussee and laggui authored Nov 25, 2024
1 parent a35321c commit fe3e43a
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 22 deletions.
4 changes: 1 addition & 3 deletions crates/burn-candle/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ pub struct CandleTensor {
impl TensorMetadata for CandleTensor {
fn dtype(&self) -> DType {
match self.tensor.dtype() {
// NOTE: bool tensors are stored as u32, we currently make this assumption
// since `TensorMetadata::dtype()` is used for display purposes only at this time.
candle_core::DType::U8 => DType::Bool,
candle_core::DType::U8 => DType::U8,
candle_core::DType::U32 => DType::U32,
candle_core::DType::I64 => DType::I64,
candle_core::DType::BF16 => DType::BF16,
Expand Down
7 changes: 1 addition & 6 deletions crates/burn-jit/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,7 @@ where

impl<R: JitRuntime> TensorMetadata for JitTensor<R> {
fn dtype(&self) -> DType {
match self.dtype {
// NOTE: bool tensors are stored as u32, we currently make this assumption
// since `TensorMetadata::dtype()` is used for display purposes only at this time.
DType::U32 => DType::Bool,
_ => self.dtype,
}
self.dtype
}

fn shape(&self) -> Shape {
Expand Down
11 changes: 10 additions & 1 deletion crates/burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use alloc::string::String;
use alloc::vec;

use burn_common::stub::RwLock;
use core::any::TypeId;
use core::future::Future;
use core::iter::repeat;
use core::{fmt::Debug, ops::Range};
Expand Down Expand Up @@ -1873,7 +1874,15 @@ where
writeln!(f, " device: {:?},", self.device())?;
writeln!(f, " backend: {:?},", B::name())?;
writeln!(f, " kind: {:?},", K::name())?;
writeln!(f, " dtype: {:?},", self.primitive.dtype().name())?;

// Bool tensors might be encoded in a different type, which we abstract for the display
let dtype = if TypeId::of::<K::Elem>() == TypeId::of::<bool>() {
DType::Bool
} else {
self.primitive.dtype()
};

writeln!(f, " dtype: {:?},", dtype.name())?;
write!(f, "}}")
}
}
Expand Down
12 changes: 0 additions & 12 deletions crates/burn-tensor/src/tests/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,4 @@ mod tests {
<TestBackend as Backend>::IntElem::dtype() // default int elem type
);
}

#[test]
fn should_support_bool_dtype() {
let tensor =
TestTensorBool::<2>::from([[false, true, true], [false, false, true]]).into_primitive();

assert_eq!(
burn_tensor::TensorMetadata::shape(&tensor),
Shape::new([2, 3])
);
assert_eq!(burn_tensor::TensorMetadata::dtype(&tensor), DType::Bool);
}
}

0 comments on commit fe3e43a

Please sign in to comment.