Skip to content

Commit

Permalink
Feat/argmax (tracel-ai#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Aug 23, 2022
1 parent 0aa486f commit d62f2b0
Show file tree
Hide file tree
Showing 21 changed files with 283 additions and 48 deletions.
37 changes: 37 additions & 0 deletions burn-tensor/src/tensor/api/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,35 @@ pub struct BoolTensor<B: Backend, const D: usize> {
pub(crate) value: B::BoolTensorPrimitive<D>,
}

pub struct IndexTensor<B: Backend, const D: usize> {
pub(crate) value: B::IndexTensorPrimitive<D>,
}

impl<B, const D: usize> IndexTensor<B, D>
where
B: Backend,
{
pub fn new(tensor: B::IndexTensorPrimitive<D>) -> Self {
Self { value: tensor }
}

pub fn shape(&self) -> &Shape<D> {
self.value.shape()
}

pub fn into_data(self) -> Data<i64, D> {
self.value.into_data()
}

pub fn to_data(&self) -> Data<i64, D> {
self.value.to_data()
}

pub fn mul(&self, other: &Self) -> Self {
Self::new(self.value.mul(&other.value))
}
}

impl<B, const D: usize> BoolTensor<B, D>
where
B: Backend,
Expand Down Expand Up @@ -244,6 +273,14 @@ where
Tensor::new(value)
}

pub fn argmax(&self, dim: usize) -> IndexTensor<B, D> {
IndexTensor::new(self.value.argmax(dim))
}

pub fn argmin(&self, dim: usize) -> IndexTensor<B, D> {
IndexTensor::new(self.value.argmin(dim))
}

