Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: sub ops #66

Merged
merged 2 commits into from
Nov 5, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
refactor: sub ops
  • Loading branch information
nathanielsimard committed Nov 5, 2022
commit 4c2e80bb479dbf278de37432fadba88f925f1616
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![Current Crates.io Version](https://img.shields.io/crates/v/burn.svg)](https://crates.io/crates/burn)
[![Test Status](https://github.com/burn-rs/burn/actions/workflows/test-burn.yml/badge.svg)](https://github.com/burn-rs/burn/actions/workflows/test-burn.yml)
[![Documentation](https://docs.rs/burn/badge.svg)](https://docs.rs/burn)
[![Rust Version](https://img.shields.io/badge/Rust-1.65.0-blue)](https://releases.rs/docs/unreleased/1.65.0)
[![Rust Version](https://img.shields.io/badge/Rust-1.65.0-blue)](https://releases.rs/docs/released/1.65.0)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/burn-rs/burn/blob/master/LICENSE)

> This library aims to be a complete deep learning framework with extreme flexibility written in Rust.
Expand Down
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ mod neg;
mod pow;
mod precision;
mod reshape;
mod sub;
mod tensor;
mod transpose;

Expand Down
70 changes: 69 additions & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
Backend,
},
graph::ops::{BinaryOps, BinaryOpsNodeState, UnaryOps, UnaryOpsNodeState},
ops::TensorOps,
ops::{TensorOps, TensorOpsNeg},
Data, Shape,
};

Expand Down Expand Up @@ -65,6 +65,38 @@ impl<B: Backend, const D: usize>
}
}

#[derive(Default, Debug)]
struct SubBackward<B: Backend, const D: usize> {
_b: B,
}

impl<B: Backend, const D: usize>
BinaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for SubBackward<B, D>
{
fn partial_left(
&self,
state: &BinaryOpsNodeState<
B::TensorPrimitive<D>,
B::TensorPrimitive<D>,
B::TensorPrimitive<D>,
>,
) -> B::TensorPrimitive<D> {
state.output.grad()
}

fn partial_right(
&self,
state: &BinaryOpsNodeState<
B::TensorPrimitive<D>,
B::TensorPrimitive<D>,
B::TensorPrimitive<D>,
>,
) -> B::TensorPrimitive<D> {
state.output.grad().neg()
}
}

#[derive(Default, Debug)]
struct AddScalarBackward<B: Backend, const D: usize> {
_b: B,
Expand All @@ -81,6 +113,22 @@ impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimit
}
}

#[derive(Default, Debug)]
struct SubScalarBackward<B: Backend, const D: usize> {
_b: B,
}

impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for SubScalarBackward<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
state.output.grad()
}
}

impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn shape<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
Expand Down Expand Up @@ -162,4 +210,24 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {

unary_ops_wrapper(lhs.node.clone(), output, ops)
}

fn sub<const D: usize>(
lhs: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
rhs: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
let output = B::sub(lhs.tensor_ref(), rhs.tensor_ref());
let ops = SubBackward::<B, D>::default();

binary_ops_wrapper(lhs.node.clone(), rhs.node.clone(), output, ops)
}

fn sub_scalar<const D: usize>(
lhs: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
rhs: &<ADBackendDecorator<B> as Backend>::Elem,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
let output = B::sub_scalar(lhs.tensor_ref(), rhs);
let ops = SubScalarBackward::<B, D>::default();

unary_ops_wrapper(lhs.node.clone(), output, ops)
}
}
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ pub trait Backend:
+ TensorOpsMul<Self::Elem, D>
+ TensorOpsDiv<Self::Elem, D>
+ TensorOpsNeg<Self::Elem, D>
+ TensorOpsSub<Self::Elem, D>
+ TensorOpsDetach<Self::Elem, D>
+ Zeros<Self::TensorPrimitive<D>>
+ Ones<Self::TensorPrimitive<D>>
Expand Down
18 changes: 5 additions & 13 deletions burn-tensor/src/tensor/backend/ndarray/ops/map_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ where
E: NdArrayElement,
{
fn equal(&self, other: &Self) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
let tensor = self.sub(other);
let tensor = NdArrayBackend::<E>::sub(self, other);
let zero = E::zeros(&E::default());
tensor.equal_scalar(&zero)
}
Expand All @@ -26,7 +26,7 @@ where
}

fn greater(&self, other: &Self) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
let tensor = self.sub(other);
let tensor = NdArrayBackend::<E>::sub(self, other);
let zero = E::zeros(&E::default());
tensor.greater_scalar(&zero)
}
Expand All @@ -47,7 +47,7 @@ where
&self,
other: &Self,
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
let tensor = self.sub(other);
let tensor = NdArrayBackend::<E>::sub(self, other);
let zero = E::zeros(&E::default());
tensor.greater_equal_scalar(&zero)
}
Expand All @@ -65,7 +65,7 @@ where
}

