Skip to content

Commit

Permalink
Add image classification web demo with WebGPU, CPU backends (tracel-a…
Browse files Browse the repository at this point in the history
  • Loading branch information
antimora authored Oct 5, 2023
1 parent 28e2a99 commit e2a17e4
Show file tree
Hide file tree
Showing 31 changed files with 2,045 additions and 63 deletions.
75 changes: 40 additions & 35 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,32 @@
resolver = "2"

members = [
"burn",
"burn-autodiff",
"burn-common",
"burn-compute",
"burn-core",
"burn-dataset",
"burn-derive",
"burn-import",
"burn-import/onnx-tests",
"burn-ndarray",
"burn-no-std-tests",
"burn-tch",
"burn-wgpu",
"burn-candle",
"burn-tensor-testgen",
"burn-tensor",
"burn-train",
"xtask",
"examples/*",
"backend-comparison",
"burn",
"burn-autodiff",
"burn-common",
"burn-compute",
"burn-core",
"burn-dataset",
"burn-derive",
"burn-import",
"burn-import/onnx-tests",
"burn-ndarray",
"burn-no-std-tests",
"burn-tch",
"burn-wgpu",
"burn-candle",
"burn-tensor-testgen",
"burn-tensor",
"burn-train",
"xtask",
"examples/*",
"backend-comparison",
]

exclude = ["examples/notebook"]

[workspace.dependencies]
async-trait = "0.1.73"
bytemuck = "1.13"
const-random = "0.1.15"
csv = "1.2.2"
Expand All @@ -37,11 +38,12 @@ dirs = "5.0.1"
fake = "2.6.1"
flate2 = "1.0.26"
float-cmp = "0.9.0"
getrandom = { version = "0.2.10", default-features = false }
gix-tempfile = { version = "8.0.0", features = ["signals"] }
hashbrown = "0.14.0"
indicatif = "0.17.5"
libm = "0.2.7"
log = "0.4.19"
log = { default-features = false, version = "0.4.19" }
pretty_assertions = "1.3"
proc-macro2 = "1.0.60"
protobuf-codegen = "3.2"
Expand All @@ -55,15 +57,17 @@ rusqlite = { version = "0.29" }
sanitize-filename = "0.5.0"
serde_rusqlite = "0.33.1"
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
strum = "0.24"
strum_macros = "0.24"
strum = "0.25.0"
strum_macros = "0.25.2"
syn = { version = "2.0", features = ["full", "extra-traits"] }
tempfile = "3.6.0"
thiserror = "1.0.40"
tracing-subscriber = "0.3.17"
tracing-core = "0.1.31"
tracing-appender = "0.2.2"
async-trait = "0.1.73"
tracing-core = "0.1.31"
tracing-subscriber = "0.3.17"
wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37"
wasm-logger = "0.2.0"

# WGPU stuff
futures-intrusive = "0.5"
Expand All @@ -75,26 +79,27 @@ wgpu = "0.17.0"
# The following packages disable the "std" feature for no_std compatibility
#
bincode = { version = "2.0.0-rc.3", features = [
"alloc",
"serde",
"alloc",
"serde",
], default-features = false }
derive-new = { version = "0.5.9", default-features = false }

half = { version = "2.3.1", features = [
"alloc",
"num-traits",
"serde",
"alloc",
"num-traits",
"serde",
], default-features = false }
ndarray = { version = "0.15.6", default-features = false }
num-traits = { version = "0.2.15", default-features = false, features = [
"libm",
"libm",
] } # libm is for no_std
rand = { version = "0.8.5", default-features = false, features = [
"std_rng",
"std_rng",
] } # std_rng is for no_std
rand_distr = { version = "0.4.3", default-features = false }
serde = { version = "1.0.164", default-features = false, features = [
"derive",
"alloc",
"derive",
"alloc",
] } # alloc is for no_std, derive is needed
serde_json = { version = "1.0.96", default-features = false }
uuid = { version = "1.3.4", default-features = false }
Expand Down
11 changes: 5 additions & 6 deletions _typos.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
[default]
extend-ignore-identifiers-re = [
"ratatui",
"NdArray*",
"ND"
]
extend-ignore-identifiers-re = ["ratatui", "NdArray*", "ND"]

[files]
extend-exclude = ["assets/ModuleSerialization.xml"]
extend-exclude = [
"assets/ModuleSerialization.xml",
"examples/image-classification-web/src/model/label.txt",
]
4 changes: 2 additions & 2 deletions burn-candle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ version = "0.10.0"

[dependencies]
derive-new = { workspace = true }
burn-tensor = { path = "../burn-tensor", version = "0.10.0" }
half = { workspace = true, features = ["std"] }
burn-tensor = { path = "../burn-tensor", version = "0.10.0", default-features = false }
half = { workspace = true }
# candle-core = { version = "0.1.2" }
candle-core = { git = "https://github.com/huggingface/candle", rev = "237323c" }

Expand Down
2 changes: 1 addition & 1 deletion burn-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ std = ["rand/std"]

[target.'cfg(target_family = "wasm")'.dependencies]
async-trait = { workspace = true }
getrandom = { version = "0.2", features = ["js"] }
getrandom = { workspace = true, features = ["js"] }

[dependencies]
# ** Please make sure all dependencies support no_std when std is disabled **
Expand Down
41 changes: 24 additions & 17 deletions burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@ version = "0.10.0"
[features]
default = ["std", "dataset-minimal"]
std = [
"burn-common/std",
"burn-tensor/std",
"flate2",
"log",
"rand/std",
"rmp-serde",
"serde/std",
"serde_json/std",
"bincode/std",
"half/std",
"derive-new/std",
"burn-common/std",
"burn-tensor/std",
"flate2",
"log",
"rand/std",
"rmp-serde",
"serde/std",
"serde_json/std",
"bincode/std",
"half/std",
]
dataset = ["burn-dataset/default"]
dataset-minimal = ["burn-dataset"]
Expand All @@ -35,10 +34,18 @@ autodiff = ["burn-autodiff"]

ndarray = ["__ndarray", "burn-ndarray/default"]
ndarray-no-std = ["__ndarray", "burn-ndarray"]
ndarray-blas-accelerate = ["__ndarray", "ndarray", "burn-ndarray/blas-accelerate"]
ndarray-blas-accelerate = [
"__ndarray",
"ndarray",
"burn-ndarray/blas-accelerate",
]
ndarray-blas-netlib = ["__ndarray", "ndarray", "burn-ndarray/blas-netlib"]
ndarray-blas-openblas = ["__ndarray", "ndarray", "burn-ndarray/blas-openblas"]
ndarray-blas-openblas-system = ["__ndarray", "ndarray", "burn-ndarray/blas-openblas-system"]
ndarray-blas-openblas-system = [
"__ndarray",
"ndarray",
"burn-ndarray/blas-openblas-system",
]
__ndarray = [] # Internal flag to know when one ndarray feature is enabled.

wgpu = ["burn-wgpu/default"]
Expand All @@ -48,8 +55,8 @@ tch = ["burn-tch"]
# Serialization formats
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]

