Skip to content

Commit

Permalink
Update model.bin mnist inference web + add cuda-jit flag for ag-news-…
Browse files Browse the repository at this point in the history
…infer (tracel-ai#2170)

* Update model.bin mnist inference web

* Add cuda-jit flag for ag-news-infer
  • Loading branch information
laggui authored Aug 19, 2024
1 parent 2755c36 commit 784f57b
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 1 deletion.
Binary file modified examples/mnist-inference-web/model.bin
Binary file not shown.
13 changes: 13 additions & 0 deletions examples/text-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,16 @@ cargo run --example ag-news-infer --release --features wgpu # Run inference on
cargo run --example db-pedia-train --release --features wgpu # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features wgpu # Run inference db pedia dataset
```

## CUDA backend

```bash
git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.

# AG News
cargo run --example ag-news-train --release --features cuda-jit # Train on the ag news dataset
cargo run --example ag-news-infer --release --features cuda-jit # Run inference on the ag news dataset
```
12 changes: 12 additions & 0 deletions examples/text-classification/examples/ag-news-infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ mod wgpu {
}
}

#[cfg(feature = "cuda-jit")]
mod cuda_jit {
use crate::{launch, ElemType};
use burn::backend::{cuda_jit::CudaDevice, CudaJit};

pub fn run() {
launch::<CudaJit<ElemType, i32>>(CudaDevice::default());
}
}

fn main() {
#[cfg(any(
feature = "ndarray",
Expand All @@ -95,4 +105,6 @@ fn main() {
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
#[cfg(feature = "cuda-jit")]
cuda_jit::run();
}
2 changes: 1 addition & 1 deletion examples/text-classification/examples/ag-news-train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ mod cuda_jit {
use burn::backend::{Autodiff, CudaJit};

pub fn run() {
launch::<Autodiff<CudaJit>>(vec![Default::default()]);
launch::<Autodiff<CudaJit<ElemType, i32>>>(vec![Default::default()]);
}
}

Expand Down

0 comments on commit 784f57b

Please sign in to comment.