Skip to content

Commit

Permalink
Perf/tensor ops/more tests (tracel-ai#718)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Aug 30, 2023
1 parent bb71f17 commit 7c34e21
Show file tree
Hide file tree
Showing 20 changed files with 444 additions and 71 deletions.
2 changes: 1 addition & 1 deletion burn-autodiff/src/tests/gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mod tests {
let x = tensor_1.clone().matmul(activation::gelu(tensor_2));
let mut grads = x.backward();

let grad_1 = tensor_1.grad(&mut grads).unwrap();
let grad_1 = tensor_1.grad(&grads).unwrap();

let grad_1_updated = TestADTensor::random([32, 32], Distribution::Default).require_grad();
tensor_1.grad_replace(&mut grads, grad_1_updated.clone().inner());
Expand Down
6 changes: 3 additions & 3 deletions burn-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,21 @@ mod tests {

// test ops
burn_tensor::testgen_add!();
burn_tensor::testgen_aggregation!();
// burn_tensor::testgen_aggregation!();
burn_tensor::testgen_arange!();
burn_tensor::testgen_arange_step!();
burn_tensor::testgen_arg!();
burn_tensor::testgen_cast!();
burn_tensor::testgen_cat!();
burn_tensor::testgen_clamp!();
burn_tensor::testgen_cos!();
burn_tensor::testgen_div!();
burn_tensor::testgen_empty!();
// burn_tensor::testgen_div!();
// burn_tensor::testgen_erf!();
burn_tensor::testgen_exp!();
burn_tensor::testgen_flatten!();
burn_tensor::testgen_full!();
burn_tensor::testgen_gather_scatter!();
burn_tensor::testgen_init!();
burn_tensor::testgen_log!();
burn_tensor::testgen_log1p!();
burn_tensor::testgen_map_comparison!();
Expand Down
26 changes: 21 additions & 5 deletions burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<CandleBackend<F, I
) -> IntTensor<Self, D> {
CandleTensor::new(
mask.tensor
.where_cond(&tensor.tensor, &tensor.tensor)
.where_cond(&source.tensor, &tensor.tensor)
.unwrap(),
)
}
Expand Down Expand Up @@ -284,7 +284,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<CandleBackend<F, I
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> IntTensor<Self, D> {
CandleTensor::new((lhs.tensor / rhs.elem::<f64>()).unwrap())
// Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0.
panic!("Not supported by Candle")
}

fn int_zeros<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
Expand Down Expand Up @@ -312,15 +313,30 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<CandleBackend<F, I
}

fn int_mean_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
CandleTensor::new(tensor.tensor.mean_keepdim(dim).unwrap())
// Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0.
panic!("Not supported by Candle")
}

fn int_argmax<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
CandleTensor::new(tensor.tensor.argmax_keepdim(dim).unwrap())
CandleTensor::new(
tensor
.tensor
.argmax_keepdim(dim)
.unwrap()
.to_dtype(I::DTYPE)
.unwrap(),
)
}

fn int_argmin<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
CandleTensor::new(tensor.tensor.argmin_keepdim(dim).unwrap())
CandleTensor::new(
tensor
.tensor
.argmin_keepdim(dim)
.unwrap()
.to_dtype(I::DTYPE)
.unwrap(),
)
}

