Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable burn fusion by default #1223

Merged
merged 11 commits into from
Feb 8, 2024
Prev Previous commit
Next Next commit
Update book
  • Loading branch information
nathanielsimard committed Feb 7, 2024
commit 058481b24970d60c632cead1ebec403c49527b59
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,12 @@ impl<E: FloatElement> DynamicKernel for FusedMatmulAddRelu<E> {
```

Subsequently, we'll go into implementing our custom backend trait for the WGPU backend.
Note that we won't go into supporting the `fusion` feature flag in this tutorial, so
we implement the trait for the raw `WgpuBackend` type.

```rust, ignore
/// Implement our custom backend trait for the existing backend `Wgpu`.
impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for Wgpu<G, F, I> {
/// Implement our custom backend trait for the existing backend `WgpuBackend`.
impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for WgpuBackend<G, F, I> {
fn fused_matmul_add_relu<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
Expand Down Expand Up @@ -419,7 +421,7 @@ operation nodes.
The only remaining part is to implement our autodiff-decorated backend trait for our WGPU Backend.

```rust, ignore
impl<G: GraphicsApi, F: FloatElement, I: IntElement> AutodiffBackend for Autodiff<Wgpu<G, F, I>>
impl<G: GraphicsApi, F: FloatElement, I: IntElement> AutodiffBackend for Autodiff<WgpuBackend<G, F, I>>
{
}
```
Expand Down
8 changes: 4 additions & 4 deletions burn-wgpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ mod fusion;
/// Tensor backend that uses the [wgpu] crate for executing GPU compute shaders.
///
/// This backend can target multiple graphics APIs, including:
/// - [Vulkan](crate::Vulkan) on Linux, Windows, and Android.
/// - [Vulkan] on Linux, Windows, and Android.
/// - [OpenGL](crate::OpenGl) on Linux, Windows, and Android.
/// - [DirectX 12](crate::Dx12) on Windows.
/// - [Metal](crate::Metal) on Apple hardware.
/// - [Metal] on Apple hardware.
/// - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes.
///
/// # Notes
Expand All @@ -56,10 +56,10 @@ pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> = burn_fusion::Fusion<WgpuB
/// Tensor backend that uses the [wgpu] crate for executing GPU compute shaders.
///
/// This backend can target multiple graphics APIs, including:
/// - [Vulkan](crate::Vulkan) on Linux, Windows, and Android.
/// - [Vulkan] on Linux, Windows, and Android.
/// - [OpenGL](crate::OpenGl) on Linux, Windows, and Android.
/// - [DirectX 12](crate::Dx12) on Windows.
/// - [Metal](crate::Metal) on Apple hardware.
/// - [Metal] on Apple hardware.
/// - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes.
///
/// # Notes
Expand Down