Skip to content

Commit

Permalink
Add float cast op for JIT backend (tracel-ai#2511)
Browse files Browse the repository at this point in the history
* Fix backend FloatElem docstring

* Remove elem type generic from JitTensor (float ops still using default dtype)

* Use execute_with_dtype macro for float ops

* Add cast for fusion

* Add warning for type promotion

* Add cast to backend router

* Add Primitive trait to correctly display the tensor dtype

* Fix primtiive associated type

* Fix jit bool tensor display

* Fix CI

* Fix clippy

* Whoops

* Fix candle bool tensor display

* Add shape to primitive trait

* Add primitive tests

* Rename trait to TensorMetadata

* Add missing changes to example

* Fix split from merge
  • Loading branch information
laggui authored Nov 21, 2024
1 parent a6c7a2b commit 9a2b841
Show file tree
Hide file tree
Showing 149 changed files with 2,225 additions and 1,627 deletions.
28 changes: 17 additions & 11 deletions burn-book/src/advanced/backend-extension/custom-cubecl-kernel.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,14 @@ impl<R: JitRuntime, F: FloatElement, I: IntElement> Backend for JitBackend<R, F,
.empty(shape_out.num_elements() * core::mem::size_of::<F>());
// Create the output tensor primitive.
let output =
JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer);
// Create the output tensor primitive.
let output = JitTensor::new_contiguous(
lhs.client.clone(),
lhs.device.clone(),
shape_out,
buffer,
F::dtype(),
);
// Declare the wgsl workgroup with the number of cubes in x, y and z.
let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
Expand All @@ -186,10 +192,10 @@ impl<R: JitRuntime, F: FloatElement, I: IntElement> Backend for JitBackend<R, F,
&lhs.client,
cube_count,
cube_dim,
lhs.as_tensor_arg(1),
rhs.as_tensor_arg(1),
bias.as_tensor_arg(1),
output.as_tensor_arg(1),
lhs.as_tensor_arg::<F>(1),
rhs.as_tensor_arg::<F>(1),
bias.as_tensor_arg::<F>(1),
output.as_tensor_arg::<F>(1),
);
// Return the output tensor.
Expand Down Expand Up @@ -251,12 +257,12 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
// Set our state.
let (lhs_state, rhs_state, output, shape_bias) = ops.state;
let lhs = checkpointer.retrieve_node_output(lhs_state);
let rhs = checkpointer.retrieve_node_output(rhs_state);
let lhs: FloatTensor<B> = checkpointer.retrieve_node_output(lhs_state);
let rhs: FloatTensor<B> = checkpointer.retrieve_node_output(rhs_state);
// Fetch shapes of our tensor to support broadcasting.
let shape_lhs = B::float_shape(&lhs);
let shape_rhs = B::float_shape(&rhs);
let shape_lhs = lhs.shape();
let shape_rhs = rhs.shape();
// Compute the gradient of the output using the already existing `relu_backward`
// function in the basic Burn backend trait.
Expand Down Expand Up @@ -314,7 +320,7 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
// compute bound operation.
let lhs_state = prep.checkpoint(&lhs);
let rhs_state = prep.checkpoint(&rhs);
let bias_shape = B::float_shape(&bias.primitive);
let bias_shape = bias.primitive.shape();
let output = B::fused_matmul_add_relu(
lhs.primitive.clone(),
Expand Down
21 changes: 13 additions & 8 deletions burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,19 @@ impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
.empty(shape_out.num_elements() * core::mem::size_of::<F>());
// Create the output tensor primitive.
let output =
JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer);
let output = JitTensor::new_contiguous(
lhs.client.clone(),
lhs.device.clone(),
shape_out,
buffer,
F::dtype(),
);
// Create the kernel.
let kernel = FusedMatmulAddRelu::<F>::new(cube_dim);
// Build info buffer with tensor information needed by the kernel, such as shapes and strides.
let info = build_info(&[&lhs, &rhs, &output]);
let info = build_info::<_, F>(&[&lhs, &rhs, &output]);
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
// Declare the wgsl workgroup with the number of cubes in x, y and z.
Expand Down Expand Up @@ -331,12 +336,12 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
// Set our state.
let (lhs_state, rhs_state, output, shape_bias) = ops.state;
let lhs = checkpointer.retrieve_node_output(lhs_state);
let rhs = checkpointer.retrieve_node_output(rhs_state);
let lhs: FloatTensor<B> = checkpointer.retrieve_node_output(lhs_state);
let rhs: FloatTensor<B> = checkpointer.retrieve_node_output(rhs_state);
// Fetch shapes of our tensor to support broadcasting.
let shape_lhs = B::float_shape(&lhs);
let shape_rhs = B::float_shape(&rhs);
let shape_lhs = lhs.shape();
let shape_rhs = rhs.shape();
// Compute the gradient of the output using the already existing `relu_backward`
// function in the basic Burn backend trait.
Expand Down Expand Up @@ -392,7 +397,7 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
// during the backward pass. Here we choose to save it in the state because it's a compute bound operation.
let lhs_state = prep.checkpoint(&lhs);
let rhs_state = prep.checkpoint(&rhs);
let bias_shape = B::float_shape(&bias.primitive);
let bias_shape = bias.primitive.shape();
let output = B::fused_matmul_add_relu(
lhs.primitive.clone(),
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-autodiff/src/grads.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use burn_tensor::{backend::Backend, container::TensorContainer, ops::FloatTensor};
use burn_tensor::{backend::Backend, container::TensorContainer, ops::FloatTensor, TensorMetadata};

use crate::{
graph::{NodeRef, Requirement},
Expand All @@ -22,7 +22,7 @@ impl Gradients {
};
gradients.register::<B>(
root_node.id,
B::float_ones(B::float_shape(&root_tensor), &B::float_device(&root_tensor)),
B::float_ones(root_tensor.shape(), &B::float_device(&root_tensor)),
);
gradients
}
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-autodiff/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
graph::{ComputingProperty, NodeID, NodeRef, Requirement, Step},
tensor::AutodiffTensor,
};
use burn_tensor::{backend::Backend, ops::FloatTensor, Shape};
use burn_tensor::{backend::Backend, ops::FloatTensor, Shape, TensorMetadata};
use std::marker::PhantomData;

/// Operation in preparation.
Expand Down Expand Up @@ -292,7 +292,7 @@ impl<const N: usize> Step for UntrackedOpsStep<N> {
/// If broadcasting happened during the forward pass, the gradients will be sum along the
/// broadcasted dimension.
pub fn broadcast_shape<B: Backend>(mut grad: FloatTensor<B>, shape: &Shape) -> FloatTensor<B> {
let shape_grad = B::float_shape(&grad);
let shape_grad = grad.shape();
let ndims = shape_grad.num_dims();

for i in 0..ndims {
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-autodiff/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
B::bool_from_data(data, device)
}

fn bool_shape(tensor: &BoolTensor<B>) -> Shape {
B::bool_shape(tensor)
}

async fn bool_into_data(tensor: BoolTensor<B>) -> TensorData {
B::bool_into_data(tensor).await
}
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_from_data(data, device)
}

fn int_shape(tensor: &IntTensor<B>) -> Shape {
B::int_shape(tensor)
}

async fn int_into_data(tensor: IntTensor<B>) -> TensorData {
B::int_into_data(tensor).await
}
Expand Down
Loading

0 comments on commit 9a2b841

Please sign in to comment.