fn int_abs<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
Expand Down
4 changes: 3 additions & 1 deletion burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<CandleBackend<F, I>>
Distribution::Bernoulli(prob) => CandleTensor::new(
candle_core::Tensor::rand(0., 1., shape, device)
.unwrap()
.gt(&super::candle_utils::fill(prob, shape, F::DTYPE, device))
.to_dtype(F::DTYPE)
.unwrap()
.lt(&super::candle_utils::fill(prob, shape, F::DTYPE, device))
.unwrap()
.to_dtype(F::DTYPE)
.unwrap(),
Expand Down
50 changes: 32 additions & 18 deletions burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ include_models!(

#[cfg(test)]
mod tests {
use std::f64::consts;

use super::*;

use burn::tensor::{Data, Int, Shape, Tensor};
Expand Down Expand Up @@ -154,8 +156,8 @@ mod tests {
// Initialize the model with weights (loaded from the exported file)
let model: conv1d::Model<Backend> = conv1d::Model::default();

// Run the model with 3.1416 as input for easier testing
let input = Tensor::<Backend, 3>::full([6, 4, 10], 3.1416);
// Run the model with pi as input for easier testing
let input = Tensor::<Backend, 3>::full([6, 4, 10], consts::PI);

let output = model.forward(input);

Expand Down Expand Up @@ -460,18 +462,24 @@ mod tests {

// Run the model
let input = Tensor::<Backend, 1>::from_floats([
0.88226926, 0.91500396, 0.38286376, 0.95930564, 0.39044821, 0.60089535,
0.88226926,
0.91500396,
0.38286376,
0.95930564,
0.390_448_2,
0.60089535,
]);
let (output1, output2, output3) = model.forward(input);
let expected1 = Data::from([
0.88226926, 0.91500396, 0.38286376, 0.95930564, 0.39044821, 0.60089535,
]);
let expected2 = Data::from([
0.69999999, 0.69999999, 0.50000000, 0.69999999, 0.50000000, 0.60089535,
]);
let expected3 = Data::from([
0.80000001, 0.80000001, 0.38286376, 0.80000001, 0.39044821, 0.60089535,
0.88226926,
0.91500396,
0.38286376,
0.95930564,
0.390_448_2,
0.60089535,
]);
let expected2 = Data::from([0.7, 0.7, 0.5, 0.7, 0.5, 0.60089535]);
let expected3 = Data::from([0.8, 0.8, 0.38286376, 0.8, 0.390_448_2, 0.60089535]);

assert_eq!(output1.to_data(), expected1);
assert_eq!(output2.to_data(), expected2);
Expand All @@ -485,18 +493,24 @@ mod tests {

// Run the model
let input = Tensor::<Backend, 1>::from_floats([
0.88226926, 0.91500396, 0.38286376, 0.95930564, 0.39044821, 0.60089535,
0.88226926,
0.91500396,
0.38286376,
0.95930564,
0.390_448_2,
0.60089535,
]);
let (output1, output2, output3) = model.forward(input);
let expected1 = Data::from([
0.88226926, 0.91500396, 0.38286376, 0.95930564, 0.39044821, 0.60089535,
]);
let expected2 = Data::from([
0.69999999, 0.69999999, 0.50000000, 0.69999999, 0.50000000, 0.60089535,
]);
let expected3 = Data::from([
0.80000001, 0.80000001, 0.38286376, 0.80000001, 0.39044821, 0.60089535,
0.88226926,
0.91500396,
0.38286376,
0.95930564,
0.390_448_2,
0.60089535,
]);
let expected2 = Data::from([0.7, 0.7, 0.5, 0.7, 0.5, 0.60089535]);
let expected3 = Data::from([0.8, 0.8, 0.38286376, 0.8, 0.390_448_2, 0.60089535]);

assert_eq!(output1.to_data(), expected1);
assert_eq!(output2.to_data(), expected2);
Expand Down
26 changes: 19 additions & 7 deletions burn-tch/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,13 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
}

fn int_div_scalar<const D: usize>(lhs: TchTensor<i64, D>, rhs: i64) -> TchTensor<i64, D> {
lhs.unary_ops(
let lhs: TchTensor<f64, D> =
TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, true, false));
let output: TchTensor<i64, D> = lhs.unary_ops(
|mut tensor| tensor.f_div_scalar_(rhs).unwrap(),
|tensor| tensor.f_div_scalar(rhs).unwrap(),
)
);
TchTensor::<i64, D>::new(output.tensor.to_dtype(tch::Kind::Int64, true, false))
}

fn int_neg<const D: usize>(tensor: TchTensor<i64, D>) -> TchTensor<i64, D> {
Expand Down Expand Up @@ -249,11 +252,20 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
}

fn int_mean<const D: usize>(tensor: TchTensor<i64, D>) -> TchTensor<i64, 1> {
TchOps::mean(tensor)
let tensor: TchTensor<f64, D> =
TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
let output: TchTensor<i64, 1> = TchTensor::new(TchOps::mean(tensor).tensor);

TchTensor::<i64, 1>::new(output.tensor.to_dtype(tch::Kind::Int64, true, false))
}

fn int_mean_dim<const D: usize>(tensor: TchTensor<i64, D>, dim: usize) -> TchTensor<i64, D> {
TchOps::mean_dim(tensor, dim)
let tensor: TchTensor<f64, D> =
TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));

let output: TchTensor<i64, D> = TchTensor::new(TchOps::mean_dim(tensor, dim).tensor);

TchTensor::<i64, D>::new(output.tensor.to_dtype(tch::Kind::Int64, true, false))
}

fn int_gather<const D: usize>(
Expand Down Expand Up @@ -298,9 +310,9 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
TchTensor::binary_ops_tensor(
tensor,
source,
|tensor, source| tensor.f_masked_scatter_(&mask.tensor, source).unwrap(),
|tensor, source| tensor.f_masked_scatter(&mask.tensor, source).unwrap(),
|tensor, source| tensor.f_masked_scatter(&mask.tensor, source).unwrap(),
|tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
|tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
|tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
)
}

Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ macro_rules! testgen_all {
burn_tensor::testgen_clamp!();
burn_tensor::testgen_cos!();
burn_tensor::testgen_div!();
burn_tensor::testgen_empty!();
burn_tensor::testgen_erf!();
burn_tensor::testgen_exp!();
burn_tensor::testgen_flatten!();
burn_tensor::testgen_full!();
burn_tensor::testgen_gather_scatter!();
burn_tensor::testgen_init!();
burn_tensor::testgen_log!();
burn_tensor::testgen_log1p!();
burn_tensor::testgen_map_comparison!();
Expand Down
41 changes: 40 additions & 1 deletion burn-tensor/src/tests/ops/add.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[burn_tensor_testgen::testgen(add)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
use burn_tensor::{Data, Int, Tensor};

#[test]
fn test_add_d2() {
Expand Down Expand Up @@ -41,4 +41,43 @@ mod tests {
let data_expected = Data::from([[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]]);
assert_eq!(data_expected, data_actual);
}

#[test]
fn test_add_d2_int() {
let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]);
let data_2 = Data::from([[6, 7, 8], [9, 10, 11]]);
let tensor_1 = Tensor::<TestBackend, 2, Int>::from_data(data_1);
let tensor_2 = Tensor::<TestBackend, 2, Int>::from_data(data_2);

let data_actual = (tensor_1 + tensor_2).into_data();

let data_expected = Data::from([[6, 8, 10], [12, 14, 16]]);
assert_eq!(data_expected, data_actual);
}

#[test]
fn test_add_broadcast_int() {
let data_1 = Data::from([[0, 1, 2]]);
let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]);
let tensor_1 = Tensor::<TestBackend, 2, Int>::from_data(data_1);
let tensor_2 = Tensor::<TestBackend, 2, Int>::from_data(data_2);

let data_actual = (tensor_1 + tensor_2).into_data();

let data_expected = Data::from([[3, 5, 7], [6, 8, 10]]);
assert_eq!(data_expected, data_actual);
}

