Skip to content

Commit

Permalink
fix integer backend
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Aug 27, 2022
1 parent 9ed0b0b commit 2c4288d
Show file tree
Hide file tree
Showing 16 changed files with 131 additions and 118 deletions.
38 changes: 38 additions & 0 deletions burn-tensor/src/tensor/api/bool_tensor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use super::Tensor;
use crate::tensor::backend::Backend;
use crate::tensor::ops::*;
use crate::tensor::{Data, Shape};

pub struct BoolTensor<B: Backend, const D: usize> {
pub(crate) value: B::BoolTensorPrimitive<D>,
}

impl<B, const D: usize> BoolTensor<B, D>
where
B: Backend,
{
pub fn new(tensor: B::BoolTensorPrimitive<D>) -> Self {
Self { value: tensor }
}

pub fn shape(&self) -> &Shape<D> {
self.value.shape()
}

pub fn into_data(self) -> Data<bool, D> {
self.value.into_data()
}

pub fn to_data(&self) -> Data<bool, D> {
self.value.to_data()
}

pub fn from_data(data: Data<bool, D>) -> Self {
let value = B::from_data_bool(data, B::Device::default());
Self::new(value)
}

pub fn to_int(&self) -> Tensor<B::IntegerBackend, D> {
Tensor::from_data(self.value.to_data().convert())
}
}
2 changes: 2 additions & 0 deletions burn-tensor/src/tensor/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ pub mod back;
pub mod losses;

mod ad;
mod bool_tensor;
mod ops;
mod tensor;

pub use bool_tensor::*;
pub use tensor::*;
68 changes: 5 additions & 63 deletions burn-tensor/src/tensor/api/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::BoolTensor;
use crate::tensor::activation::*;
use crate::tensor::backend::Backend;
use crate::tensor::ops::*;
Expand All @@ -9,65 +10,6 @@ pub struct Tensor<B: Backend, const D: usize> {
pub(crate) value: B::TensorPrimitive<D>,
}

pub struct BoolTensor<B: Backend, const D: usize> {
pub(crate) value: B::BoolTensorPrimitive<D>,
}

pub struct IndexTensor<B: Backend, const D: usize> {
pub(crate) value: B::IndexTensorPrimitive<D>,
}

impl<B, const D: usize> IndexTensor<B, D>
where
B: Backend,
{
pub fn new(tensor: B::IndexTensorPrimitive<D>) -> Self {
Self { value: tensor }
}

pub fn shape(&self) -> &Shape<D> {
self.value.shape()
}

pub fn into_data(self) -> Data<i64, D> {
self.value.into_data()
}

pub fn to_data(&self) -> Data<i64, D> {
self.value.to_data()
}

pub fn mul(&self, other: &Self) -> Self {
Self::new(self.value.mul(&other.value))
}
}

impl<B, const D: usize> BoolTensor<B, D>
where
B: Backend,
{
pub fn new(tensor: B::BoolTensorPrimitive<D>) -> Self {
Self { value: tensor }
}

pub fn shape(&self) -> &Shape<D> {
self.value.shape()
}

pub fn into_data(self) -> Data<bool, D> {
self.value.into_data()
}

pub fn to_data(&self) -> Data<bool, D> {
self.value.to_data()
}

pub fn from_data(data: Data<bool, D>) -> Self {
let value = B::from_data_bool(data, B::Device::default());
Self::new(value)
}
}

impl<const D: usize, B> Tensor<B, D>
where
B: Backend,
Expand Down Expand Up @@ -288,12 +230,12 @@ where
Tensor::new(value)
}

