Skip to content

Commit

Permalink
Add float cast tensor op (tracel-ai#2483)
Browse files Browse the repository at this point in the history
* Remove elem type generic from TchTensor

* Remove elem type generic from CandleTensor

* Add float cast tensor op

* Add float tensor cast to book

* Fix candle into_data for int dtypes

* Fix fmt

* Fix clippy

* Add dtype check to test

* Edit type promotion comment

* Clean up candle into_data

* Add candle float cast
  • Loading branch information
laggui authored Nov 18, 2024
1 parent f8c845d commit 56dc4c0
Show file tree
Hide file tree
Showing 27 changed files with 606 additions and 569 deletions.
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ Those operations are only available for `Float` tensors.
| Burn API | PyTorch Equivalent |
|-----------------------------------------------| ---------------------------------- |
| `Tensor::one_hot(index, num_classes, device)` | N/A |
| `tensor.cast(dtype)` | `tensor.to(dtype)` |
| `tensor.ceil()` | `tensor.ceil()` |
| `tensor.cos()` | `tensor.cos()` |
| `tensor.erf()` | `tensor.erf()` |
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2558,6 +2558,10 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
}

fn float_cast(tensor: FloatTensor<Self>, dtype: burn_tensor::FloatDType) -> FloatTensor<Self> {
AutodiffTensor::new(B::float_cast(tensor.primitive, dtype))
}