#[test]
fn should_support_add_scalar_ops_int() {
let data = Data::from([[0, 1, 2], [3, 4, 5]]);
let scalar = 2;
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data);

let output = tensor + scalar;

let data_actual = output.into_data();
let data_expected = Data::from([[2, 3, 4], [5, 6, 7]]);
assert_eq!(data_expected, data_actual);
}
}
36 changes: 36 additions & 0 deletions burn-tensor/src/tests/ops/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ mod tests {
assert_eq!(data_actual, Data::from([15.0 / 6.0]));
}

#[test]
fn test_should_mean_int() {
let tensor = TestTensorInt::from_data([[2, 2, 2], [3, 4, 5]]);

let data_actual = tensor.mean().to_data();

assert_eq!(data_actual, Data::from([3]));
}

#[test]
fn test_should_sum() {
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
Expand All @@ -21,6 +30,15 @@ mod tests {
assert_eq!(data_actual, Data::from([15.0]));
}

#[test]
fn test_should_sum_int() {
let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]);

let data_actual = tensor.sum().to_data();

assert_eq!(data_actual, Data::from([15]));
}

#[test]
fn test_should_mean_last_dim() {
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
Expand All @@ -39,6 +57,24 @@ mod tests {
assert_eq!(data_actual, Data::from([[3.0], [12.0]]));
}

#[test]
fn test_should_mean_last_dim_int() {
let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]);

let data_actual = tensor.mean_dim(1).to_data();

assert_eq!(data_actual, Data::from([[1], [4]]));
}

#[test]
fn test_should_sum_last_dim_int() {
let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]);

let data_actual = tensor.sum_dim(1).to_data();

assert_eq!(data_actual, Data::from([[3], [12]]));
}

#[test]
fn test_should_sum_first_dim() {
let tensor = TestTensor::from_data([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]);
Expand Down
24 changes: 23 additions & 1 deletion burn-tensor/src/tests/ops/arg.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[burn_tensor_testgen::testgen(arg)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
use burn_tensor::{Data, Int, Tensor};

#[test]
fn test_argmax_2d_dim0() {
Expand All @@ -25,6 +25,28 @@ mod tests {
assert_eq!(data_expected, data_actual.to_data());
}

#[test]
fn test_argmax_2d_dim0_int() {
let data = Data::from([[10, 11, 2], [3, 4, 5]]);
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data);

let data_actual = tensor.argmax(0);

let data_expected = Data::from([[0, 0, 1]]);
assert_eq!(data_expected, data_actual.to_data());
}

#[test]
fn test_argmin_2d_dim0_int() {
let data = Data::from([[10, 11, 2], [30, 4, 5]]);
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data);

let data_actual = tensor.argmin(0);

let data_expected = Data::from([[0, 1, 0]]);
assert_eq!(data_expected, data_actual.to_data());
}

#[test]
fn test_argmax_2d_dim1() {
let data = Data::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]);
Expand Down
Loading

0 comments on commit 7c34e21

Please sign in to comment.