Skip to content

Commit

Permalink
Refactor/extract tch backend (tracel-ai#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Nov 16, 2022
1 parent 23677b8 commit ab51c22
Show file tree
Hide file tree
Showing 123 changed files with 1,805 additions and 1,637 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/test-burn-tch.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: test

on: [push]

jobs:
publish:
name: test burn tch
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v2

- name: install rust nightly
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
components: rustfmt, clippy
override: true

- name: check format
run: |
cd burn-tch
cargo fmt --check --all
- name: check doc
run: |
cd burn-tch
cargo test --no-default-features --features doc --doc
- name: check tests
run: |
cd burn-tch
cargo test --tests
- name: check clippy
run: |
cargo clippy -p burn-tch -- -D warnings
15 changes: 3 additions & 12 deletions .github/workflows/test-burn-tensor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,13 @@ jobs:
- name: check doc
run: |
cd burn-tensor
cargo test --no-default-features --features doc --doc
cargo test --doc
- name: check tests backend ndarray
run: |
cd burn-tensor
cargo test --no-default-features --features ndarray --tests
- name: check tests backend tch
run: |
cd burn-tensor
cargo test --no-default-features --features tch --tests
- name: check clippy backend tch
run: |
cargo clippy -p burn-tensor --no-default-features --features tch -- -D warnings
cargo test --no-default-features --features ndarray export_tests --tests
- name: check clippy backend ndarray
run: |
cargo clippy -p burn-tensor --no-default-features --features tch -- -D warnings
cargo clippy -p burn-tensor --no-default-features --features ndarray -- -D warnings
2 changes: 1 addition & 1 deletion .github/workflows/test-burn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: check doc
run: |
cd burn
cargo test --no-default-features --features doc --doc
cargo test --doc
- name: check tests
run: |
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ members = [
"burn-derive",
"burn-tensor",
"burn-dataset",
"burn-tch",
"examples/*",
]
27 changes: 27 additions & 0 deletions burn-tch/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
[package]
name = "burn-tch"
version = "0.2.3"
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]

description = "Tch backend for burn"
repository = "https://github.com/burn-rs/burn/tree/main/burn-tch"
readme="README.md"
keywords = ["deep-learning", "machine-learning", "data"]
categories = ["science"]
license = "MIT/Apache-2.0"
edition = "2021"

[features]
doc = ["tch/doc-only"]

[dependencies]
burn-tensor = { path = "../burn-tensor", version = "0.2.3", default-features = false }
rand = "0.8"
num-traits = "0.2"
tch = { version = "0.8" }
serde = { version = "1.0", features = ["derive"] }
lazy_static = "1.4"
half = { version = "1.6", features = ["num-traits"] } # needs to be 1.6 to work with tch

[dev-dependencies]
burn-tensor = { path = "../burn-tensor", version = "0.2.3", default-features = false, features = ["export_tests"] }
1 change: 1 addition & 0 deletions burn-tch/LICENSE-APACHE
1 change: 1 addition & 0 deletions burn-tch/LICENSE-MIT
3 changes: 3 additions & 0 deletions burn-tch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Burn-tch

Tch backend for burn.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::element::TchElement;
use super::TchTensor;
use crate::tensor::backend::Backend;
use crate::tensor::{Data, Distribution, Shape};
use burn_tensor::backend::Backend;
use burn_tensor::{Data, Distribution, Shape};

#[derive(Clone, Copy, Debug)]
/// The device struct when using the `tch` backend.
Expand All @@ -11,7 +11,7 @@ use crate::tensor::{Data, Distribution, Shape};
/// # Example
///
/// ```rust
/// use burn_tensor::backend::TchDevice;
/// use burn_tch::TchDevice;
///
/// let device_gpu_1 = TchDevice::Cuda(0); // First GPU
/// let device_gpu_2 = TchDevice::Cuda(1); // Second GPU
Expand Down
14 changes: 14 additions & 0 deletions burn-tch/src/element.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use burn_tensor::Element;
use half::f16;