test-tch = ["tch"] # To use tch during testing, default uses ndarray.
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.
test-tch = ["tch"] # To use tch during testing, default uses ndarray.
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.

[dependencies]

Expand All @@ -66,7 +73,7 @@ burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", optional = true
burn-wgpu = { path = "../burn-wgpu", version = "0.10.0", optional = true }
burn-tch = { path = "../burn-tch", version = "0.10.0", optional = true }

derive-new = { workspace = true, default-features = false }
derive-new = { workspace = true }
libm = { workspace = true }
log = { workspace = true, optional = true }
rand = { workspace = true, features = ["std_rng"] } # Default enables std
Expand All @@ -88,7 +95,7 @@ serde_json = { workspace = true, features = ["alloc"] } #Default enables std
[dev-dependencies]
tempfile = { workspace = true }
burn-dataset = { path = "../burn-dataset", version = "0.10.0", features = [
"fake",
"fake",
] }

burn-ndarray = { path = "../burn-ndarray", version = "0.10.0", default-features = false }
Expand Down
3 changes: 1 addition & 2 deletions burn-import/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ default = ["onnx"]
onnx = []

[dependencies]
burn = {path = "../burn", version = "0.10.0" }
burn-common = {path = "../burn-common", version = "0.10.0" }
burn = {path = "../burn", version = "0.10.0"}
burn-ndarray = {path = "../burn-ndarray", version = "0.10.0" }

