Skip to content

Commit

Permalink
refactor: sub ops (tracel-ai#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Nov 5, 2022
1 parent 0f4c1e4 commit 2bdad6f
Show file tree
Hide file tree
Showing 17 changed files with 189 additions and 226 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![Current Crates.io Version](https://img.shields.io/crates/v/burn.svg)](https://crates.io/crates/burn)
[![Test Status](https://github.com/burn-rs/burn/actions/workflows/test-burn.yml/badge.svg)](https://github.com/burn-rs/burn/actions/workflows/test-burn.yml)
[![Documentation](https://docs.rs/burn/badge.svg)](https://docs.rs/burn)
[![Rust Version](https://img.shields.io/badge/Rust-1.65.0-blue)](https://releases.rs/docs/unreleased/1.65.0)
[![Rust Version](https://img.shields.io/badge/Rust-1.65.0-blue)](https://releases.rs/docs/released/1.65.0)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/burn-rs/burn/blob/master/LICENSE)

> This library aims to be a complete deep learning framework with extreme flexibility written in Rust.
Expand Down
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ mod neg;
mod pow;
mod precision;
mod reshape;
mod sub;
mod tensor;
mod transpose;

Expand Down
104 changes: 0 additions & 104 deletions burn-tensor/src/tensor/backend/autodiff/ops/sub.rs

This file was deleted.

70 changes: 69 additions & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
Backend,
},
graph::ops::{BinaryOps, BinaryOpsNodeState, UnaryOps, UnaryOpsNodeState},
ops::TensorOps,
ops::{TensorOps, TensorOpsNeg},
Data, Shape,
};

Expand Down Expand Up @@ -65,6 +65,38 @@ impl<B: Backend, const D: usize>
}
}

#[derive(Default, Debug)]
struct SubBackward<B: Backend, const D: usize> {
_b: B,
}

impl<B: Backend, const D: usize>
BinaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for SubBackward<B, D>
{
fn partial_left(
&self,
state: &BinaryOpsNodeState<
B::TensorPrimitive<D>,
B::TensorPrimitive<D>,
B::TensorPrimitive<D>,
>,
) -> B::TensorPrimitive<D> {
state.output.grad()
}

fn partial_right(
&self,
state: &BinaryOpsNodeState<
B::TensorPrimitive<D>,
B::TensorPrimitive<D>,
B::TensorPrimitive<D>,
>,
) -> B::TensorPrimitive<D> {
state.output.grad().neg()
}
}

#[derive(Default, Debug)]
struct AddScalarBackward<B: Backend, const D: usize> {
_b: B,
Expand All @@ -81,6 +113,22 @@ impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimit
}
}

#[derive(Default, Debug)]
struct SubScalarBackward<B: Backend, const D: usize> {
_b: B,
}

impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for SubScalarBackward<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
state.output.grad()
}
}

impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn shape<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
Expand Down Expand Up @@ -162,4 +210,24 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {

unary_ops_wrapper(lhs.node.clone(), output, ops)
}

fn sub<const D: usize>(
lhs: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
rhs: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
let output = B::sub(lhs.tensor_ref(), rhs.tensor_ref());
let ops = SubBackward::<B, D>::default();

binary_ops_wrapper(lhs.node.clone(), rhs.node.clone(), output, ops)
}

fn sub_scalar<const D: usize>(
lhs: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
rhs: &<ADBackendDecorator<B> as Backend>::Elem,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
let output = B::sub_scalar(lhs.tensor_ref(), rhs);
let ops = SubScalarBackward::<B, D>::default();

unary_ops_wrapper(lhs.node.clone(), output, ops)
}
}
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ pub trait Backend:
+ TensorOpsMul<Self::Elem, D>
+ TensorOpsDiv<Self::Elem, D>
+ TensorOpsNeg<Self::Elem, D>
+ TensorOpsSub<Self::Elem, D>
+ TensorOpsDetach<Self::Elem, D>
+ Zeros<Self::TensorPrimitive<D>>
+ Ones<Self::TensorPrimitive<D>>
Expand Down
18 changes: 5 additions & 13 deletions burn-tensor/src/tensor/backend/ndarray/ops/map_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ where
E: NdArrayElement,
{
fn equal(&self, other: &Self) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
let tensor = self.sub(other);
let tensor = NdArrayBackend::<E>::sub(self, other);
let zero = E::zeros(&E::default());
tensor.equal_scalar(&zero)
}
Expand All @@ -26,7 +26,7 @@ where
}

fn greater(&self, other: &Self) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
let tensor = self.sub(other);
let tensor = NdArrayBackend::<E>::sub(self, other);
let zero = E::zeros(&E::default());
tensor.greater_scalar(&zero)
}
Expand All @@ -47,7 +47,7 @@ where
&self,
other: &Self,
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
let tensor = self.sub(other);
let tensor = NdArrayBackend::<E>::sub(self, other);
let zero = E::zeros(&E::default());
tensor.greater_equal_scalar(&zero)
}
Expand All @@ -65,7 +65,7 @@ where
}

fn lower(&self, other: &Self) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
let tensor = self.sub(other);
let tensor = NdArrayBackend::<E>::sub(self, other);
let zero = E::zeros(&E::default());
tensor.lower_scalar(&zero)
}
Expand All @@ -83,7 +83,7 @@ where
}

fn lower_equal(&self, other: &Self) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
let tensor = self.sub(other);
let tensor = NdArrayBackend::<E>::sub(self, other);
let zero = E::zeros(&E::default());
tensor.lower_equal_scalar(&zero)
}
Expand All @@ -100,11 +100,3 @@ where
}
}
}

#[cfg(tests)]
mod tests {
use super::*;

#[test]
fn test_greater() {}
}
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/ndarray/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ mod neg;
mod pow;
mod precision;
mod reshape;
mod sub;
mod transpose;
43 changes: 0 additions & 43 deletions burn-tensor/src/tensor/backend/ndarray/ops/sub.rs

This file was deleted.

19 changes: 19 additions & 0 deletions burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,25 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
let array = lhs.array.clone() + *rhs;
let shape = lhs.shape;

NdArrayTensor { array, shape }
}
fn sub<const D: usize>(
lhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
rhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
let array = lhs.array.clone() - rhs.array.clone();
let array = array.into_shared();
let shape = lhs.shape.higher(&rhs.shape);

NdArrayTensor { array, shape }
}
fn sub_scalar<const D: usize>(
lhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
rhs: &E,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
let array = lhs.array.clone() - *rhs;
let shape = lhs.shape;

NdArrayTensor { array, shape }
}
}
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/tch/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ mod neg;
mod pow;
mod precision;
mod reshape;
mod sub;
mod transpose;
Loading

0 comments on commit 2bdad6f

Please sign in to comment.