pub fn cat(tensors: Vec<&Self>, dim: usize) -> Self {
let tensors: Vec<B::TensorPrimitive<D>> =
tensors.into_iter().map(|a| a.value.clone()).collect();
Expand Down
2 changes: 2 additions & 0 deletions burn-tensor/src/tensor/backend/autodiff/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ macro_rules! define_impl {
type TensorPrimitive<const D: usize> = ADTensor<D, $backend>;
type BoolTensorPrimitive<const D: usize> =
<$backend as Backend>::BoolTensorPrimitive<D>;
type IndexTensorPrimitive<const D: usize> =
<$backend as Backend>::IndexTensorPrimitive<D>;

fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
Expand Down
25 changes: 25 additions & 0 deletions burn-tensor/src/tensor/backend/autodiff/ops/arg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use crate::backend::backend::Backend;
use crate::tensor::ops::*;

macro_rules! define_impl {
(
$backend:ty,
$backend_inner:ty,
$element:ident
) => {
impl<E: $element, const D: usize> TensorOpsArg<$backend, D>
for <$backend as Backend>::TensorPrimitive<D>
{
fn argmax(&self, dim: usize) -> <$backend as Backend>::IndexTensorPrimitive<D> {
TensorOpsArg::argmax(&self.tensor(), dim)
}

fn argmin(&self, dim: usize) -> <$backend as Backend>::IndexTensorPrimitive<D> {
TensorOpsArg::argmin(&self.tensor(), dim)
}
}
};
}

crate::register_tch!();
crate::register_ndarray!();
1 change: 1 addition & 0 deletions burn-tensor/src/tensor/backend/autodiff/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod add;
mod aggregation;
mod arg;
mod cat;
mod creation;
mod device;
Expand Down
10 changes: 9 additions & 1 deletion burn-tensor/src/tensor/backend/backend.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::activation::ReLU;
use crate::graph::grad::Gradients;
use crate::ops::{
TensorOpsAggregation, TensorOpsCat, TensorOpsDevice, TensorOpsExp, TensorOpsLog,
TensorOpsAggregation, TensorOpsArg, TensorOpsCat, TensorOpsDevice, TensorOpsExp, TensorOpsLog,
TensorOpsMapComparison, TensorOpsMask, TensorOpsPrecision, TensorOpsUtilities,
};
use crate::tensor::ops::{TensorOpsIndex, TensorOpsReshape};
Expand All @@ -20,6 +20,7 @@ pub trait Backend: Clone + Sized + Default + Send + Sync + std::fmt::Debug + 'st
+ TensorOpsIndex<Self::Elem, D>
+ TensorOpsAggregation<Self, D>
+ TensorOpsExp<Self::Elem, D>
+ TensorOpsArg<Self, D>
+ TensorOpsCat<Self::Elem, D>
+ TensorOpsLog<Self::Elem, D>
+ TensorOpsMask<Self, D>
Expand All @@ -28,12 +29,19 @@ pub trait Backend: Clone + Sized + Default + Send + Sync + std::fmt::Debug + 'st
+ Send
+ Sync
+ 'static;

type BoolTensorPrimitive<const D: usize>: TensorOpsUtilities<bool, D>
+ Clone
+ Send
+ Sync
+ 'static
+ std::fmt::Debug;
type IndexTensorPrimitive<const D: usize>: TensorTrait<i64, D>
+ Clone
+ Send
+ Sync
+ 'static
+ std::fmt::Debug;

fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/src/tensor/backend/ndarray/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ impl<E: NdArrayElement> Backend for NdArrayBackend<E> {
type FullPrecisionBackend = NdArrayBackend<f32>;
type TensorPrimitive<const D: usize> = NdArrayTensor<E, D>;
type BoolTensorPrimitive<const D: usize> = NdArrayTensor<bool, D>;
type IndexTensorPrimitive<const D: usize> = NdArrayTensor<i64, D>;

fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
Expand Down
73 changes: 73 additions & 0 deletions burn-tensor/src/tensor/backend/ndarray/ops/arg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use crate::backend::ndarray::NdArrayBackend;
use crate::tensor::{backend::ndarray::NdArrayTensor, ops::*};
use crate::ElementValue;
use crate::NdArrayElement;
use std::cmp::Ordering;

impl<E, const D: usize> TensorOpsArg<NdArrayBackend<E>, D> for NdArrayTensor<E, D>
where
E: NdArrayElement,
{
fn argmax(&self, dim: usize) -> NdArrayTensor<i64, D> {
arg(self, dim, cmp_max)
}

fn argmin(&self, dim: usize) -> NdArrayTensor<i64, D> {
arg(self, dim, cmp_min)
}
}

fn arg<E: NdArrayElement, F, const D: usize>(
tensor: &NdArrayTensor<E, D>,
dim: usize,
cmp: F,
) -> NdArrayTensor<i64, D>
where
F: Fn(&f64, &f64) -> Ordering,
{
let mut data = tensor.to_data();
let mut start = 1;

for i in 0..dim {
start = start * tensor.shape.dims[i];
}
let end = start + tensor.shape.dims[dim];

let data_dim = &mut data.value[start..end];
let mut sorted: Vec<f64> = data_dim.iter().map(|a| a.to_elem()).collect();
sorted.sort_by(cmp);

let max = sorted[0];
for elem in data_dim {
*elem = <E as ElementValue>::zero();
}

let data_dim = &mut data.value[start..end];
for elem in data_dim {
let as_float: f64 = elem.to_elem();
if as_float == max {
*elem = <E as ElementValue>::one();
break;
}
}

NdArrayTensor::from_data(data.convert())
}

fn cmp_max(a: &f64, b: &f64) -> Ordering {
if a < b {
return Ordering::Less;
} else if a > b {
return Ordering::Greater;
}
return Ordering::Equal;
}

fn cmp_min(a: &f64, b: &f64) -> Ordering {
if a > b {
return Ordering::Less;
} else if a < b {
return Ordering::Greater;
}
return Ordering::Equal;
}
1 change: 1 addition & 0 deletions burn-tensor/src/tensor/backend/ndarray/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod add;
mod aggregation;
mod arg;
mod cat;
mod creation;
mod device;
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/src/tensor/backend/tch/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ impl<E: TchElement> Backend for TchBackend<E> {
type FullPrecisionBackend = TchBackend<f32>;
type TensorPrimitive<const D: usize> = TchTensor<E, D>;
type BoolTensorPrimitive<const D: usize> = TchTensor<bool, D>;
type IndexTensorPrimitive<const D: usize> = TchTensor<i64, D>;

fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
Expand Down
33 changes: 33 additions & 0 deletions burn-tensor/src/tensor/backend/tch/ops/arg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use crate::backend::tch::TchBackend;
use crate::tensor::TchElement;
use crate::tensor::{
backend::tch::{TchKind, TchTensor},
ops::*,
};

impl<E, const D: usize> TensorOpsArg<TchBackend<E>, D> for TchTensor<E, D>
where
E: TchElement,
{
fn argmax(&self, dim: usize) -> TchTensor<i64, D> {
let tensor = self.tensor.argmax(dim as i64, true);
let shape = self.shape.clone();

TchTensor {
tensor,
kind: TchKind::<i64>::new(),
shape,
}
}

fn argmin(&self, dim: usize) -> TchTensor<i64, D> {
let tensor = self.tensor.argmin(dim as i64, true);
let shape = self.shape.clone();

TchTensor {
tensor,
kind: TchKind::<i64>::new(),
shape,
}
}
}
13 changes: 8 additions & 5 deletions burn-tensor/src/tensor/backend/tch/ops/div.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use crate::tensor::{backend::tch::TchTensor, ops::*, Shape};
use crate::{
tensor::{backend::tch::TchTensor, ops::*, Shape},
Element,
};
use std::ops::Div;

impl<P: tch::kind::Element + Into<f64>, const D: usize> TensorOpsDiv<P, D> for TchTensor<P, D> {
impl<P: Element + tch::kind::Element, const D: usize> TensorOpsDiv<P, D> for TchTensor<P, D> {
fn div(&self, other: &Self) -> Self {
let tensor = (&self.tensor) / &other.tensor;
let shape = self.shape.higher(&other.shape);
Expand All @@ -14,7 +17,7 @@ impl<P: tch::kind::Element + Into<f64>, const D: usize> TensorOpsDiv<P, D> for T
}
}
fn div_scalar(&self, other: &P) -> Self {
let other: f64 = (other.clone()).into();
let other: f64 = (other.clone()).to_elem();
let tensor = (&self.tensor).div(other);
let shape = Shape::from(tensor.size());
let kind = self.kind.clone();
Expand All @@ -27,15 +30,15 @@ impl<P: tch::kind::Element + Into<f64>, const D: usize> TensorOpsDiv<P, D> for T
}
}

impl<P: tch::kind::Element + Into<f64>, const D: usize> std::ops::Div<P> for TchTensor<P, D> {
impl<P: Element + tch::kind::Element, const D: usize> std::ops::Div<P> for TchTensor<P, D> {
type Output = TchTensor<P, D>;

fn div(self, rhs: P) -> Self::Output {
TensorOpsDiv::div_scalar(&self, &rhs)
}
}

impl<P: tch::kind::Element + Into<f64>, const D: usize> std::ops::Div<TchTensor<P, D>>
impl<P: Element + tch::kind::Element, const D: usize> std::ops::Div<TchTensor<P, D>>
for TchTensor<P, D>
{
type Output = TchTensor<P, D>;
Expand Down
10 changes: 5 additions & 5 deletions burn-tensor/src/tensor/backend/tch/ops/map_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ where
&self,
other: &<TchBackend<E> as crate::back::Backend>::Elem,
) -> <TchBackend<E> as crate::back::Backend>::BoolTensorPrimitive<D> {
let other: f64 = (*other).into();
let other: f64 = (*other).to_elem();
let tensor = self.tensor.eq(other);

TchTensor {
Expand All @@ -53,7 +53,7 @@ where
&self,
other: &<TchBackend<E> as crate::back::Backend>::Elem,
) -> <TchBackend<E> as crate::back::Backend>::BoolTensorPrimitive<D> {
let other: f64 = (*other).into();
let other: f64 = (*other).to_elem();
let tensor = self.tensor.greater(other);

TchTensor {
Expand All @@ -80,7 +80,7 @@ where
&self,
other: &<TchBackend<E> as crate::back::Backend>::Elem,
) -> <TchBackend<E> as crate::back::Backend>::BoolTensorPrimitive<D> {
let other: f64 = (*other).into();
let other: f64 = (*other).to_elem();
let tensor = self.tensor.greater_equal(other);

TchTensor {
Expand All @@ -107,7 +107,7 @@ where
&self,
other: &<TchBackend<E> as crate::back::Backend>::Elem,
) -> <TchBackend<E> as crate::back::Backend>::BoolTensorPrimitive<D> {
let other: f64 = (*other).into();
let other: f64 = (*other).to_elem();
let tensor = self.tensor.less(other);

TchTensor {
Expand All @@ -134,7 +134,7 @@ where
&self,
other: &<TchBackend<E> as crate::back::Backend>::Elem,
) -> <TchBackend<E> as crate::back::Backend>::BoolTensorPrimitive<D> {
let other: f64 = (*other).into();
let other: f64 = (*other).to_elem();
let tensor = self.tensor.less_equal(other);

TchTensor {
Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/tch/ops/mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ impl<E: TchElement, const D: usize> TensorOpsMask<TchBackend<E>, D> for TchTenso
mask: &<TchBackend<E> as Backend>::BoolTensorPrimitive<D>,
value: E,
) -> Self {
let value: f64 = value.into();
let value: f64 = value.to_elem();
let tensor = self.tensor.f_masked_fill(&mask.tensor, value).unwrap();

Self {
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/src/tensor/backend/tch/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod add;
mod aggregation;
mod arg;
mod cat;
mod creation;
mod device;
Expand Down
Loading

0 comments on commit d62f2b0

Please sign in to comment.