bytemuck = {workspace = true}
Expand Down
35 changes: 35 additions & 0 deletions examples/image-classification-web/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
[package]
authors = ["Dilshod Tadjibaev (@antimora)"]
edition = "2021"
license = "MIT OR Apache-2.0"
name = "image-classification-web"
publish = false
version = "0.10.0"

[lib]
crate-type = ["cdylib"]

[features]
default = []

[dependencies]
burn = { path = "../../burn", default-features = false, features = [
"ndarray-no-std",
"wgpu",
] }

burn-candle = { path = "../../burn-candle", version = "0.10.0", default-features = false }

js-sys = "0.3.64"
log = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
serde-wasm-bindgen = "0.6.0"
wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37"
wasm-logger = "0.2.0"
wasm-timer = "0.2.5"

[build-dependencies]
# Used to generate code from ONNX model
burn-import = { path = "../../burn-import" }
52 changes: 52 additions & 0 deletions examples/image-classification-web/NOTICES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# NOTICES AND INFORMATION

This file contains notices and information required by libraries that this repository copied or derived from. The use of the following resources complies with the licenses provided.

## Sample Images

Image Title: Domestic cat, a ten month old female.
Author: Von.grzanka
Source: https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg
License: https://creativecommons.org/licenses/by-sa/3.0/

Image Title: The George Washington Bridge over the Hudson River leading to New York City as seen from Fort Lee, New Jersey.
Author: John O'Connell
Source: https://commons.wikimedia.org/wiki/File:George_Washington_Bridge_from_New_Jersey-edit.jpg
License: https://creativecommons.org/licenses/by/2.0/deed.en

Image Title: Coyote from Yosemite National Park, California in snow
Author: Yathin S Krishnappa
Source https://commons.wikimedia.org/wiki/File:2009-Coyote-Yosemite.jpg
License: https://creativecommons.org/licenses/by-sa/3.0/deed.en

Image Title: Table lamp with a lampshade illuminated by sunlight.
Author: LoMit
Source: https://commons.wikimedia.org/wiki/File:Lamp_with_a_lampshade_illuminated_by_sunlight.jpg
License: https://creativecommons.org/licenses/by-sa/4.0/deed.en

Image Title: White Pelican Pelecanus onocrotalus at Walvis Bay, Namibia
Author: Rui Ornelas
Source: https://commons.wikimedia.org/wiki/File:Pelikan_Walvis_Bay.jpg
License: https://creativecommons.org/licenses/by/2.0/deed.en

Image Title: Photo of a traditional torch to be posted at gates
Author: Faizul Latif Chowdhury
Source: https://commons.wikimedia.org/wiki/File:Torch_traditional.jpg
License: https://creativecommons.org/licenses/by-sa/3.0/deed.en

Image Title: American Flamingo Phoenicopterus ruber at Gotomeer, Riscado, Bonaire
Author: Paul Asman and Jill Lenoble
Source: https://commons.wikimedia.org/wiki/File:Phoenicopterus_ruber_Bonaire_2.jpg
License: https://creativecommons.org/licenses/by/2.0/deed.en

## ONNX Model

SqueezeNet 1.1 model is licensed under Apache License 2.0. The model is downloaded from the [ONNX model zoo](https://github.com/onnx/models/tree/main).

Source: https://github.com/onnx/models/blob/main/vision/classification/squeezenet/model/squeezenet1.1-7.onnx
License: Apache License 2.0
License URL: https://github.com/onnx/models/blob/main/LICENSE

## ONNX Labels

The labels for the SqueezeNet 1.1 model are licensed under Apache License 2.0. The labels are downloaded from the [ONNX model zoo](https://github.com/onnx/models/blob/main/vision/classification/synset.txt)
Loading

0 comments on commit e2a17e4

Please sign in to comment.