// TODO: Implement float_prod and float_sum
// https://github.com/tracel-ai/burn/issues/1458
}
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-candle/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {

type FullPrecisionBridge = PrecisionBridge<f32>;

type FloatTensorPrimitive = CandleTensor<Self::FloatElem>;
type FloatTensorPrimitive = CandleTensor;
type FloatElem = F;

type IntTensorPrimitive = CandleTensor<Self::IntElem>;
type IntTensorPrimitive = CandleTensor;
type IntElem = I;

type BoolTensorPrimitive = CandleTensor<u8>;
type BoolTensorPrimitive = CandleTensor;

type QuantizedTensorPrimitive = CandleQTensor;
type QuantizedEncoding = u8;
Expand Down
102 changes: 47 additions & 55 deletions crates/burn-candle/src/ops/base.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::marker::PhantomData;

use burn_tensor::{backend::Backend, Shape, TensorData};
use burn_tensor::{backend::Backend, Element, Shape, TensorData};
use candle_core::WithDType;
use half::{bf16, f16};

use crate::{
element::{CandleElement, FloatCandleElement, IntCandleElement},
Expand All @@ -9,47 +11,52 @@ use crate::{

use super::tensor;

pub fn cat<E: CandleElement>(tensors: Vec<CandleTensor<E>>, dim: usize) -> CandleTensor<E> {
pub fn cat(tensors: Vec<CandleTensor>, dim: usize) -> CandleTensor {
let tensors: Vec<candle_core::Tensor> = tensors.into_iter().map(|t| t.tensor).collect();
CandleTensor::new(candle_core::Tensor::cat(&tensors, dim).unwrap())
}

pub fn from_data<E: CandleElement>(data: TensorData, device: &CandleDevice) -> CandleTensor<E> {
CandleTensor::from_data(data, device.clone())
pub fn from_data<E: CandleElement>(data: TensorData, device: &CandleDevice) -> CandleTensor {
CandleTensor::from_data::<E>(data, device.clone())
}
pub fn into_data<E: CandleElement>(tensor: CandleTensor<E>) -> TensorData {
TensorData::new(
tensor.tensor.flatten_all().unwrap().to_vec1::<E>().unwrap(),
tensor.shape(),
)
pub fn into_data(tensor: CandleTensor) -> TensorData {
fn tensor_data_from_dtype<T: WithDType + Element>(tensor: &CandleTensor) -> TensorData {
TensorData::new(
tensor.tensor.flatten_all().unwrap().to_vec1::<T>().unwrap(),
tensor.shape(),
)
}

match tensor.tensor.dtype() {
candle_core::DType::BF16 => tensor_data_from_dtype::<bf16>(&tensor),
candle_core::DType::F16 => tensor_data_from_dtype::<f16>(&tensor),
candle_core::DType::F32 => tensor_data_from_dtype::<f32>(&tensor),
candle_core::DType::F64 => tensor_data_from_dtype::<f64>(&tensor),
candle_core::DType::U8 => tensor_data_from_dtype::<u8>(&tensor),
candle_core::DType::U32 => tensor_data_from_dtype::<u32>(&tensor),
candle_core::DType::I64 => tensor_data_from_dtype::<i64>(&tensor),
}
}

pub fn to_device<E: CandleElement>(
tensor: CandleTensor<E>,
device: &CandleDevice,
) -> CandleTensor<E> {
pub fn to_device(tensor: CandleTensor, device: &CandleDevice) -> CandleTensor {
CandleTensor::new(tensor.tensor.to_device(&(device.clone()).into()).unwrap())
}

pub fn empty<E: CandleElement>(shape: Shape, device: &CandleDevice) -> CandleTensor<E> {
pub fn empty(shape: Shape, device: &CandleDevice, dtype: candle_core::DType) -> CandleTensor {
CandleTensor::new(
candle_core::Tensor::zeros(shape.dims, E::DTYPE, &(device.clone()).into()).unwrap(),
candle_core::Tensor::zeros(shape.dims, dtype, &(device.clone()).into()).unwrap(),
)
}

pub fn swap_dims<E: CandleElement>(
mut tensor: CandleTensor<E>,
dim1: usize,
dim2: usize,
) -> CandleTensor<E> {
pub fn swap_dims(mut tensor: CandleTensor, dim1: usize, dim2: usize) -> CandleTensor {
CandleTensor::new(tensor.tensor.transpose(dim1, dim2).unwrap())
}

pub fn permute<E: CandleElement>(tensor: CandleTensor<E>, axes: &[usize]) -> CandleTensor<E> {
pub fn permute(tensor: CandleTensor, axes: &[usize]) -> CandleTensor {
CandleTensor::new(tensor.tensor.permute(axes).unwrap())
}

pub fn flip<E: CandleElement>(tensor: CandleTensor<E>, axes: &[usize]) -> CandleTensor<E> {
pub fn flip(tensor: CandleTensor, axes: &[usize]) -> CandleTensor {
// FIXME: Replace with an appropriate method when Candle provides one.
let mut tensor = tensor.tensor;
for &axis in axes {
Expand All @@ -66,22 +73,19 @@ pub fn flip<E: CandleElement>(tensor: CandleTensor<E>, axes: &[usize]) -> Candle
CandleTensor::new(tensor)
}

pub fn reshape<E: CandleElement>(tensor: CandleTensor<E>, shape: Shape) -> CandleTensor<E> {
pub fn reshape(tensor: CandleTensor, shape: Shape) -> CandleTensor {
CandleTensor::new(tensor.tensor.reshape(shape.dims).unwrap())
}

pub fn device<E: CandleElement>(tensor: &CandleTensor<E>) -> CandleDevice {
pub fn device(tensor: &CandleTensor) -> CandleDevice {
tensor.tensor.device().clone().into()
}

pub fn shape<E: CandleElement>(tensor: &CandleTensor<E>) -> Shape {
pub fn shape(tensor: &CandleTensor) -> Shape {
tensor.shape()
}

pub fn slice<E: CandleElement>(
tensor: CandleTensor<E>,
ranges: &[std::ops::Range<usize>],
) -> CandleTensor<E> {
pub fn slice(tensor: CandleTensor, ranges: &[std::ops::Range<usize>]) -> CandleTensor {
let mut narrow_tensor = tensor.tensor;
for (i, range) in ranges.iter().enumerate().take(ranges.len()) {
narrow_tensor = narrow_tensor
Expand All @@ -91,55 +95,43 @@ pub fn slice<E: CandleElement>(
CandleTensor::new(narrow_tensor)
}

pub fn slice_assign<E: CandleElement>(
tensor: CandleTensor<E>,
pub fn slice_assign(
tensor: CandleTensor,
ranges: &[std::ops::Range<usize>],
value: CandleTensor<E>,
) -> CandleTensor<E> {
value: CandleTensor,
) -> CandleTensor {
CandleTensor::new(tensor.tensor.slice_assign(ranges, &value.tensor).unwrap())
}

pub fn narrow<E: CandleElement>(
tensor: CandleTensor<E>,
dim: usize,
start: usize,
length: usize,
) -> CandleTensor<E> {
pub fn narrow(tensor: CandleTensor, dim: usize, start: usize, length: usize) -> CandleTensor {
let tensor = tensor.tensor.narrow(dim, start, length);
match tensor {
Ok(tensor) => CandleTensor::new(tensor),
Err(e) => panic!("error narrow from Candle"),
}
}

pub fn chunk<E: CandleElement>(
tensor: CandleTensor<E>,
chunks: usize,
dim: usize,
) -> Vec<CandleTensor<E>> {
pub fn chunk(tensor: CandleTensor, chunks: usize, dim: usize) -> Vec<CandleTensor> {
let tensors = tensor.tensor.chunk(chunks, dim);
match tensors {
Ok(tensors) => tensors
.into_iter()
.map(|tensor| CandleTensor::new(tensor))
.collect(),
Ok(tensors) => tensors.into_iter().map(CandleTensor::new).collect(),
Err(e) => panic!("error chunk from Candle"),
}
}

pub fn expand<E: CandleElement>(tensor: CandleTensor<E>, shape: Shape) -> CandleTensor<E> {
pub fn expand(tensor: CandleTensor, shape: Shape) -> CandleTensor {
CandleTensor::new(tensor.tensor.broadcast_as(shape.dims).unwrap())
}

pub fn sign<E: CandleElement>(tensor: CandleTensor<E>) -> CandleTensor<E> {
pub fn sign(tensor: CandleTensor) -> CandleTensor {
CandleTensor::new(tensor.tensor.sign().unwrap())
}

pub fn mask_where_broadcasted<E: CandleElement>(
tensor: CandleTensor<E>,
mask: CandleTensor<u8>,
value: CandleTensor<E>,
) -> CandleTensor<E> {
pub fn mask_where_broadcasted(
tensor: CandleTensor,
mask: CandleTensor,
value: CandleTensor,
) -> CandleTensor {
let shape = tensor
.tensor
.shape()
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-candle/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use super::base::{expand, permute};

impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<F, I> {
fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
super::base::empty(shape, device)
super::base::empty(shape, device, candle_core::DType::U8)
}

fn bool_shape(tensor: &BoolTensor<Self>) -> Shape {
Expand All @@ -27,7 +27,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<

fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
let data: TensorData = TensorData::new(data.iter::<bool>().collect(), data.shape);
super::base::from_data(data, device)
super::base::from_data::<u8>(data, device)
}

fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use super::base::{expand, permute, sign};

impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F, I> {
fn int_empty(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
super::base::empty(shape, device)
super::base::empty(shape, device, I::DTYPE)
}

fn int_shape(tensor: &IntTensor<Self>) -> Shape {
Expand All @@ -24,7 +24,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
}

fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
super::base::from_data(data, device)
super::base::from_data::<I>(data, device)
}

fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
Expand Down Expand Up @@ -251,7 +251,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F

fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
let sum = tensor.tensor.sum_all().unwrap().to_scalar::<I>().unwrap();
CandleTensor::from_data(
CandleTensor::from_data::<I>(
TensorData::new([sum].into(), [1]),
Self::int_device(&tensor),
)
Expand Down
35 changes: 25 additions & 10 deletions crates/burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::borrow::Borrow;

use burn_tensor::{
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, FullPrecisionBackend, IntTensor},
Device, Distribution, ElementConversion, Shape, TensorData,
Device, Distribution, ElementConversion, FloatDType, Shape, TensorData,
};
use candle_core::{backend::BackendStorage, shape, Tensor};

Expand All @@ -14,8 +14,8 @@ use crate::{
use super::base::{expand, permute, sign};

impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
fn float_from_data(data: TensorData, device: &Device<Self>) -> CandleTensor<F> {
CandleTensor::from_data(data, device.clone())
fn float_from_data(data: TensorData, device: &Device<Self>) -> CandleTensor {
CandleTensor::from_data::<F>(data, device.clone())
}

fn float_random(
Expand Down Expand Up @@ -52,28 +52,28 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
}
}

fn float_shape(tensor: &CandleTensor<F>) -> Shape {
fn float_shape(tensor: &CandleTensor) -> Shape {
super::base::shape(tensor)
}

async fn float_into_data(tensor: CandleTensor<F>) -> TensorData {
async fn float_into_data(tensor: CandleTensor) -> TensorData {
super::base::into_data(tensor)
}

fn float_device(tensor: &CandleTensor<F>) -> Device<Self> {
fn float_device(tensor: &CandleTensor) -> Device<Self> {
super::base::device(tensor)
}

fn float_to_device(tensor: CandleTensor<F>, device: &Device<Self>) -> CandleTensor<F> {
fn float_to_device(tensor: CandleTensor, device: &Device<Self>) -> CandleTensor {
super::base::to_device(tensor, device)
}

fn float_into_int(tensor: CandleTensor<F>) -> IntTensor<Self> {
fn float_into_int(tensor: CandleTensor) -> IntTensor<Self> {
CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap())
}

fn float_empty(shape: Shape, device: &Device<Self>) -> FloatTensor<Self> {
super::base::empty(shape, device)
super::base::empty(shape, device, F::DTYPE)
}

fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
Expand Down Expand Up @@ -298,7 +298,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle

fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let sum = tensor.tensor.sum_all().unwrap().to_scalar::<F>().unwrap();
CandleTensor::from_data(
CandleTensor::from_data::<F>(
TensorData::new([sum].into(), [1]),
Self::float_device(&tensor),
)
Expand Down Expand Up @@ -470,4 +470,19 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
sign(tensor)
}

fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
let dtype = match dtype {
FloatDType::F64 => candle_core::DType::F64,
FloatDType::F32 => candle_core::DType::F32,
FloatDType::F16 => candle_core::DType::F16,
FloatDType::BF16 => candle_core::DType::BF16,
};

if tensor.tensor.dtype() == dtype {
tensor
} else {
CandleTensor::new(tensor.tensor.to_dtype(dtype).unwrap())
}
}
}
16 changes: 5 additions & 11 deletions crates/burn-candle/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::marker::PhantomData;

use burn_tensor::{
quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy},
Element, Shape, TensorData,
Expand All @@ -9,18 +7,14 @@ use crate::{element::CandleElement, CandleDevice};

/// A tensor that uses the candle backend.
#[derive(Debug, Clone)]
pub struct CandleTensor<E: CandleElement> {
pub struct CandleTensor {
pub(crate) tensor: candle_core::Tensor,
phantom: PhantomData<E>,
}

impl<E: CandleElement> CandleTensor<E> {
impl CandleTensor {
/// Create a new tensor.
pub fn new(tensor: candle_core::Tensor) -> Self {
Self {
tensor,
phantom: PhantomData,
}
Self { tensor }
}

/// Creates a new tensor from data and a device.
Expand All @@ -33,7 +27,7 @@ impl<E: CandleElement> CandleTensor<E> {
/// # Returns
///
/// A new tensor.
pub fn from_data(data: TensorData, device: CandleDevice) -> Self {
pub fn from_data<E: CandleElement>(data: TensorData, device: CandleDevice) -> Self {
let candle_shape: candle_core::Shape = data.shape.clone().into();
let tensor = candle_core::Tensor::from_slice(
data.convert::<E>().as_slice::<E>().unwrap(),
Expand All @@ -53,7 +47,7 @@ impl<E: CandleElement> CandleTensor<E> {
pub struct CandleQTensor {
/// The quantized tensor.
// NOTE: candle does not implement `WithDType` for i8
pub qtensor: CandleTensor<u8>,
pub qtensor: CandleTensor,
/// The quantization scheme.
pub scheme: QuantizationScheme,
}
Expand Down
Loading

0 comments on commit 56dc4c0

Please sign in to comment.