diff --git a/Cargo.lock b/Cargo.lock index 6405a4706e..d86155642d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1379,7 +1379,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.1.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1390,7 +1390,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.1.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9" dependencies = [ "derive-new", "getrandom", @@ -1404,7 +1404,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.1.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9" dependencies = [ "bytemuck", "cubecl-macros", @@ -1419,7 +1419,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.1.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9" dependencies = [ "bytemuck", "cubecl-common", @@ -1434,7 +1434,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.1.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9" dependencies = [ "bytemuck", "cubecl-core", @@ -1445,7 +1445,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.1.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9" dependencies = [ "derive-new", "proc-macro2", @@ -1456,7 +1456,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.1.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9" dependencies = [ "async-channel", "cubecl-common", @@ -1475,7 +1475,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.1.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=546f588eb0f29e23b8f5dca2c2858253dd19bb5e#546f588eb0f29e23b8f5dca2c2858253dd19bb5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9" dependencies = [ "async-channel", "bytemuck", diff --git a/Cargo.toml b/Cargo.toml index e145288849..f268ee4b81 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -143,8 +143,8 @@ sysinfo = "0.30.13" systemstat = "0.2.3" ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "546f588eb0f29e23b8f5dca2c2858253dd19bb5e" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "546f588eb0f29e23b8f5dca2c2858253dd19bb5e" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4e17724fbc98de02d3cb4275e249ba660a4b2cb9" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4e17724fbc98de02d3cb4275e249ba660a4b2cb9" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl" } # cubecl-common = { path = "../cubecl/crates/cubecl-common" } diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index a13700084a..e969ee7d41 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -99,6 +99,7 @@ record-backward-compat = [] test-tch = ["tch"] # To use tch during testing, default uses ndarray. test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray. +test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray. [dependencies] diff --git a/crates/burn-core/src/lib.rs b/crates/burn-core/src/lib.rs index 4bfb8f900f..de80e8a25b 100644 --- a/crates/burn-core/src/lib.rs +++ b/crates/burn-core/src/lib.rs @@ -44,7 +44,12 @@ pub mod backend; extern crate alloc; -#[cfg(all(test, not(feature = "test-tch"), not(feature = "test-wgpu"),))] +#[cfg(all( + test, + not(feature = "test-tch"), + not(feature = "test-wgpu"), + not(feature = "test-cuda") +))] pub type TestBackend = burn_ndarray::NdArray; #[cfg(all(test, feature = "test-tch"))] @@ -53,6 +58,9 @@ pub type TestBackend = burn_tch::LibTorch; #[cfg(all(test, feature = "test-wgpu"))] pub type TestBackend = burn_wgpu::Wgpu; +#[cfg(all(test, feature = "test-cuda"))] +pub type TestBackend = burn_cuda::Cuda; + #[cfg(feature = "std")] #[cfg(test)] pub type TestAutodiffBackend = burn_autodiff::Autodiff; diff --git a/crates/burn-jit/src/fusion/kernel.rs b/crates/burn-jit/src/fusion/kernel.rs index 5f7cf791d2..da19180ed6 100644 --- a/crates/burn-jit/src/fusion/kernel.rs +++ b/crates/burn-jit/src/fusion/kernel.rs @@ -70,8 +70,10 @@ pub enum OutputRuntimeInfo { impl ExecutableKernel { /// Execute the kernel. pub fn execute(self) { - self.client - .execute(self.kernel, self.cube_count, self.bindings) + unsafe { + self.client + .execute_unchecked(self.kernel, self.cube_count, self.bindings) + } } } diff --git a/examples/text-classification/Cargo.toml b/examples/text-classification/Cargo.toml index fe9eff6707..80205bf65f 100644 --- a/examples/text-classification/Cargo.toml +++ b/examples/text-classification/Cargo.toml @@ -20,7 +20,7 @@ cuda-jit = ["burn/cuda-jit"] [dependencies] # Burn -burn = {path = "../../crates/burn", features=["train", "ndarray", "std", "tui", "metrics", "autotune", "fusion", "default"], default-features = false} +burn = {path = "../../crates/burn", features=["train", "ndarray", "std", "metrics", "autotune", "fusion", "default"], default-features = false} # Tokenizer tokenizers = { version = "0.19.1", default-features = false, features = [ diff --git a/examples/text-classification/src/training.rs b/examples/text-classification/src/training.rs index 05b2da12c2..f2b31ae4d1 100644 --- a/examples/text-classification/src/training.rs +++ b/examples/text-classification/src/training.rs @@ -29,7 +29,7 @@ use std::sync::Arc; pub struct ExperimentConfig { pub transformer: TransformerEncoderConfig, pub optimizer: AdamConfig, - #[config(default = 512)] + #[config(default = 256)] pub max_seq_length: usize, #[config(default = 32)] pub batch_size: usize,