pub trait TchElement: Element + tch::kind::Element {}

impl TchElement for f64 {}
impl TchElement for f32 {}
impl TchElement for f16 {}

impl TchElement for i64 {}
impl TchElement for i32 {}
impl TchElement for i16 {}

impl TchElement for u8 {}
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,10 @@ mod tensor_ops;
pub use backend::*;
pub use tensor::*;
pub use tensor_ops::*;

#[cfg(test)]
mod tests {
type TestBackend = crate::TchBackend<f32>;

burn_tensor::test_all!();
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{element::TchElement, TchBackend, TchTensor};
use crate::{ops::ModuleOps, Shape};
use burn_tensor::{ops::ModuleOps, Shape};

impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
fn embedding(weights: &TchTensor<E, 2>, indexes: &TchTensor<i64, 2>) -> TchTensor<E, 3> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::tensor::{backend::tch::TchTensor, ops::*};
use crate::TchTensor;
use burn_tensor::ops::*;

impl<P, const D: usize> Zeros for TchTensor<P, D>
where
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
use super::element::TchElement;
use crate::{
backend::{TchBackend, TchDevice},
ops::TensorOps,
tensor::{Data, Shape},
};
use crate::{element::TchElement, TchBackend, TchDevice};
use burn_tensor::{ops::TensorOps, Data, Shape};

lazy_static::lazy_static! {
static ref NO_GRAD: tch::NoGradGuard = {
Expand Down Expand Up @@ -53,16 +49,6 @@ impl<const D: usize> From<Shape<D>> for TchShape<D> {
}
}

impl<const D: usize> From<Vec<i64>> for Shape<D> {
fn from(shape: Vec<i64>) -> Self {
let mut dims = [1; D];
for (i, dim) in shape.into_iter().enumerate() {
dims[i] = dim as usize;
}
Self::new(dims)
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct TchKind<P: tch::kind::Element> {
_p: P,
Expand Down Expand Up @@ -99,12 +85,12 @@ impl<P: tch::kind::Element + Default, const D: usize> TchTensor<P, D> {
#[cfg(test)]
mod utils {
use super::*;
use crate::{backend::TchBackend, ops::TensorOps};
use crate::{backend::TchBackend, element::TchElement};

impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
impl<P: TchElement, const D: usize> TchTensor<P, D> {
pub(crate) fn into_data(self) -> Data<P, D>
where
P: TchElement,
P: tch::kind::Element,
{
<TchBackend<P> as TensorOps<TchBackend<P>>>::into_data(self)
}
Expand All @@ -131,7 +117,7 @@ impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> T
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Distribution;
use burn_tensor::Distribution;
use rand::prelude::StdRng;
use rand::SeedableRng;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{element::TchElement, TchBackend, TchDevice, TchKind, TchShape, TchTensor};
use crate::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape};
use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape};
use std::ops::{Add, Div, Mul, Range, Sub};

impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
Expand Down
11 changes: 2 additions & 9 deletions burn-tensor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,16 @@ all-features = false
no-default-features = true

[features]
default = ["tch", "ndarray"]
tch = ["dep:tch"]
default = ["ndarray"]
ndarray = ["dep:ndarray", "dep:libm"]
doc = ["dep:tch", "tch/doc-only", "dep:ndarray"]
export_tests = []

[dependencies]
num-traits = "0.2"
derive-new = "0.5"
rand = "0.8"
half = { version = "1.6", features = ["num-traits"] } # needs to be 1.6 to work with tch

# Backends
tch = { version = "0.8", optional = true }
lazy_static = "1.4"



# NdArray
ndarray = { version = "0.15", optional = true }
libm = { version = "0.2", optional = true }
Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/examples/autodiff_simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ fn main() {
#[cfg(feature = "ndarray")]
{
run::<backend::NdArrayBackend<f32>>(x.clone(), y.clone());
run_ad::<backend::NdArrayADBackend<f32>>(x.clone(), y.clone());
run_ad::<backend::NdArrayADBackend<f32>>(x, y);
}

#[cfg(feature = "tch")]
Expand Down
3 changes: 3 additions & 0 deletions burn-tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,8 @@ pub use graph::grad::Gradients;

mod tensor;

#[cfg(feature = "export_tests")]
mod tests;

pub use half::f16;
pub use tensor::*;
3 changes: 0 additions & 3 deletions burn-tensor/src/tensor/backend/autodiff/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,3 @@ impl<B: Backend> ADBackend for ADBackendDecorator<B> {
#[cfg(feature = "ndarray")]
pub type ADBackendNdArray<E> =
ADBackendDecorator<crate::tensor::backend::ndarray::NdArrayBackend<E>>;

#[cfg(feature = "tch")]
pub type ADBackendTch<E> = ADBackendDecorator<crate::tensor::backend::tch::TchBackend<E>>;
10 changes: 0 additions & 10 deletions burn-tensor/src/tensor/backend/autodiff/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,4 @@ pub mod helper {
pub type TestADTensor<E, const D: usize> = Tensor<ADBackendNdArray<E>, D>;
}
pub use helper_impl::*;

#[cfg(feature = "tch")]
#[cfg(not(feature = "ndarray"))]
mod helper_impl {
use crate::tensor::backend::autodiff::ADBackendTch;
use crate::tensor::Tensor;

pub type TestADTensor<E, const D: usize> = Tensor<ADBackendTch<E>, D>;
}
pub use helper_impl::*;
}
12 changes: 2 additions & 10 deletions burn-tensor/src/tensor/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,6 @@ pub use base::*;

pub(crate) mod autodiff;

#[cfg(feature = "tch")]
pub(crate) mod tch;
#[cfg(feature = "tch")]
pub type TchADBackend<E> = self::autodiff::ADBackendTch<E>;
#[cfg(feature = "tch")]
pub type TchBackend<E> = self::tch::TchBackend<E>;

#[cfg(feature = "tch")]
pub type TchDevice = self::tch::TchDevice;

#[cfg(feature = "ndarray")]
pub(crate) mod ndarray;
#[cfg(feature = "ndarray")]
Expand All @@ -25,3 +15,5 @@ pub type NdArrayADBackend<E> = self::autodiff::ADBackendNdArray<E>;
pub type NdArrayBackend<E> = self::ndarray::NdArrayBackend<E>;
#[cfg(feature = "ndarray")]
pub type NdArrayDevice = self::ndarray::NdArrayDevice;

pub use autodiff::ADBackendDecorator;
32 changes: 0 additions & 32 deletions burn-tensor/src/tensor/backend/tch/element.rs

This file was deleted.

12 changes: 12 additions & 0 deletions burn-tensor/src/tensor/element.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{tensor::ops::*, Distribution};
use half::f16;
use num_traits::ToPrimitive;
use rand::prelude::StdRng;

Expand Down Expand Up @@ -182,3 +183,14 @@ make_element!(
convert |elem: &dyn ToPrimitive| elem.to_u8().unwrap(),
random |distribution: Distribution<u8>, rng: &mut StdRng| distribution.sampler(rng).sample()
);
make_element!(
ty f16 Precision::Half,
zero <f16 as num_traits::Zero>::zero(),
one <f16 as num_traits::One>::one(),
convert |elem: &dyn ToPrimitive| f16::from_f32(elem.to_f32().unwrap()),
random |distribution: Distribution<f16>, rng: &mut StdRng| {
let distribution: Distribution<f32> = distribution.convert();
let sample = distribution.sampler(rng).sample();
f16::from_elem(sample)
}
);
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub(crate) mod ops;
pub mod ops;
pub(crate) mod stats;

mod base;
Expand Down
Loading

0 comments on commit ab51c22

Please sign in to comment.