Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove GraphicsAPI generic for WgpuRuntime #1888

Merged
merged 4 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,9 @@ macro_rules! bench_on_backend {

#[cfg(feature = "wgpu")]
{
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};

bench::<Wgpu<AutoGraphicsApi, f32, i32>>(
&WgpuDevice::default(),
feature_name,
url,
token,
);
use burn::backend::wgpu::{Wgpu, WgpuDevice};

bench::<Wgpu<f32, i32>>(&WgpuDevice::default(), feature_name, url, token);
}

#[cfg(feature = "tch-gpu")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ the raw `WgpuBackend` type.

```rust, ignore
/// Implement our custom backend trait for the existing backend `WgpuBackend`.
impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime<G>, F, I> {
impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
fn fused_matmul_add_relu<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
Expand Down
11 changes: 5 additions & 6 deletions burn-book/src/basic-workflow/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ entrypoint of our program, namely the `main` function defined in `src/main.rs`.
#
use crate::{model::ModelConfig, training::TrainingConfig};
use burn::{
backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu},
backend::{Autodiff, Wgpu},
# data::dataset::Dataset,
optim::AdamConfig,
};

fn main() {
type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
type MyBackend = Wgpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>;

let device = burn::backend::wgpu::WgpuDevice::default();
Expand All @@ -32,10 +32,9 @@ fn main() {

In this example, we use the `Wgpu` backend which is compatible with any operating system and will
use the GPU. For other options, see the Burn README. This backend type takes the graphics API, the
float type and the int type as generic arguments that will be used during the training. By leaving
the graphics API as `AutoGraphicsApi`, it should automatically use an API available on your machine.
The autodiff backend is simply the same backend, wrapped within the `Autodiff` struct which imparts
differentiability to any backend.
float type and the int type as generic arguments that will be used during the training. The autodiff
backend is simply the same backend, wrapped within the `Autodiff` struct which imparts differentiability \
to any backend.

We call the `train` function defined earlier with a directory for artifacts, the configuration of
the model (the number of digit classes is 10 and the hidden dimension is 512), the optimizer
Expand Down
4 changes: 2 additions & 2 deletions burn-book/src/basic-workflow/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ Add the call to `infer` to the `main.rs` file after the `train` function call:
#
# use crate::{model::ModelConfig, training::TrainingConfig};
# use burn::{
# backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu},
# backend::{Autodiff, Wgpu},
# data::dataset::Dataset,
# optim::AdamConfig,
# };
#
# fn main() {
# type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
# type MyBackend = Wgpu<f32, i32>;
# type MyAutodiffBackend = Autodiff<MyBackend>;
#
# let device = burn::backend::wgpu::WgpuDevice::default();
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-wgpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ The backend supports Vulkan, Metal, DirectX11/12, OpenGL, WebGPU.
#[cfg(feature = "wgpu")]
mod wgpu {
use burn_autodiff::Autodiff;
use burn_wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
use burn_wgpu::{Wgpu, WgpuDevice};
use mnist::training;

pub fn run() {
let device = WgpuDevice::default();
training::run::<Autodiff<Wgpu<AutoGraphicsApi, f32, i32>>>(device);
training::run::<Autodiff<Wgpu<f32, i32>>>(device);
}
}
```
Expand Down
37 changes: 33 additions & 4 deletions crates/burn-wgpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,29 @@ pub use burn_jit::{tensor::JitTensor, JitBackend};
/// - [Metal] on Apple hardware.
/// - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes.
///
/// To configure the wgpu backend, eg. to select what graphics API to use or what memory strategy to use,
/// you have to manually initialize the runtime. For example:
///
/// ```rust, ignore
/// fn custom_init() {
/// let device = Default::default();
/// burn::backend::wgpu::init_sync::<burn::backend::wgpu::Vulkan>(
/// &device,
/// Default::default(),
/// );
/// }
/// ```
/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.
/// It's also possible to use an existing wgpu device, by using `init_existing_device`.
///
/// # Notes
///
/// This version of the [wgpu] backend uses [burn_fusion] to compile and optimize streams of tensor
/// operations for improved performance.
///
/// You can disable the `fusion` feature flag to remove that functionality, which might be
/// necessary on `wasm` for now.
pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> =
burn_fusion::Fusion<JitBackend<WgpuRuntime<G>, F, I>>;
pub type Wgpu<F = f32, I = i32> = burn_fusion::Fusion<JitBackend<WgpuRuntime, F, I>>;

