Skip to content

Commit

Permalink
feat: improve bool tensor (#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Dec 2, 2022
1 parent cdc29a0 commit 7c38a98
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 1 deletion.
7 changes: 7 additions & 0 deletions burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
B::bool_reshape(tensor, shape)
}

fn bool_to_device<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D>,
device: <ADBackendDecorator<B> as Backend>::Device,
) -> <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D> {
B::bool_to_device(tensor, device)
}

fn device<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::Device {
Expand Down
7 changes: 7 additions & 0 deletions burn-ndarray/src/tensor_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
Data::new(values, tensor.shape)
}

fn bool_to_device<const D: usize>(
tensor: &NdArrayTensor<bool, D>,
_device: NdArrayDevice,
) -> NdArrayTensor<bool, D> {
tensor.clone()
}

fn bool_reshape<const D1: usize, const D2: usize>(
tensor: &NdArrayTensor<bool, D1>,
shape: Shape<D2>,
Expand Down
15 changes: 15 additions & 0 deletions burn-tch/src/tensor_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,21 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
Data::new(values, tensor.shape)
}

fn bool_to_device<const D: usize>(
tensor: &TchTensor<bool, D>,
device: TchDevice,
) -> TchTensor<bool, D> {
let device = match device {
TchDevice::Cpu => tch::Device::Cpu,
TchDevice::Cuda(num) => tch::Device::Cuda(num),
};
TchTensor {
kind: tensor.kind,
tensor: tensor.tensor.to(device),
shape: tensor.shape,
}
}

fn bool_reshape<const D1: usize, const D2: usize>(
tensor: &TchTensor<bool, D1>,
shape: Shape<D2>,
Expand Down
7 changes: 6 additions & 1 deletion burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ pub trait Backend:
+ 'static
+ std::fmt::Debug;

type BoolTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + std::fmt::Debug;
type BoolTensorPrimitive<const D: usize>: Clone
+ Send
+ Sync
+ 'static
+ std::fmt::Debug
+ From<<Self::IntegerBackend as Backend>::BoolTensorPrimitive<D>>;

fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
Expand Down
8 changes: 8 additions & 0 deletions burn-tensor/src/tensor/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ where
B::bool_shape(&self.value)
}

pub fn to_device(&self, device: B::Device) -> Self {
Self::new(B::bool_to_device(&self.value, device))
}

/// Returns the dimensions of the current tensor.
///
/// Equivalent to `tensor.shape().dims`.
Expand All @@ -44,6 +48,10 @@ where
Tensor::from_data(data.convert())
}

pub fn from_int_backend(tensor: BoolTensor<B::IntegerBackend, D>) -> Self {
Self::new(tensor.value.into())
}

/// Reshape the tensor to have the given shape.
///
/// # Panics
Expand Down
4 changes: 4 additions & 0 deletions burn-tensor/src/tensor/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ pub trait TensorOps<B: Backend> {
fn bool_shape<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> &Shape<D>;
fn bool_to_data<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Data<bool, D>;
fn bool_into_data<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> Data<bool, D>;
fn bool_to_device<const D: usize>(
tensor: &B::BoolTensorPrimitive<D>,
device: B::Device,
) -> B::BoolTensorPrimitive<D>;
fn bool_reshape<const D1: usize, const D2: usize>(
tensor: &B::BoolTensorPrimitive<D1>,
shape: Shape<D2>,
Expand Down

0 comments on commit 7c38a98

Please sign in to comment.