Skip to content

Commit

Permalink
Enable cuda-jit in burn-core + in text classification example (tracel…
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Aug 12, 2024
1 parent 7c17e84 commit ff8d030
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 15 deletions.
16 changes: 8 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ sysinfo = "0.30.13"
systemstat = "0.2.3"

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "546f588eb0f29e23b8f5dca2c2858253dd19bb5e" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "546f588eb0f29e23b8f5dca2c2858253dd19bb5e" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4e17724fbc98de02d3cb4275e249ba660a4b2cb9" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4e17724fbc98de02d3cb4275e249ba660a4b2cb9" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl" }
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
Expand Down
1 change: 1 addition & 0 deletions crates/burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ record-backward-compat = []

test-tch = ["tch"] # To use tch during testing, default uses ndarray.
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.
test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray.

[dependencies]

Expand Down
10 changes: 9 additions & 1 deletion crates/burn-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@ pub mod backend;

extern crate alloc;

#[cfg(all(test, not(feature = "test-tch"), not(feature = "test-wgpu"),))]
#[cfg(all(
test,
not(feature = "test-tch"),
not(feature = "test-wgpu"),
not(feature = "test-cuda")
))]
pub type TestBackend = burn_ndarray::NdArray<f32>;

#[cfg(all(test, feature = "test-tch"))]
Expand All @@ -53,6 +58,9 @@ pub type TestBackend = burn_tch::LibTorch<f32>;
#[cfg(all(test, feature = "test-wgpu"))]
pub type TestBackend = burn_wgpu::Wgpu;

#[cfg(all(test, feature = "test-cuda"))]
pub type TestBackend = burn_cuda::Cuda;

#[cfg(feature = "std")]
#[cfg(test)]
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
Expand Down
6 changes: 4 additions & 2 deletions crates/burn-jit/src/fusion/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ pub enum OutputRuntimeInfo {
impl<R: JitRuntime> ExecutableKernel<R> {
/// Execute the kernel.
pub fn execute(self) {
self.client
.execute(self.kernel, self.cube_count, self.bindings)
unsafe {
self.client
.execute_unchecked(self.kernel, self.cube_count, self.bindings)
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion examples/text-classification/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ cuda-jit = ["burn/cuda-jit"]

[dependencies]
# Burn
burn = {path = "../../crates/burn", features=["train", "ndarray", "std", "tui", "metrics", "autotune", "fusion", "default"], default-features = false}
burn = {path = "../../crates/burn", features=["train", "ndarray", "std", "metrics", "autotune", "fusion", "default"], default-features = false}

# Tokenizer
tokenizers = { version = "0.19.1", default-features = false, features = [
Expand Down
2 changes: 1 addition & 1 deletion examples/text-classification/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use std::sync::Arc;
pub struct ExperimentConfig {
pub transformer: TransformerEncoderConfig,
pub optimizer: AdamConfig,
#[config(default = 512)]
#[config(default = 256)]
pub max_seq_length: usize,
#[config(default = 32)]
pub batch_size: usize,
Expand Down

0 comments on commit ff8d030

Please sign in to comment.