From 4c2e80bb479dbf278de37432fadba88f925f1616 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Sat, 5 Nov 2022 09:48:00 -0400 Subject: [PATCH 1/2] refactor: sub ops --- README.md | 2 +- .../src/tensor/backend/autodiff/ops/mod.rs | 1 - .../src/tensor/backend/autodiff/ops/tensor.rs | 70 ++++++++++++++++++- burn-tensor/src/tensor/backend/base.rs | 1 - .../backend/ndarray/ops/map_comparison.rs | 18 ++--- .../src/tensor/backend/ndarray/ops/mod.rs | 1 - .../src/tensor/backend/ndarray/ops/sub.rs | 43 ------------ .../src/tensor/backend/ndarray/tensor_ops.rs | 19 +++++ burn-tensor/src/tensor/backend/tch/ops/mod.rs | 1 - burn-tensor/src/tensor/backend/tch/ops/sub.rs | 52 -------------- .../src/tensor/backend/tch/tensor_ops.rs | 26 ++++++- burn-tensor/src/tensor/base.rs | 4 +- burn-tensor/src/tensor/ops/base.rs | 13 ++-- burn-tensor/tests/tensor/grad/mod.rs | 1 + burn-tensor/tests/tensor/grad/sub.rs | 57 +++++++++++++++ burn-tensor/tests/tensor/mod.rs | 2 + 16 files changed, 189 insertions(+), 122 deletions(-) delete mode 100644 burn-tensor/src/tensor/backend/ndarray/ops/sub.rs delete mode 100644 burn-tensor/src/tensor/backend/tch/ops/sub.rs create mode 100644 burn-tensor/tests/tensor/grad/sub.rs diff --git a/README.md b/README.md index 924b91ac20..b05821f821 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![Current Crates.io Version](https://img.shields.io/crates/v/burn.svg)](https://crates.io/crates/burn) [![Test Status](https://github.com/burn-rs/burn/actions/workflows/test-burn.yml/badge.svg)](https://github.com/burn-rs/burn/actions/workflows/test-burn.yml) [![Documentation](https://docs.rs/burn/badge.svg)](https://docs.rs/burn) -[![Rust Version](https://img.shields.io/badge/Rust-1.65.0-blue)](https://releases.rs/docs/unreleased/1.65.0) +[![Rust Version](https://img.shields.io/badge/Rust-1.65.0-blue)](https://releases.rs/docs/released/1.65.0) [![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/burn-rs/burn/blob/master/LICENSE) > This library aims to be a complete deep learning framework with extreme flexibility written in Rust. diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs index b9d98fda75..514a63995c 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs @@ -18,7 +18,6 @@ mod neg; mod pow; mod precision; mod reshape; -mod sub; mod tensor; mod transpose; diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs b/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs index 4636e7d717..f95d17c6a2 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs @@ -5,7 +5,7 @@ use crate::{ Backend, }, graph::ops::{BinaryOps, BinaryOpsNodeState, UnaryOps, UnaryOpsNodeState}, - ops::TensorOps, + ops::{TensorOps, TensorOpsNeg}, Data, Shape, }; @@ -65,6 +65,38 @@ impl } } +#[derive(Default, Debug)] +struct SubBackward { + _b: B, +} + +impl + BinaryOps, B::TensorPrimitive, B::TensorPrimitive> + for SubBackward +{ + fn partial_left( + &self, + state: &BinaryOpsNodeState< + B::TensorPrimitive, + B::TensorPrimitive, + B::TensorPrimitive, + >, + ) -> B::TensorPrimitive { + state.output.grad() + } + + fn partial_right( + &self, + state: &BinaryOpsNodeState< + B::TensorPrimitive, + B::TensorPrimitive, + B::TensorPrimitive, + >, + ) -> B::TensorPrimitive { + state.output.grad().neg() + } +} + #[derive(Default, Debug)] struct AddScalarBackward { _b: B, @@ -81,6 +113,22 @@ impl UnaryOps, B::TensorPrimit } } +#[derive(Default, Debug)] +struct SubScalarBackward { + _b: B, +} + +impl UnaryOps, B::TensorPrimitive> + for SubScalarBackward +{ + fn partial( + &self, + state: &UnaryOpsNodeState, B::TensorPrimitive>, + ) -> B::TensorPrimitive { + state.output.grad() + } +} + impl TensorOps> for ADBackendDecorator { fn shape( tensor: & as Backend>::TensorPrimitive, @@ -162,4 +210,24 @@ impl TensorOps> for ADBackendDecorator { unary_ops_wrapper(lhs.node.clone(), output, ops) } + + fn sub( + lhs: & as Backend>::TensorPrimitive, + rhs: & as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + let output = B::sub(lhs.tensor_ref(), rhs.tensor_ref()); + let ops = SubBackward::::default(); + + binary_ops_wrapper(lhs.node.clone(), rhs.node.clone(), output, ops) + } + + fn sub_scalar( + lhs: & as Backend>::TensorPrimitive, + rhs: & as Backend>::Elem, + ) -> as Backend>::TensorPrimitive { + let output = B::sub_scalar(lhs.tensor_ref(), rhs); + let ops = SubScalarBackward::::default(); + + unary_ops_wrapper(lhs.node.clone(), output, ops) + } } diff --git a/burn-tensor/src/tensor/backend/base.rs b/burn-tensor/src/tensor/backend/base.rs index b09f09e081..c9c5045c13 100644 --- a/burn-tensor/src/tensor/backend/base.rs +++ b/burn-tensor/src/tensor/backend/base.rs @@ -27,7 +27,6 @@ pub trait Backend: + TensorOpsMul + TensorOpsDiv + TensorOpsNeg - + TensorOpsSub + TensorOpsDetach + Zeros> + Ones> diff --git a/burn-tensor/src/tensor/backend/ndarray/ops/map_comparison.rs b/burn-tensor/src/tensor/backend/ndarray/ops/map_comparison.rs index cac469e043..26e2bc1825 100644 --- a/burn-tensor/src/tensor/backend/ndarray/ops/map_comparison.rs +++ b/burn-tensor/src/tensor/backend/ndarray/ops/map_comparison.rs @@ -8,7 +8,7 @@ where E: NdArrayElement, { fn equal(&self, other: &Self) -> as Backend>::BoolTensorPrimitive { - let tensor = self.sub(other); + let tensor = NdArrayBackend::::sub(self, other); let zero = E::zeros(&E::default()); tensor.equal_scalar(&zero) } @@ -26,7 +26,7 @@ where } fn greater(&self, other: &Self) -> as Backend>::BoolTensorPrimitive { - let tensor = self.sub(other); + let tensor = NdArrayBackend::::sub(self, other); let zero = E::zeros(&E::default()); tensor.greater_scalar(&zero) } @@ -47,7 +47,7 @@ where &self, other: &Self, ) -> as Backend>::BoolTensorPrimitive { - let tensor = self.sub(other); + let tensor = NdArrayBackend::::sub(self, other); let zero = E::zeros(&E::default()); tensor.greater_equal_scalar(&zero) } @@ -65,7 +65,7 @@ where } fn lower(&self, other: &Self) -> as Backend>::BoolTensorPrimitive { - let tensor = self.sub(other); + let tensor = NdArrayBackend::::sub(self, other); let zero = E::zeros(&E::default()); tensor.lower_scalar(&zero) } @@ -83,7 +83,7 @@ where } fn lower_equal(&self, other: &Self) -> as Backend>::BoolTensorPrimitive { - let tensor = self.sub(other); + let tensor = NdArrayBackend::::sub(self, other); let zero = E::zeros(&E::default()); tensor.lower_equal_scalar(&zero) } @@ -100,11 +100,3 @@ where } } } - -#[cfg(tests)] -mod tests { - use super::*; - - #[test] - fn test_greater() {} -} diff --git a/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs b/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs index 6c685f7aa3..ffe550952f 100644 --- a/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs @@ -16,5 +16,4 @@ mod neg; mod pow; mod precision; mod reshape; -mod sub; mod transpose; diff --git a/burn-tensor/src/tensor/backend/ndarray/ops/sub.rs b/burn-tensor/src/tensor/backend/ndarray/ops/sub.rs deleted file mode 100644 index 90ebbf06e5..0000000000 --- a/burn-tensor/src/tensor/backend/ndarray/ops/sub.rs +++ /dev/null @@ -1,43 +0,0 @@ -use crate::tensor::{backend::ndarray::NdArrayTensor, ops::*}; -use ndarray::{LinalgScalar, ScalarOperand}; - -impl TensorOpsSub for NdArrayTensor -where - P: Clone + LinalgScalar + Default + std::fmt::Debug + ScalarOperand, -{ - fn sub(&self, other: &Self) -> Self { - let array = self.array.clone() - other.array.clone(); - let array = array.into_shared(); - let shape = self.shape.higher(&other.shape); - - Self { array, shape } - } - fn sub_scalar(&self, other: &P) -> Self { - let array = self.array.clone() - *other; - let shape = self.shape; - - Self { array, shape } - } -} - -impl std::ops::Sub for NdArrayTensor -where - P: Clone + LinalgScalar + Default + std::fmt::Debug + ScalarOperand, -{ - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - TensorOpsSub::sub(&self, &rhs) - } -} - -impl std::ops::Sub

for NdArrayTensor -where - P: Clone + LinalgScalar + Default + std::fmt::Debug + ScalarOperand, -{ - type Output = Self; - - fn sub(self, rhs: P) -> Self::Output { - TensorOpsSub::sub_scalar(&self, &rhs) - } -} diff --git a/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs b/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs index 2923005d85..1331023845 100644 --- a/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs +++ b/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs @@ -80,6 +80,25 @@ impl TensorOps> for NdArrayBackend { let array = lhs.array.clone() + *rhs; let shape = lhs.shape; + NdArrayTensor { array, shape } + } + fn sub( + lhs: & as Backend>::TensorPrimitive, + rhs: & as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + let array = lhs.array.clone() - rhs.array.clone(); + let array = array.into_shared(); + let shape = lhs.shape.higher(&rhs.shape); + + NdArrayTensor { array, shape } + } + fn sub_scalar( + lhs: & as Backend>::TensorPrimitive, + rhs: &E, + ) -> as Backend>::TensorPrimitive { + let array = lhs.array.clone() - *rhs; + let shape = lhs.shape; + NdArrayTensor { array, shape } } } diff --git a/burn-tensor/src/tensor/backend/tch/ops/mod.rs b/burn-tensor/src/tensor/backend/tch/ops/mod.rs index 6c685f7aa3..ffe550952f 100644 --- a/burn-tensor/src/tensor/backend/tch/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/tch/ops/mod.rs @@ -16,5 +16,4 @@ mod neg; mod pow; mod precision; mod reshape; -mod sub; mod transpose; diff --git a/burn-tensor/src/tensor/backend/tch/ops/sub.rs b/burn-tensor/src/tensor/backend/tch/ops/sub.rs deleted file mode 100644 index 17b1947826..0000000000 --- a/burn-tensor/src/tensor/backend/tch/ops/sub.rs +++ /dev/null @@ -1,52 +0,0 @@ -use crate::tensor::{backend::tch::TchTensor, ops::*, Data}; -use std::ops::Sub; - -impl TensorOpsSub - for TchTensor -{ - fn sub(&self, other: &Self) -> Self { - let tensor = (&self.tensor).sub(&other.tensor); - let kind = self.kind; - let shape = self.shape.higher(&other.shape); - - Self { - tensor, - shape, - kind, - } - } - fn sub_scalar(&self, other: &P) -> Self { - let elems: [P; D] = [*other; D]; - let data = Data::from(elems); - let other = TchTensor::from_data(data, self.tensor.device()); - let tensor = (&self.tensor).sub(&other.tensor); - let kind = self.kind; - let shape = self.shape; - - Self { - tensor, - shape, - kind, - } - } -} - -impl std::ops::Sub - for TchTensor -{ - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - TensorOpsSub::sub(&self, &rhs) - } -} - -impl std::ops::Sub

