Skip to content

Commit

Permalink
feat: add comparison (tracel-ai#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Aug 22, 2022
1 parent 5560bae commit f8ab29b
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 3 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/test-burn-dataset.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: test

on: [push]

jobs:
publish:
name: test burn dataset
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v2

- name: install rust nightly
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: nightly
components: rustfmt
override: true

- name: check format
run: |
cd burn-dataset
cargo fmt --check --all
- name: check tests
run: |
cd burn-dataset
cargo test
27 changes: 27 additions & 0 deletions .github/workflows/test-burn.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: test

on: [push]

jobs:
publish:
name: test burn
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v2

- name: install rust nightly
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: nightly
components: rustfmt
override: true

- name: check format
run: |
cargo fmt --check --all
- name: check tests
run: |
cargo test
8 changes: 8 additions & 0 deletions burn-tensor/src/tensor/api/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,14 @@ where
Self::new(self.value.sum_dim(dim))
}

pub fn equal(&self, other: &Self) -> BoolTensor<B, D> {
BoolTensor::new(self.value.equal(&other.value))
}

pub fn equal_scalar(&self, other: &B::Elem) -> BoolTensor<B, D> {
BoolTensor::new(self.value.equal_scalar(other))
}

pub fn greater(&self, other: &Self) -> BoolTensor<B, D> {
BoolTensor::new(self.value.greater(&other.value))
}
Expand Down
11 changes: 11 additions & 0 deletions burn-tensor/src/tensor/backend/autodiff/ops/map_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ macro_rules! define_impl {
impl<E: $element, const D: usize> TensorOpsMapComparison<$backend, D>
for <$backend as Backend>::TensorPrimitive<D>
{
fn equal(&self, other: &Self) -> <$backend as Backend>::BoolTensorPrimitive<D> {
TensorOpsMapComparison::equal(&self.tensor(), &other.tensor())
}

fn equal_scalar(
&self,
other: &<$backend as Backend>::Elem,
) -> <$backend as Backend>::BoolTensorPrimitive<D> {
TensorOpsMapComparison::equal_scalar(&self.tensor(), other)
}

fn greater(&self, other: &Self) -> <$backend as Backend>::BoolTensorPrimitive<D> {
TensorOpsMapComparison::greater(&self.tensor(), &other.tensor())
}
Expand Down
21 changes: 21 additions & 0 deletions burn-tensor/src/tensor/backend/ndarray/ops/map_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,27 @@ impl<E, const D: usize> TensorOpsMapComparison<NdArrayBackend<E>, D> for NdArray
where
E: NdArrayElement,
{
fn equal(
&self,
other: &Self,
) -> <NdArrayBackend<E> as crate::back::Backend>::BoolTensorPrimitive<D> {
let tensor = self.sub(other);
let zero = E::zeros(&E::default());
tensor.equal_scalar(&zero)
}

fn equal_scalar(
&self,
other: &<NdArrayBackend<E> as crate::back::Backend>::Elem,
) -> <NdArrayBackend<E> as crate::back::Backend>::BoolTensorPrimitive<D> {
let array = self.array.mapv(|a| a == *other).into_shared();

NdArrayTensor {
shape: self.shape,
array,
}
}

fn greater(
&self,
other: &Self,
Expand Down
27 changes: 27 additions & 0 deletions burn-tensor/src/tensor/backend/tch/ops/map_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,33 @@ impl<E, const D: usize> TensorOpsMapComparison<TchBackend<E>, D> for TchTensor<E
where
E: TchElement,
{
fn equal(
&self,
other: &Self,
) -> <TchBackend<E> as crate::back::Backend>::BoolTensorPrimitive<D> {
let tensor = self.tensor.eq_tensor(&other.tensor);

TchTensor {
shape: self.shape,
tensor,
kind: TchKind::<bool>::new(),
}
}

fn equal_scalar(
&self,
other: &<TchBackend<E> as crate::back::Backend>::Elem,
) -> <TchBackend<E> as crate::back::Backend>::BoolTensorPrimitive<D> {
let other: f64 = (*other).into();
let tensor = self.tensor.eq(other);

TchTensor {
shape: self.shape,
tensor,
kind: TchKind::<bool>::new(),
}
}

fn greater(
&self,
other: &Self,
Expand Down
2 changes: 2 additions & 0 deletions burn-tensor/src/tensor/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ pub trait TensorOpsIndex<E, const D1: usize> {
}

pub trait TensorOpsMapComparison<B: Backend, const D: usize> {
fn equal(&self, other: &Self) -> B::BoolTensorPrimitive<D>;
fn equal_scalar(&self, other: &B::Elem) -> B::BoolTensorPrimitive<D>;
fn greater(&self, other: &Self) -> B::BoolTensorPrimitive<D>;
fn greater_scalar(&self, other: &B::Elem) -> B::BoolTensorPrimitive<D>;
fn greater_equal(&self, other: &Self) -> B::BoolTensorPrimitive<D>;
Expand Down
33 changes: 30 additions & 3 deletions burn-tensor/src/tensor/tensor_trait.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use crate::{tensor::ops::*, Distribution};
use half::f16;
use num_traits::{One, ToPrimitive, Zero};
use num_traits::ToPrimitive;
use rand::prelude::StdRng;

pub trait Element:
Zeros<Self>
+ ToPrimitive
+ ElementRandom<Self>
+ ElementConversion
+ ElementValue
+ Ones<Self>
+ std::ops::Mul<Self, Output = Self>
+ std::fmt::Debug
Expand Down Expand Up @@ -37,6 +38,14 @@ pub trait ElementRandom<T> {
fn random(distribution: Distribution<T>, rng: &mut StdRng) -> T;
}

pub trait ElementValue {
fn inf() -> Self;
fn inf_neg() -> Self;
fn nan() -> Self;
fn zero() -> Self;
fn one() -> Self;
}

#[cfg(feature = "ndarray")]
pub trait NdArrayElement:
Element + ndarray::LinalgScalar + ndarray::ScalarOperand + ExpElement + num_traits::FromPrimitive
Expand Down Expand Up @@ -87,6 +96,24 @@ macro_rules! ad_items {
}
}

impl ElementValue for $float {
fn inf() -> Self {
Self::from_elem(f64::INFINITY)
}
fn inf_neg() -> Self {
Self::from_elem(std::ops::Neg::neg(f64::INFINITY))
}
fn nan() -> Self {
Self::from_elem(f64::NAN)
}
fn zero() -> Self {
$zero
}
fn one() -> Self {
$one
}
}

impl ElementRandom<$float> for $float {
fn random(distribution: Distribution<$float>, rng: &mut StdRng) -> $float {
$random(distribution, rng)
Expand Down Expand Up @@ -135,8 +162,8 @@ ad_items!(

ad_items!(
ty f16,
zero f16::zero(),
one f16::one(),
zero <f16 as num_traits::Zero>::zero(),
one <f16 as num_traits::One>::one(),
convert |elem: &dyn ToPrimitive| f16::from_f32(elem.to_f32().unwrap()),
random |distribution: Distribution<f16>, rng: &mut StdRng| {
let distribution: Distribution<f32> = distribution.convert();
Expand Down

0 comments on commit f8ab29b

Please sign in to comment.