#[cfg(not(feature = "fusion"))]
/// Tensor backend that uses the [wgpu] crate for executing GPU compute shaders.
Expand All @@ -57,20 +71,35 @@ pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> =
/// - [Metal] on Apple hardware.
/// - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes.
///
/// To configure the wgpu backend, eg. to select what graphics API to use or what memory strategy to use,
/// you have to manually initialize the runtime. For example:
///
/// ```rust, ignore
/// fn custom_init() {
/// let device = Default::default();
/// burn::backend::wgpu::init_sync::<burn::backend::wgpu::Vulkan>(
/// &device,
/// Default::default(),
/// );
/// }
/// ```
/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.
/// It's also possible to use an existing wgpu device, by using `init_existing_device`.
///
/// # Notes
///
/// This version of the [wgpu] backend doesn't use [burn_fusion] to compile and optimize streams of tensor
/// operations.
///
/// You can enable the `fusion` feature flag to add that functionality, which might improve
/// performance.
pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> = JitBackend<WgpuRuntime<G>, F, I>;
pub type Wgpu<F = f32, I = i32> = JitBackend<WgpuRuntime, F, I>;

#[cfg(test)]
mod tests {
use super::*;

pub type TestRuntime = crate::WgpuRuntime<AutoGraphicsApi>;
pub type TestRuntime = crate::WgpuRuntime;

burn_jit::testgen_all!();
burn_cube::testgen_all!();
Expand Down
33 changes: 14 additions & 19 deletions crates/burn-wgpu/src/runtime.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
compiler::wgsl,
compute::{WgpuServer, WgpuStorage},
GraphicsApi, WgpuDevice,
AutoGraphicsApi, GraphicsApi, WgpuDevice,
};
use alloc::sync::Arc;
use burn_common::stub::RwLock;
Expand All @@ -15,21 +15,16 @@ use burn_compute::{
use burn_cube::Runtime;
use burn_jit::JitRuntime;
use burn_tensor::backend::{DeviceId, DeviceOps};
use std::{
marker::PhantomData,
sync::atomic::{AtomicBool, Ordering},
};
use std::sync::atomic::{AtomicBool, Ordering};
use wgpu::{AdapterInfo, DeviceDescriptor};

/// Runtime that uses the [wgpu] crate with the wgsl compiler.
///
/// The [graphics api](GraphicsApi) type is passed as generic.
/// Runtime that uses the [wgpu] crate with the wgsl compiler. This is used in the Wgpu backend.
/// For advanced configuration, use [`init_sync`] to pass in runtime options or to select a
/// specific graphics API.
#[derive(Debug)]
pub struct WgpuRuntime<G: GraphicsApi> {
_g: PhantomData<G>,
}
pub struct WgpuRuntime {}

impl<G: GraphicsApi> JitRuntime for WgpuRuntime<G> {
impl JitRuntime for WgpuRuntime {
type JitDevice = WgpuDevice;
type JitServer = WgpuServer<SimpleMemoryManagement<WgpuStorage>>;
}
Expand All @@ -42,7 +37,7 @@ type Server = WgpuServer<SimpleMemoryManagement<WgpuStorage>>;

static SUBGROUP: AtomicBool = AtomicBool::new(false);

impl<G: GraphicsApi> Runtime for WgpuRuntime<G> {
impl Runtime for WgpuRuntime {
type Compiler = wgsl::WgslCompiler;
type Server = WgpuServer<SimpleMemoryManagement<WgpuStorage>>;

Expand All @@ -51,7 +46,8 @@ impl<G: GraphicsApi> Runtime for WgpuRuntime<G> {

fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
RUNTIME.client(device, move || {
let (adapter, device_wgpu, queue) = pollster::block_on(create_wgpu_setup::<G>(device));
let (adapter, device_wgpu, queue) =
pollster::block_on(create_wgpu_setup::<AutoGraphicsApi>(device));
Comment on lines +49 to +50
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I understand it well, this is where the default runtime is actually created for a device type. It uses the AutoGraphicsApi to select the most appropriate adaptor for the device. If you want to specify manually the graphics API for a device type, you have to use init_sync or init_async!

I don't think this is a bad thing, now the WgpuBackend only has typical backend generics, and runtime options can be changed using the init functions! Before merging I would add a bit more documentation on the WgpuRuntime and WgpuBackend pointing to the right direction to specify backend and runtime options.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep that's right! Seemed coherent to me to have one "advanced initialize" place, glad you agree :)

Have added some docs now, and a random code cleanup, lmk how this looks!

create_client(adapter, device_wgpu, queue, RuntimeOptions::default())
})
}
Expand Down Expand Up @@ -125,14 +121,13 @@ pub fn init_existing_device(
device_id
}

/// Init the client sync, useful to configure the runtime options.
/// Initialize a client on the given device with the given options. This function is useful to configure the runtime options
/// or to pick a different graphics API. On wasm, it is necessary to use [`init_async`] instead.
pub fn init_sync<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) {
let (adapter, device_wgpu, queue) = pollster::block_on(create_wgpu_setup::<G>(device));
let client = create_client(adapter, device_wgpu, queue, options);
RUNTIME.register(device, client)
pollster::block_on(init_async::<G>(device, options));
}

/// Init the client async, necessary for wasm.
/// Like [`init_sync`], but async, necessary for wasm.
pub async fn init_async<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) {
let (adapter, device_wgpu, queue) = create_wgpu_setup::<G>(device).await;
let client = create_client(adapter, device_wgpu, queue, options);
Expand Down
4 changes: 2 additions & 2 deletions examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use burn::{
backend::wgpu::{AutoGraphicsApi, WgpuRuntime},
backend::wgpu::WgpuRuntime,
tensor::{Distribution, Tensor},
};
use custom_wgpu_kernel::{
Expand Down Expand Up @@ -71,7 +71,7 @@ fn autodiff<B: AutodiffBackend>(device: &B::Device) {
}

fn main() {
type MyBackend = burn::backend::wgpu::JitBackend<WgpuRuntime<AutoGraphicsApi>, f32, i32>;
type MyBackend = burn::backend::wgpu::JitBackend<WgpuRuntime, f32, i32>;
type MyAutodiffBackend = burn::backend::Autodiff<MyBackend>;
let device = Default::default();
inference::<MyBackend>(&device);
Expand Down
7 changes: 2 additions & 5 deletions examples/custom-wgpu-kernel/src/backward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,12 @@ use burn::{
ops::{broadcast_shape, Backward, Ops, OpsKind},
Autodiff, NodeID,
},
wgpu::{FloatElement, GraphicsApi, IntElement, JitBackend, WgpuRuntime},
wgpu::{FloatElement, IntElement, JitBackend, WgpuRuntime},
},
tensor::Shape,
};

impl<G: GraphicsApi, F: FloatElement, I: IntElement> AutodiffBackend
for Autodiff<JitBackend<WgpuRuntime<G>, F, I>>
{
}
impl<F: FloatElement, I: IntElement> AutodiffBackend for Autodiff<JitBackend<WgpuRuntime, F, I>> {}

// Implement our custom backend trait for any backend that also implements our custom backend trait.
//
Expand Down
6 changes: 3 additions & 3 deletions examples/custom-wgpu-kernel/src/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use crate::FloatTensor;
use super::Backend;
use burn::{
backend::wgpu::{
build_info, into_contiguous, kernel_wgsl, CubeCount, CubeDim, FloatElement, GraphicsApi,
IntElement, JitBackend, JitTensor, KernelSource, SourceKernel, SourceTemplate, WgpuRuntime,
build_info, into_contiguous, kernel_wgsl, CubeCount, CubeDim, FloatElement, IntElement,
JitBackend, JitTensor, KernelSource, SourceKernel, SourceTemplate, WgpuRuntime,
},
tensor::Shape,
};
Expand Down Expand Up @@ -36,7 +36,7 @@ impl<E: FloatElement> KernelSource for FusedMatmulAddRelu<E> {
}

/// Implement our custom backend trait for the existing backend `WgpuBackend`.
impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime<G>, F, I> {
impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
fn fused_matmul_add_relu<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
Expand Down
4 changes: 2 additions & 2 deletions examples/guide/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ mod training;

use crate::{model::ModelConfig, training::TrainingConfig};
use burn::{
backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu},
backend::{Autodiff, Wgpu},
data::dataset::Dataset,
optim::AdamConfig,
};

fn main() {
type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
type MyBackend = Wgpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>;

let device = burn::backend::wgpu::WgpuDevice::default();
Expand Down
2 changes: 1 addition & 1 deletion examples/image-classification-web/src/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub enum ModelType {
WithNdArrayBackend(Model<NdArray<f32>>),

/// The model is loaded to the Wgpu backend
WithWgpuBackend(Model<Wgpu<AutoGraphicsApi, f32, i32>>),
WithWgpuBackend(Model<Wgpu<f32, i32>>),
}

/// The image is 224x224 pixels with 3 channels (RGB)
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist-inference-web/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use burn::{
use burn::backend::wgpu::{init_async, AutoGraphicsApi, Wgpu, WgpuDevice};

#[cfg(feature = "wgpu")]
pub type Backend = Wgpu<AutoGraphicsApi, f32, i32>;
pub type Backend = Wgpu<f32, i32>;

#[cfg(all(feature = "ndarray", not(feature = "wgpu")))]
pub type Backend = burn::backend::ndarray::NdArray<f32>;
Expand Down
4 changes: 2 additions & 2 deletions examples/text-classification/examples/ag-news-infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ mod tch_cpu {
#[cfg(feature = "wgpu")]
mod wgpu {
use crate::{launch, ElemType};
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
use burn::backend::wgpu::{Wgpu, WgpuDevice};

pub fn run() {
launch::<Wgpu<AutoGraphicsApi, ElemType, i32>>(WgpuDevice::default());
launch::<Wgpu<ElemType, i32>>(WgpuDevice::default());
}
}

Expand Down
4 changes: 2 additions & 2 deletions examples/text-classification/examples/ag-news-train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ mod tch_cpu {
mod wgpu {
use crate::{launch, ElemType};
use burn::backend::{
wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice},
wgpu::{Wgpu, WgpuDevice},
Autodiff,
};

pub fn run() {
launch::<Autodiff<Wgpu<AutoGraphicsApi, ElemType, i32>>>(vec![WgpuDevice::default()]);
launch::<Autodiff<Wgpu<ElemType, i32>>>(vec![WgpuDevice::default()]);
}
}

Expand Down
4 changes: 2 additions & 2 deletions examples/text-classification/examples/db-pedia-infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,14 @@ mod tch_cpu {
#[cfg(feature = "wgpu")]
mod wgpu {
use burn::backend::{
wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice},
wgpu::{Wgpu, WgpuDevice},
Autodiff,
};

use crate::{launch, ElemType};

pub fn run() {
launch::<Autodiff<Wgpu<AutoGraphicsApi, ElemType, i32>>>(WgpuDevice::default());
launch::<Autodiff<Wgpu<ElemType, i32>>>(WgpuDevice::default());
}
}

Expand Down
4 changes: 2 additions & 2 deletions examples/text-classification/examples/db-pedia-train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ mod tch_cpu {
#[cfg(feature = "wgpu")]
mod wgpu {
use burn::backend::{
wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice},
wgpu::{Wgpu, WgpuDevice},
Autodiff,
};

use crate::{launch, ElemType};

pub fn run() {
launch::<Autodiff<Wgpu<AutoGraphicsApi, ElemType, i32>>>(vec![WgpuDevice::default()]);
launch::<Autodiff<Wgpu<ElemType, i32>>>(vec![WgpuDevice::default()]);
}
}

Expand Down
Loading