pub fn argmax(&self, dim: usize) -> IndexTensor<B, D> {
IndexTensor::new(self.value.argmax(dim))
pub fn argmax(&self, dim: usize) -> Tensor<B::IntegerBackend, D> {
Tensor::new(self.value.argmax(dim))
}

pub fn argmin(&self, dim: usize) -> IndexTensor<B, D> {
IndexTensor::new(self.value.argmin(dim))
pub fn argmin(&self, dim: usize) -> Tensor<B::IntegerBackend, D> {
Tensor::new(self.value.argmin(dim))
}

pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
Expand Down
3 changes: 1 addition & 2 deletions burn-tensor/src/tensor/backend/autodiff/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ macro_rules! define_impl {
type Device = <$backend as Backend>::Device;
type Elem = E;
type FullPrecisionElem = f32;
type IntegerBackend = <$backend as Backend>::IntegerBackend;
type FullPrecisionBackend = $name<<$backend as Backend>::FullPrecisionElem>;
type TensorPrimitive<const D: usize> = ADTensor<D, $backend>;
type BoolTensorPrimitive<const D: usize> =
<$backend as Backend>::BoolTensorPrimitive<D>;
type IndexTensorPrimitive<const D: usize> =
<$backend as Backend>::IndexTensorPrimitive<D>;

fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
Expand Down
10 changes: 8 additions & 2 deletions burn-tensor/src/tensor/backend/autodiff/ops/arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@ macro_rules! define_impl {
impl<E: $element, const D: usize> TensorOpsArg<$backend, D>
for <$backend as Backend>::TensorPrimitive<D>
{
fn argmax(&self, dim: usize) -> <$backend as Backend>::IndexTensorPrimitive<D> {
fn argmax(
&self,
dim: usize,
) -> <<$backend as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
TensorOpsArg::argmax(&self.tensor(), dim)
}

fn argmin(&self, dim: usize) -> <$backend as Backend>::IndexTensorPrimitive<D> {
fn argmin(
&self,
dim: usize,
) -> <<$backend as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
TensorOpsArg::argmin(&self.tensor(), dim)
}
}
Expand Down
9 changes: 2 additions & 7 deletions burn-tensor/src/tensor/backend/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ pub trait Backend: Clone + Sized + Default + Send + Sync + std::fmt::Debug + 'st
type Device: Copy + Clone + Default + std::fmt::Debug + Send + Sync;
type Elem: Element;
type FullPrecisionElem: Element;
type FullPrecisionBackend: Backend<Elem = Self::FullPrecisionElem>;
type FullPrecisionBackend: Backend<Elem = Self::FullPrecisionElem, Device = Self::Device>;
type IntegerBackend: Backend<Elem = i64, Device = Self::Device>;
type TensorPrimitive<const D: usize>: TensorTrait<Self::Elem, D>
+ TensorOpsReshape<Self, D>
+ TensorOpsPrecision<Self, D>
Expand All @@ -36,12 +37,6 @@ pub trait Backend: Clone + Sized + Default + Send + Sync + std::fmt::Debug + 'st
+ Sync
+ 'static
+ std::fmt::Debug;
type IndexTensorPrimitive<const D: usize>: TensorTrait<i64, D>
+ Clone
+ Send
+ Sync
+ 'static
+ std::fmt::Debug;

fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/ndarray/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ impl<E: NdArrayElement> Backend for NdArrayBackend<E> {
type Elem = E;
type FullPrecisionElem = f32;
type FullPrecisionBackend = NdArrayBackend<f32>;
type IntegerBackend = NdArrayBackend<i64>;
type TensorPrimitive<const D: usize> = NdArrayTensor<E, D>;
type BoolTensorPrimitive<const D: usize> = NdArrayTensor<bool, D>;
type IndexTensorPrimitive<const D: usize> = NdArrayTensor<i64, D>;

fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
Expand Down
51 changes: 26 additions & 25 deletions burn-tensor/src/tensor/backend/ndarray/ops/arg.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
use crate::backend::ndarray::NdArrayBackend;
use crate::tensor::{backend::ndarray::NdArrayTensor, ops::*};
use crate::ElementValue;
use crate::NdArrayElement;
use crate::{Data, NdArrayElement};
use std::cmp::Ordering;

impl<E, const D: usize> TensorOpsArg<NdArrayBackend<E>, D> for NdArrayTensor<E, D>
where
E: NdArrayElement,
{
fn argmax(&self, dim: usize) -> NdArrayTensor<i64, D> {
arg(self, dim, cmp_max)
arg(self, dim, cmp_min)
}

fn argmin(&self, dim: usize) -> NdArrayTensor<i64, D> {
arg(self, dim, cmp_min)
arg(self, dim, cmp_max)
}
}

Expand All @@ -26,32 +25,34 @@ where
F: Fn(&f64, &f64) -> Ordering,
{
let mut data = tensor.to_data();
let mut start = 1;
let batch_size = tensor.shape.dims[dim];
let mut start = 0;
let mut end = tensor.shape.dims[dim];
let mut output = Vec::new();

for i in 0..dim {
start = start * tensor.shape.dims[i];
}
let end = start + tensor.shape.dims[dim];

let data_dim = &mut data.value[start..end];
let mut sorted: Vec<f64> = data_dim.iter().map(|a| a.to_elem()).collect();
sorted.sort_by(cmp);
while end <= data.value.len() {
let data_dim = &mut data.value[start..end];
let mut sorted: Vec<f64> = data_dim.iter().map(|a| a.to_elem()).collect();
sorted.sort_by(&cmp);

let max = sorted[0];
for elem in data_dim {
*elem = <E as ElementValue>::zero();
}
let max = sorted[0];

let data_dim = &mut data.value[start..end];
for elem in data_dim {
let as_float: f64 = elem.to_elem();
if as_float == max {
*elem = <E as ElementValue>::one();
break;
let data_dim = &mut data.value[start..end];
let mut index: i64 = 0;
for elem in data_dim {
let as_float: f64 = elem.to_elem();
if as_float == max {
break;
}
index += 1;
}
output.push(index);
start = start + batch_size;
end = end + batch_size;
}

NdArrayTensor::from_data(data.convert())
let mut shape = tensor.shape.clone();
shape.dims[dim] = 1;
NdArrayTensor::from_data(Data::new(output, shape))
}

fn cmp_max(a: &f64, b: &f64) -> Ordering {
Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/tch/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ impl<E: TchElement> Backend for TchBackend<E> {
type Elem = E;
type FullPrecisionElem = f32;
type FullPrecisionBackend = TchBackend<f32>;
type IntegerBackend = TchBackend<i64>;
type TensorPrimitive<const D: usize> = TchTensor<E, D>;
type BoolTensorPrimitive<const D: usize> = TchTensor<bool, D>;
type IndexTensorPrimitive<const D: usize> = TchTensor<i64, D>;

fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
Expand Down
6 changes: 4 additions & 2 deletions burn-tensor/src/tensor/backend/tch/ops/arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ where
{
fn argmax(&self, dim: usize) -> TchTensor<i64, D> {
let tensor = self.tensor.argmax(dim as i64, true);
let shape = self.shape.clone();
let mut shape = self.shape.clone();
shape.dims[dim] = 1;

TchTensor {
tensor,
Expand All @@ -22,7 +23,8 @@ where

fn argmin(&self, dim: usize) -> TchTensor<i64, D> {
let tensor = self.tensor.argmin(dim as i64, true);
let shape = self.shape.clone();
let mut shape = self.shape.clone();
shape.dims[dim] = 1;

TchTensor {
tensor,
Expand Down
16 changes: 15 additions & 1 deletion burn-tensor/src/tensor/data.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::ops::{Ones, Zeros};
use crate::{tensor::Shape, Element};
use crate::{tensor::Shape, Element, ElementConversion};
use rand::{distributions::Standard, prelude::StdRng, Rng, SeedableRng};

#[derive(serde::Serialize, serde::Deserialize, Debug)]
Expand Down Expand Up @@ -94,6 +94,20 @@ impl<const D: usize, P: Element> Data<P, D> {
}
}

impl<const D: usize> Data<bool, D> {
pub fn convert<E: Element>(self) -> Data<E, D> {
let value: Vec<E> = self
.value
.into_iter()
.map(|a| (a as i64).to_elem())
.collect();

Data {
value,
shape: self.shape,
}
}
}
impl<P: Element, const D: usize> Data<P, D> {
pub fn random(shape: Shape<D>, distribution: Distribution<P>) -> Self {
let num_elements = shape.num_elements();
Expand Down
4 changes: 2 additions & 2 deletions burn-tensor/src/tensor/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ pub trait TensorOpsPrecision<B: Backend, const D: usize> {
}

pub trait TensorOpsArg<B: Backend, const D: usize> {
fn argmax(&self, dim: usize) -> B::IndexTensorPrimitive<D>;
fn argmin(&self, dim: usize) -> B::IndexTensorPrimitive<D>;
fn argmax(&self, dim: usize) -> <B::IntegerBackend as Backend>::TensorPrimitive<D>;
fn argmin(&self, dim: usize) -> <B::IntegerBackend as Backend>::TensorPrimitive<D>;
}

pub trait TensorOpsExp<E, const D: usize> {
Expand Down
13 changes: 13 additions & 0 deletions burn-tensor/tests/tensor/ops/arg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use super::super::TestBackend;
use burn_tensor::{Data, Tensor};

#[test]
fn test_argmax_2d() {
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);

let data_actual = tensor.argmax(1);

let data_expected = Data::from([[2], [2]]);
assert_eq!(data_expected, data_actual.to_data());
}
1 change: 1 addition & 0 deletions burn-tensor/tests/tensor/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod add;
mod aggregation;
mod arg;
mod div;
mod exp;
mod index;
Expand Down
4 changes: 3 additions & 1 deletion examples/mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ fn run<B: ad::Backend>(device: B::Device) {
let num_epochs = 10;
let num_workers = 8;
let num_layers = 4;
let hidden_dim = 5560;
let hidden_dim = 1024;
let seed = 42;
let metrics = || -> Vec<Box<dyn Metric<ClassificationOutput<B>>>> {
vec![
Expand Down Expand Up @@ -207,4 +207,6 @@ fn run<B: ad::Backend>(device: B::Device) {
fn main() {
let device = burn::tensor::back::TchDevice::Cuda(0);
run::<ad::Tch<burn::tensor::f16>>(device);
// let device = burn::tensor::back::NdArrayDevice::Cpu;
// run::<ad::NdArray<f32>>(device);
}
Loading

0 comments on commit 2c4288d

Please sign in to comment.