forked from tracel-ai/burn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor/extract tch backend (tracel-ai#103)
- Loading branch information
1 parent
23677b8
commit ab51c22
Showing
123 changed files
with
1,805 additions
and
1,637 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
name: test | ||
|
||
on: [push] | ||
|
||
jobs: | ||
publish: | ||
name: test burn tch | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: checkout | ||
uses: actions/checkout@v2 | ||
|
||
- name: install rust nightly | ||
uses: actions-rs/toolchain@v1 | ||
with: | ||
profile: minimal | ||
toolchain: stable | ||
components: rustfmt, clippy | ||
override: true | ||
|
||
- name: check format | ||
run: | | ||
cd burn-tch | ||
cargo fmt --check --all | ||
- name: check doc | ||
run: | | ||
cd burn-tch | ||
cargo test --no-default-features --features doc --doc | ||
- name: check tests | ||
run: | | ||
cd burn-tch | ||
cargo test --tests | ||
- name: check clippy | ||
run: | | ||
cargo clippy -p burn-tch -- -D warnings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,5 +4,6 @@ members = [ | |
"burn-derive", | ||
"burn-tensor", | ||
"burn-dataset", | ||
"burn-tch", | ||
"examples/*", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
[package] | ||
name = "burn-tch" | ||
version = "0.2.3" | ||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"] | ||
|
||
description = "Tch backend for burn" | ||
repository = "https://github.com/burn-rs/burn/tree/main/burn-tch" | ||
readme="README.md" | ||
keywords = ["deep-learning", "machine-learning", "data"] | ||
categories = ["science"] | ||
license = "MIT/Apache-2.0" | ||
edition = "2021" | ||
|
||
[features] | ||
doc = ["tch/doc-only"] | ||
|
||
[dependencies] | ||
burn-tensor = { path = "../burn-tensor", version = "0.2.3", default-features = false } | ||
rand = "0.8" | ||
num-traits = "0.2" | ||
tch = { version = "0.8" } | ||
serde = { version = "1.0", features = ["derive"] } | ||
lazy_static = "1.4" | ||
half = { version = "1.6", features = ["num-traits"] } # needs to be 1.6 to work with tch | ||
|
||
[dev-dependencies] | ||
burn-tensor = { path = "../burn-tensor", version = "0.2.3", default-features = false, features = ["export_tests"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../LICENSE-APACHE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../LICENSE-MIT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Burn-tch | ||
|
||
Tch backend for burn. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
use burn_tensor::Element; | ||
use half::f16; | ||
|
||
pub trait TchElement: Element + tch::kind::Element {} | ||
|
||
impl TchElement for f64 {} | ||
impl TchElement for f32 {} | ||
impl TchElement for f16 {} | ||
|
||
impl TchElement for i64 {} | ||
impl TchElement for i32 {} | ||
impl TchElement for i16 {} | ||
|
||
impl TchElement for u8 {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
...nsor/src/tensor/backend/tch/module_ops.rs → burn-tch/src/module_ops.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 2 additions & 1 deletion
3
...or/src/tensor/backend/tch/ops/creation.rs → burn-tch/src/ops/creation.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
...nsor/src/tensor/backend/tch/tensor_ops.rs → burn-tch/src/tensor_ops.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
pub(crate) mod ops; | ||
pub mod ops; | ||
pub(crate) mod stats; | ||
|
||
mod base; | ||
|
Oops, something went wrong.