fn lower(&self, other: &Self) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
let tensor = self.sub(other);
let tensor = NdArrayBackend::<E>::sub(self, other);
let zero = E::zeros(&E::default());
tensor.lower_scalar(&zero)
}
Expand All @@ -83,7 +83,7 @@ where
}

fn lower_equal(&self, other: &Self) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
let tensor = self.sub(other);
let tensor = NdArrayBackend::<E>::sub(self, other);
let zero = E::zeros(&E::default());
tensor.lower_equal_scalar(&zero)
}
Expand All @@ -100,11 +100,3 @@ where
}
}
}

#[cfg(tests)]
mod tests {
use super::*;

#[test]
fn test_greater() {}
}
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/ndarray/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ mod neg;
mod pow;
mod precision;
mod reshape;
mod sub;
mod transpose;
43 changes: 0 additions & 43 deletions burn-tensor/src/tensor/backend/ndarray/ops/sub.rs

This file was deleted.

19 changes: 19 additions & 0 deletions burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,25 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
let array = lhs.array.clone() + *rhs;
let shape = lhs.shape;

NdArrayTensor { array, shape }
}
fn sub<const D: usize>(
lhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
rhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
let array = lhs.array.clone() - rhs.array.clone();
let array = array.into_shared();
let shape = lhs.shape.higher(&rhs.shape);

NdArrayTensor { array, shape }
}
fn sub_scalar<const D: usize>(
lhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
rhs: &E,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
let array = lhs.array.clone() - *rhs;
let shape = lhs.shape;

NdArrayTensor { array, shape }
}
}
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/tch/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ mod neg;
mod pow;
mod precision;
mod reshape;
mod sub;
mod transpose;
52 changes: 0 additions & 52 deletions burn-tensor/src/tensor/backend/tch/ops/sub.rs

This file was deleted.

26 changes: 25 additions & 1 deletion burn-tensor/src/tensor/backend/tch/tensor_ops.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::Add;
use std::ops::{Add, Sub};

use super::{TchBackend, TchDevice, TchKind, TchTensor};
use crate::{backend::Backend, ops::TensorOps, Data, Shape, TchElement};
Expand Down Expand Up @@ -93,6 +93,30 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
let kind = lhs.kind;
let shape = lhs.shape;

TchTensor {
tensor,
shape,
kind,
}
}
fn sub<const D: usize>(lhs: &TchTensor<E, D>, rhs: &TchTensor<E, D>) -> TchTensor<E, D> {
let tensor = (&lhs.tensor).sub(&rhs.tensor);
let kind = lhs.kind;
let shape = lhs.shape.higher(&rhs.shape);

TchTensor {
tensor,
shape,
kind,
}
}

fn sub_scalar<const D: usize>(lhs: &TchTensor<E, D>, rhs: &E) -> TchTensor<E, D> {
let other: f64 = (rhs.clone()).to_elem();
let tensor = (&lhs.tensor).sub(other).to_kind(lhs.kind.kind());
let kind = lhs.kind;
let shape = lhs.shape;

TchTensor {
tensor,
shape,
Expand Down
4 changes: 2 additions & 2 deletions burn-tensor/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,14 @@ where
///
/// `y = x2 - x1`
pub fn sub(&self, other: &Self) -> Self {
Self::new(self.value.sub(&other.value))
Self::new(B::sub(&self.value, &other.value))
}

/// Applies element wise substraction operation with a scalar.
///
/// `y = x - s`
pub fn sub_scalar<E: ElementConversion>(&self, other: E) -> Self {
Self::new(self.value.sub_scalar(&other.to_elem()))
Self::new(B::sub_scalar(&self.value, &other.to_elem()))
}

/// Applies the transpose operation.
Expand Down
13 changes: 8 additions & 5 deletions burn-tensor/src/tensor/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,14 @@ pub trait TensorOps<B: Backend> {
lhs: &B::TensorPrimitive<D>,
rhs: &B::Elem,
) -> B::TensorPrimitive<D>;
}

pub trait TensorOpsSub<E, const D: usize> {
fn sub(&self, other: &Self) -> Self;
fn sub_scalar(&self, other: &E) -> Self;
fn sub<const D: usize>(
lhs: &B::TensorPrimitive<D>,
rhs: &B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
fn sub_scalar<const D: usize>(
lhs: &B::TensorPrimitive<D>,
rhs: &B::Elem,
) -> B::TensorPrimitive<D>;
}

pub trait TensorOpsTranspose<E, const D: usize> {
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/tests/tensor/grad/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ mod aggregation;
mod cross_entropy;
mod div;
mod softmax;
mod sub;
Loading