Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Aug 8, 2022
1 parent 5d6c866 commit 127b7cb
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 33 deletions.
3 changes: 2 additions & 1 deletion burn-tensor/src/tensor/api/ad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ mod tch {
use super::*;
use crate::tensor::backend::autodiff::ADBackendTch;
use crate::tensor::backend::tch::TchBackend;
use crate::TchElement;

impl<E: Element, const D: usize> Tensor<D, TchBackend<E>>
impl<E: TchElement, const D: usize> Tensor<D, TchBackend<E>>
where
Standard: rand::distributions::Distribution<E>,
{
Expand Down
12 changes: 8 additions & 4 deletions burn-tensor/src/tensor/backend/autodiff/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@ use rand::distributions::Standard;
macro_rules! define_impl {
(
name: $name:ident,
backend: $backend:ty
backend: $backend:ty,
element: $element:ty
) => {
#[derive(Clone, Copy, Debug, Default)]
pub struct $name<E> {
_b: $backend,
}

impl<E: Element> Backend for $name<E>
impl<E> Backend for $name<E>
where
E: $element,
Standard: rand::distributions::Distribution<E>,
{
type Device = <$backend as Backend>::Device;
Expand Down Expand Up @@ -107,10 +109,12 @@ macro_rules! define_impl {
#[cfg(feature = "ndarray")]
define_impl!(
name: ADBackendNdArray,
backend: crate::tensor::backend::ndarray::NdArrayBackend<E>
backend: crate::tensor::backend::ndarray::NdArrayBackend<E>,
element: crate::NdArrayElement
);
#[cfg(feature = "tch")]
define_impl!(
name: ADBackendTch,
backend: crate::tensor::backend::tch::TchBackend<E>
backend: crate::tensor::backend::tch::TchBackend<E>,
element: crate::TchElement
);
4 changes: 2 additions & 2 deletions burn-tensor/src/tensor/backend/tch/backend.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::TchTensor;
use crate::tensor::{backend::Backend, Element};
use crate::tensor::{backend::Backend, TchElement};
use crate::tensor::{Data, Distribution, Shape};
use rand::distributions::Standard;

Expand All @@ -20,7 +20,7 @@ pub struct TchBackend<E> {
_e: E,
}

impl<E: Element> Backend for TchBackend<E>
impl<E: TchElement> Backend for TchBackend<E>
where
Standard: rand::distributions::Distribution<E>,
{
Expand Down
5 changes: 2 additions & 3 deletions burn-tensor/src/tensor/backend/tch/ops/device.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use crate::{
backend::tch::{TchBackend, TchDevice, TchTensor},
ops::TensorOpsDevice,
Element,
TchElement,
};
use rand::distributions::Standard;

impl<E: Element, const D: usize> TensorOpsDevice<TchBackend<E>, D> for TchTensor<E, D>
impl<E: TchElement, const D: usize> TensorOpsDevice<TchBackend<E>, D> for TchTensor<E, D>
where
E: Element,
Standard: rand::distributions::Distribution<E>,
{
fn device(&self) -> <TchBackend<E> as crate::back::Backend>::Device {
Expand Down
4 changes: 2 additions & 2 deletions burn-tensor/src/tensor/backend/tch/ops/map_comparison.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::backend::tch::TchBackend;
use crate::tensor::Element;
use crate::tensor::TchElement;
use crate::tensor::{
backend::tch::{TchKind, TchTensor},
ops::*,
Expand All @@ -8,7 +8,7 @@ use rand::distributions::Standard;

impl<E, const D: usize> TensorOpsMapComparison<TchBackend<E>, D> for TchTensor<E, D>
where
E: Element,
E: TchElement,
Standard: rand::distributions::Distribution<E>,
{
fn greater(
Expand Down
6 changes: 3 additions & 3 deletions burn-tensor/src/tensor/backend/tch/ops/mask.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::{
back::Backend, backend::tch::TchBackend, backend::tch::TchTensor, ops::TensorOpsMask, Element,
back::Backend, backend::tch::TchBackend, backend::tch::TchTensor, ops::TensorOpsMask,
TchElement,
};
use rand::distributions::Standard;

impl<E: Element, const D: usize> TensorOpsMask<TchBackend<E>, D> for TchTensor<E, D>
impl<E: TchElement, const D: usize> TensorOpsMask<TchBackend<E>, D> for TchTensor<E, D>
where
E: Element,
Standard: rand::distributions::Distribution<E>,
{
fn mask_fill(
Expand Down
4 changes: 2 additions & 2 deletions burn-tensor/src/tensor/backend/tch/ops/reshape.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::tensor::{
backend::tch::{TchBackend, TchShape, TchTensor},
ops::*,
Element, Shape,
Shape, TchElement,
};
use rand::distributions::Standard;

impl<P: Element, const D: usize> TensorOpsReshape<TchBackend<P>, D> for TchTensor<P, D>
impl<P: TchElement, const D: usize> TensorOpsReshape<TchBackend<P>, D> for TchTensor<P, D>
where
Standard: rand::distributions::Distribution<P>,
{
Expand Down
22 changes: 6 additions & 16 deletions burn-tensor/src/tensor/tensor_trait.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::tensor::ops::*;
use rand::distributions::uniform::SampleUniform;

pub trait BasicElement:
pub trait Element:
Zeros<Self>
+ Ones<Self>
+ std::fmt::Debug
Expand All @@ -14,21 +14,12 @@ pub trait BasicElement:
+ std::cmp::PartialOrd<Self>
{
}
#[cfg(all(feature = "tch", feature = "ndarray"))]
pub trait Element:
Sized
+ BasicElement
+ ndarray::LinalgScalar
+ ndarray::ScalarOperand
+ tch::kind::Element
+ Into<f64>
{
}
#[cfg(all(feature = "tch", not(feature = "ndarray")))]
pub trait Element: BasicElement + tch::kind::Element + Into<f64> {}

#[cfg(all(feature = "ndarray", not(feature = "tch")))]
pub trait Element: BasicElement + ndarray::LinalgScalar + ndarray::ScalarOperand {}
#[cfg(feature = "tch")]
pub trait TchElement: Element + tch::kind::Element + Into<f64> {}

#[cfg(feature = "ndarray")]
pub trait NdArrayElement: Element + tch::kind::Element + Into<f64> {}

pub trait TensorTrait<P: Element, const D: usize>:
TensorOpsUtilities<P, D>
Expand All @@ -53,7 +44,6 @@ macro_rules! ad_items {
zero $zero:expr,
one $one:expr
) => {
impl BasicElement for $float {}
impl Element for $float {}

impl Zeros<$float> for $float {
Expand Down

0 comments on commit 127b7cb

Please sign in to comment.