- for TchTensor -{ - type Output = Self; - - fn sub(self, rhs: P) -> Self::Output { - TensorOpsSub::sub_scalar(&self, &rhs) - } -} diff --git a/burn-tensor/src/tensor/backend/tch/tensor_ops.rs b/burn-tensor/src/tensor/backend/tch/tensor_ops.rs index 54dc58b88c..9287730003 100644 --- a/burn-tensor/src/tensor/backend/tch/tensor_ops.rs +++ b/burn-tensor/src/tensor/backend/tch/tensor_ops.rs @@ -1,4 +1,4 @@ -use std::ops::Add; +use std::ops::{Add, Sub}; use super::{TchBackend, TchDevice, TchKind, TchTensor}; use crate::{backend::Backend, ops::TensorOps, Data, Shape, TchElement}; @@ -93,6 +93,30 @@ impl TensorOps> for TchBackend { let kind = lhs.kind; let shape = lhs.shape; + TchTensor { + tensor, + shape, + kind, + } + } + fn sub(lhs: &TchTensor, rhs: &TchTensor) -> TchTensor { + let tensor = (&lhs.tensor).sub(&rhs.tensor); + let kind = lhs.kind; + let shape = lhs.shape.higher(&rhs.shape); + + TchTensor { + tensor, + shape, + kind, + } + } + + fn sub_scalar(lhs: &TchTensor, rhs: &E) -> TchTensor { + let other: f64 = (rhs.clone()).to_elem(); + let tensor = (&lhs.tensor).sub(other).to_kind(lhs.kind.kind()); + let kind = lhs.kind; + let shape = lhs.shape; + TchTensor { tensor, shape, diff --git a/burn-tensor/src/tensor/base.rs b/burn-tensor/src/tensor/base.rs index f7cbcfc10f..42c3f75eae 100644 --- a/burn-tensor/src/tensor/base.rs +++ b/burn-tensor/src/tensor/base.rs @@ -179,14 +179,14 @@ where /// /// `y = x2 - x1` pub fn sub(&self, other: &Self) -> Self { - Self::new(self.value.sub(&other.value)) + Self::new(B::sub(&self.value, &other.value)) } /// Applies element wise substraction operation with a scalar. /// /// `y = x - s` pub fn sub_scalar(&self, other: E) -> Self { - Self::new(self.value.sub_scalar(&other.to_elem())) + Self::new(B::sub_scalar(&self.value, &other.to_elem())) } /// Applies the transpose operation. diff --git a/burn-tensor/src/tensor/ops/base.rs b/burn-tensor/src/tensor/ops/base.rs index c3a0a0a3a6..0c029eda07 100644 --- a/burn-tensor/src/tensor/ops/base.rs +++ b/burn-tensor/src/tensor/ops/base.rs @@ -73,11 +73,14 @@ pub trait TensorOps { lhs: &B::TensorPrimitive, rhs: &B::Elem, ) -> B::TensorPrimitive; -} - -pub trait TensorOpsSub { - fn sub(&self, other: &Self) -> Self; - fn sub_scalar(&self, other: &E) -> Self; + fn sub( + lhs: &B::TensorPrimitive, + rhs: &B::TensorPrimitive, + ) -> B::TensorPrimitive; + fn sub_scalar( + lhs: &B::TensorPrimitive, + rhs: &B::Elem, + ) -> B::TensorPrimitive; } pub trait TensorOpsTranspose { diff --git a/burn-tensor/tests/tensor/grad/mod.rs b/burn-tensor/tests/tensor/grad/mod.rs index d8adb58f11..f200bc1ccf 100644 --- a/burn-tensor/tests/tensor/grad/mod.rs +++ b/burn-tensor/tests/tensor/grad/mod.rs @@ -3,3 +3,4 @@ mod aggregation; mod cross_entropy; mod div; mod softmax; +mod sub; diff --git a/burn-tensor/tests/tensor/grad/sub.rs b/burn-tensor/tests/tensor/grad/sub.rs new file mode 100644 index 0000000000..4350ccbf82 --- /dev/null +++ b/burn-tensor/tests/tensor/grad/sub.rs @@ -0,0 +1,57 @@ +use crate::tensor::TestADTensor; +use burn_tensor::Data; + +#[test] +fn should_diff_sub() { + let data_1 = Data::from([2.0, 5.0]); + let data_2 = Data::from([4.0, 1.0]); + + let tensor_1 = TestADTensor::from_data(data_1); + let tensor_2 = TestADTensor::from_data(data_2); + + let tensor_3 = tensor_1.sub(&tensor_2); + let grads = tensor_3.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([1.0, 1.0])); + assert_eq!(grad_2.to_data(), Data::from([-1.0, -1.0])); + assert_eq!(tensor_3.into_data(), Data::from([-2.0, 4.0])); +} + +#[test] +fn should_diff_sub_scalar() { + let data = Data::from([2.0, 10.0]); + let tensor = TestADTensor::from_data(data); + let tensor_out = tensor.sub_scalar(5.0); + let grads = tensor_out.backward(); + + let grad = tensor.grad(&grads).unwrap(); + + assert_eq!(grad.to_data(), Data::from([1.0, 1.0])); + assert_eq!(tensor_out.into_data(), Data::from([-3.0, 5.0])); +} + +#[test] +fn test_sub_complex_1() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); + + let tensor_1 = TestADTensor::from_data(data_1); + let tensor_2 = TestADTensor::from_data(data_2); + let tensor_3 = TestADTensor::from_data(data_3); + + let tensor_4 = tensor_1.sub(&tensor_2); + let tensor_5 = tensor_4.sub(&tensor_3).sub_scalar(5.0); + let tensor_6 = tensor_1.sub(&tensor_5); + + let grads = tensor_6.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([[0.0, 0.0], [0.0, 0.0]])); + assert_eq!(grad_2.to_data(), Data::from([[1.0, 1.0], [1.0, 1.0]])); +} diff --git a/burn-tensor/tests/tensor/mod.rs b/burn-tensor/tests/tensor/mod.rs index 39711ecb0e..43039e5e18 100644 --- a/burn-tensor/tests/tensor/mod.rs +++ b/burn-tensor/tests/tensor/mod.rs @@ -10,6 +10,8 @@ pub type TestADBackend = burn_tensor::backend::NdArrayADBackend; #[cfg(all(feature = "tch", not(any(feature = "ndarray"))))] pub type TestADBackend = burn_tensor::backend::TchADBackend; +pub type TestADTensor = burn_tensor::Tensor; + mod activation; mod grad; mod module; From 051134c934c3f1a843a5e2ddc1439c1d0a4e244a Mon Sep 17 00:00:00 2001 From: nathaniel Date: Sat, 5 Nov 2022 09:51:11 -0400 Subject: [PATCH 2/2] Remove unused file --- .../src/tensor/backend/autodiff/ops/sub.rs | 104 ------------------ 1 file changed, 104 deletions(-) delete mode 100644 burn-tensor/src/tensor/backend/autodiff/ops/sub.rs diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/sub.rs b/burn-tensor/src/tensor/backend/autodiff/ops/sub.rs deleted file mode 100644 index e74c6489dc..0000000000 --- a/burn-tensor/src/tensor/backend/autodiff/ops/sub.rs +++ /dev/null @@ -1,104 +0,0 @@ -use crate::graph::ops::{BinaryOps, BinaryOpsNodeState, UnaryOps, UnaryOpsNodeState}; -use crate::tensor::backend::Backend; -use crate::{ - execute_ops, register_ops, - tensor::{backend::autodiff::ADTensor, ops::*}, -}; - -register_ops!( - ops BinaryOps, - name ADTensorSubOps, - partial_left |state: &BinaryOpsNodeState, B::TensorPrimitive, B::TensorPrimitive>| { - state.output.grad() - }, - partial_right |state: &BinaryOpsNodeState, B::TensorPrimitive, B::TensorPrimitive>| { - state.output.grad().neg() - }, -); - -register_ops!( - ops UnaryOps, - name ADTensorSubScalarOps state B::Elem, - partial |_state, state_recorded: &UnaryOpsNodeState, B::TensorPrimitive>|{ - state_recorded.output.grad() - }, -); - -impl TensorOpsSub for ADTensor { - fn sub(&self, other: &Self) -> Self { - execute_ops!( - lhs self.node.clone(), - rhs other.node.clone(), - out TensorOpsSub::sub(&self.tensor(), &other.tensor()), - ops ADTensorSubOps::::new(), - ) - } - - fn sub_scalar(&self, other: &B::Elem) -> Self { - execute_ops!( - input self.node.clone(), - out TensorOpsSub::sub_scalar(&self.tensor(), other), - ops ADTensorSubScalarOps::::new(*other), - ) - } -} - -#[cfg(test)] -mod tests { - use crate::tensor::{backend::autodiff::helper::TestADTensor, Data}; - - #[test] - fn should_diff_sub() { - let data_1 = Data::from([2.0, 5.0]); - let data_2 = Data::from([4.0, 1.0]); - - let tensor_1 = TestADTensor::from_data(data_1); - let tensor_2 = TestADTensor::from_data(data_2); - - let tensor_3 = tensor_1.sub(&tensor_2); - let grads = tensor_3.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!(grad_1.to_data(), Data::from([1.0, 1.0])); - assert_eq!(grad_2.to_data(), Data::from([-1.0, -1.0])); - assert_eq!(tensor_3.into_data(), Data::from([-2.0, 4.0])); - } - - #[test] - fn should_diff_sub_scalar() { - let data = Data::from([2.0, 10.0]); - let tensor = TestADTensor::from_data(data); - let tensor_out = tensor.sub_scalar(5.0); - let grads = tensor_out.backward(); - - let grad = tensor.grad(&grads).unwrap(); - - assert_eq!(grad.to_data(), Data::from([1.0, 1.0])); - assert_eq!(tensor_out.into_data(), Data::from([-3.0, 5.0])); - } - - #[test] - fn test_sub_complex_1() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); - - let tensor_1 = TestADTensor::from_data(data_1); - let tensor_2 = TestADTensor::from_data(data_2); - let tensor_3 = TestADTensor::from_data(data_3); - - let tensor_4 = tensor_1.sub(&tensor_2); - let tensor_5 = tensor_4.sub(&tensor_3).sub_scalar(5.0); - let tensor_6 = tensor_1.sub(&tensor_5); - - let grads = tensor_6.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!(grad_1.to_data(), Data::from([[0.0, 0.0], [0.0, 0.0]])); - assert_eq!(grad_2.to_data(), Data::from([[1.0, 1.0], [1.0, 1.0]])); - } -}