Skip to content

Commit

Permalink
make candle available (#886)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Oct 23, 2023
1 parent 07c0cf1 commit e4d9d67
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 7 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ resolver = "2"
members = [
"burn",
"burn-autodiff",
"burn-candle",
"burn-common",
"burn-compute",
"burn-core",
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ simplifying the process of experimenting, training, and deploying models.
[`no_std`](#support-for-no_std) compatibility, ensuring universal platform adaptability 👌
- [WebGPU](https://github.com/burn-rs/burn/tree/main/burn-wgpu) backend, offering cross-platform,
browser-inclusive, GPU-based computations 🌐
- [Candle](https://github.com/burn-rs/burn/tree/main/burn-candle) backend (alpha) 🕯️
- [Candle](https://github.com/burn-rs/burn/tree/main/burn-candle) backend 🕯️
- [Autodiff](https://github.com/burn-rs/burn/tree/main/burn-autodiff) backend that enables
differentiability across all backends 🌟
- [Dataset](https://github.com/burn-rs/burn/tree/main/burn-dataset) crate containing a diverse range
Expand Down
1 change: 1 addition & 0 deletions backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ version = "0.10.0"
[features]
default = ["std"]
std = []
candle = ["burn/candle"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"]
ndarray-blas-netlib = ["burn/ndarray-blas-netlib"]
Expand Down
9 changes: 9 additions & 0 deletions backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,14 @@ macro_rules! bench_on_backend {
let device = NdArrayDevice::Cpu;
bench::<NdArrayBackend>(&device);
}

#[cfg(feature = "candle")]
{
use burn::backend::candle::CandleDevice;
use burn::backend::CandleBackend;

let device = CandleDevice::Cpu;
bench::<CandleBackend>(&device);
}
};
}
3 changes: 3 additions & 0 deletions burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ wgpu = ["burn-wgpu/default"]

tch = ["burn-tch"]

candle = ["burn-candle"]

# Serialization formats
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]

Expand All @@ -72,6 +74,7 @@ burn-ndarray = { path = "../burn-ndarray", version = "0.10.0", optional = true,
burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", optional = true }
burn-wgpu = { path = "../burn-wgpu", version = "0.10.0", optional = true }
burn-tch = { path = "../burn-tch", version = "0.10.0", optional = true }
burn-candle = { path = "../burn-candle", version = "0.10.0", optional = true }

derive-new = { workspace = true }
libm = { workspace = true }
Expand Down
12 changes: 12 additions & 0 deletions burn-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ pub type WgpuBackend<G = wgpu::AutoGraphicsApi, F = f32, I = i32> = wgpu::WgpuBa
pub type WgpuAutodiffBackend<G = wgpu::AutoGraphicsApi, F = f32, I = i32> =
crate::autodiff::ADBackendDecorator<WgpuBackend<G, F, I>>;

#[cfg(feature = "candle")]
/// Candle module.
pub use burn_candle as candle;

#[cfg(feature = "candle")]
/// A CandleBackend with a default type of f32/i64.
pub type CandleBackend = candle::CandleBackend<f32, i64>;

#[cfg(all(feature = "candle", feature = "autodiff"))]
/// A CandleBackend with autodiffing enabled.
pub type CandleAutodiffBackend = crate::autodiff::ADBackendDecorator<CandleBackend>;

#[cfg(feature = "tch")]
/// Tch module.
pub use burn_tch as tch;
Expand Down
2 changes: 1 addition & 1 deletion burn-wgpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-wgpu.svg)](https://crates.io/crates/burn-wgpu)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/burn-rs/burn-wgpu/blob/master/README.md)

This crate provides a WGPU backend for [Burn](https://github.com/burn-rs/burn) utilizing the
This crate provides a WGPU backend for [Burn](https://github.com/burn-rs/burn) using the
[wgpu](https://github.com/gfx-rs/wgpu).

The backend supports Vulkan, Metal, DirectX11/12, OpenGL, WebGPU.
Expand Down
5 changes: 3 additions & 2 deletions burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ ndarray-blas-openblas-system = ["burn-core/ndarray-blas-openblas-system"]

wgpu = ["burn-core/wgpu"]
tch = ["burn-core/tch"]
candle = ["burn-core/candle"]

# Experimental
experimental-named-tensor = ["burn-core/experimental-named-tensor"]
Expand All @@ -53,8 +54,8 @@ experimental-named-tensor = ["burn-core/experimental-named-tensor"]

# ** Please make sure all dependencies support no_std when std is disabled **

burn-core = {path = "../burn-core", version = "0.10.0", default-features = false}
burn-train = {path = "../burn-train", version = "0.10.0", optional = true, default-features = false }
burn-core = { path = "../burn-core", version = "0.10.0", default-features = false }
burn-train = { path = "../burn-train", version = "0.10.0", optional = true, default-features = false }

[package.metadata.docs.rs]
all-features = true
6 changes: 3 additions & 3 deletions examples/mnist-inference-web/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ ndarray = ["burn/ndarray-no-std"]
wgpu = ["burn/wgpu"]

[dependencies]
burn = {path = "../../burn", default-features = false}
serde = {workspace = true}
wasm-bindgen = { version = "0.2.87" }
burn = { path = "../../burn", default-features = false }
serde = { workspace = true }
wasm-bindgen = { version = "0.2.87" }
wasm-bindgen-futures = "0.4"
js-sys = "0.3.64"

Expand Down

0 comments on commit e4d9d67

Please sign in to comment.