diff --git a/burn-autodiff/src/tests/gradients.rs b/burn-autodiff/src/tests/gradients.rs index 6efc43f1fd..98fb44d8ea 100644 --- a/burn-autodiff/src/tests/gradients.rs +++ b/burn-autodiff/src/tests/gradients.rs @@ -11,7 +11,7 @@ mod tests { let x = tensor_1.clone().matmul(activation::gelu(tensor_2)); let mut grads = x.backward(); - let grad_1 = tensor_1.grad(&mut grads).unwrap(); + let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_1_updated = TestADTensor::random([32, 32], Distribution::Default).require_grad(); tensor_1.grad_replace(&mut grads, grad_1_updated.clone().inner()); diff --git a/burn-candle/src/lib.rs b/burn-candle/src/lib.rs index bc36a48e17..a746ea6bd6 100644 --- a/burn-candle/src/lib.rs +++ b/burn-candle/src/lib.rs @@ -49,7 +49,7 @@ mod tests { // test ops burn_tensor::testgen_add!(); - burn_tensor::testgen_aggregation!(); + // burn_tensor::testgen_aggregation!(); burn_tensor::testgen_arange!(); burn_tensor::testgen_arange_step!(); burn_tensor::testgen_arg!(); @@ -57,13 +57,13 @@ mod tests { burn_tensor::testgen_cat!(); burn_tensor::testgen_clamp!(); burn_tensor::testgen_cos!(); - burn_tensor::testgen_div!(); - burn_tensor::testgen_empty!(); + // burn_tensor::testgen_div!(); // burn_tensor::testgen_erf!(); burn_tensor::testgen_exp!(); burn_tensor::testgen_flatten!(); burn_tensor::testgen_full!(); burn_tensor::testgen_gather_scatter!(); + burn_tensor::testgen_init!(); burn_tensor::testgen_log!(); burn_tensor::testgen_log1p!(); burn_tensor::testgen_map_comparison!(); diff --git a/burn-candle/src/ops/int_tensor.rs b/burn-candle/src/ops/int_tensor.rs index 706eaa2ce2..7499e06418 100644 --- a/burn-candle/src/ops/int_tensor.rs +++ b/burn-candle/src/ops/int_tensor.rs @@ -73,7 +73,7 @@ impl IntTensorOps IntTensor { CandleTensor::new( mask.tensor - .where_cond(&tensor.tensor, &tensor.tensor) + .where_cond(&source.tensor, &tensor.tensor) .unwrap(), ) } @@ -284,7 +284,8 @@ impl IntTensorOps, rhs: IntElem, ) -> IntTensor { - CandleTensor::new((lhs.tensor / rhs.elem::()).unwrap()) + // Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0. + panic!("Not supported by Candle") } fn int_zeros(shape: Shape, device: &Device) -> IntTensor { @@ -312,15 +313,30 @@ impl IntTensorOps(tensor: IntTensor, dim: usize) -> IntTensor { - CandleTensor::new(tensor.tensor.mean_keepdim(dim).unwrap()) + // Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0. + panic!("Not supported by Candle") } fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { - CandleTensor::new(tensor.tensor.argmax_keepdim(dim).unwrap()) + CandleTensor::new( + tensor + .tensor + .argmax_keepdim(dim) + .unwrap() + .to_dtype(I::DTYPE) + .unwrap(), + ) } fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { - CandleTensor::new(tensor.tensor.argmin_keepdim(dim).unwrap()) + CandleTensor::new( + tensor + .tensor + .argmin_keepdim(dim) + .unwrap() + .to_dtype(I::DTYPE) + .unwrap(), + ) } fn int_abs(tensor: IntTensor) -> IntTensor { diff --git a/burn-candle/src/ops/tensor.rs b/burn-candle/src/ops/tensor.rs index ea0146d67f..095f94226c 100644 --- a/burn-candle/src/ops/tensor.rs +++ b/burn-candle/src/ops/tensor.rs @@ -34,7 +34,9 @@ impl TensorOps> Distribution::Bernoulli(prob) => CandleTensor::new( candle_core::Tensor::rand(0., 1., shape, device) .unwrap() - .gt(&super::candle_utils::fill(prob, shape, F::DTYPE, device)) + .to_dtype(F::DTYPE) + .unwrap() + .lt(&super::candle_utils::fill(prob, shape, F::DTYPE, device)) .unwrap() .to_dtype(F::DTYPE) .unwrap(), diff --git a/burn-import/onnx-tests/tests/onnx_tests.rs b/burn-import/onnx-tests/tests/onnx_tests.rs index 8992141ec6..9b0f15c64a 100644 --- a/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/burn-import/onnx-tests/tests/onnx_tests.rs @@ -42,6 +42,8 @@ include_models!( #[cfg(test)] mod tests { + use std::f64::consts; + use super::*; use burn::tensor::{Data, Int, Shape, Tensor}; @@ -154,8 +156,8 @@ mod tests { // Initialize the model with weights (loaded from the exported file) let model: conv1d::Model = conv1d::Model::default(); - // Run the model with 3.1416 as input for easier testing - let input = Tensor::::full([6, 4, 10], 3.1416); + // Run the model with pi as input for easier testing + let input = Tensor::::full([6, 4, 10], consts::PI); let output = model.forward(input); @@ -460,18 +462,24 @@ mod tests { // Run the model let input = Tensor::::from_floats([ - 0.88226926, 0.91500396, 0.38286376, 0.95930564, 0.39044821, 0.60089535, + 0.88226926, + 0.91500396, + 0.38286376, + 0.95930564, + 0.390_448_2, + 0.60089535, ]); let (output1, output2, output3) = model.forward(input); let expected1 = Data::from([ - 0.88226926, 0.91500396, 0.38286376, 0.95930564, 0.39044821, 0.60089535, - ]); - let expected2 = Data::from([ - 0.69999999, 0.69999999, 0.50000000, 0.69999999, 0.50000000, 0.60089535, - ]); - let expected3 = Data::from([ - 0.80000001, 0.80000001, 0.38286376, 0.80000001, 0.39044821, 0.60089535, + 0.88226926, + 0.91500396, + 0.38286376, + 0.95930564, + 0.390_448_2, + 0.60089535, ]); + let expected2 = Data::from([0.7, 0.7, 0.5, 0.7, 0.5, 0.60089535]); + let expected3 = Data::from([0.8, 0.8, 0.38286376, 0.8, 0.390_448_2, 0.60089535]); assert_eq!(output1.to_data(), expected1); assert_eq!(output2.to_data(), expected2); @@ -485,18 +493,24 @@ mod tests { // Run the model let input = Tensor::::from_floats([ - 0.88226926, 0.91500396, 0.38286376, 0.95930564, 0.39044821, 0.60089535, + 0.88226926, + 0.91500396, + 0.38286376, + 0.95930564, + 0.390_448_2, + 0.60089535, ]); let (output1, output2, output3) = model.forward(input); let expected1 = Data::from([ - 0.88226926, 0.91500396, 0.38286376, 0.95930564, 0.39044821, 0.60089535, - ]); - let expected2 = Data::from([ - 0.69999999, 0.69999999, 0.50000000, 0.69999999, 0.50000000, 0.60089535, - ]); - let expected3 = Data::from([ - 0.80000001, 0.80000001, 0.38286376, 0.80000001, 0.39044821, 0.60089535, + 0.88226926, + 0.91500396, + 0.38286376, + 0.95930564, + 0.390_448_2, + 0.60089535, ]); + let expected2 = Data::from([0.7, 0.7, 0.5, 0.7, 0.5, 0.60089535]); + let expected3 = Data::from([0.8, 0.8, 0.38286376, 0.8, 0.390_448_2, 0.60089535]); assert_eq!(output1.to_data(), expected1); assert_eq!(output2.to_data(), expected2); diff --git a/burn-tch/src/ops/int_tensor.rs b/burn-tch/src/ops/int_tensor.rs index 505dc19578..a8d2f5a53a 100644 --- a/burn-tch/src/ops/int_tensor.rs +++ b/burn-tch/src/ops/int_tensor.rs @@ -195,10 +195,13 @@ impl IntTensorOps> for TchBackend { } fn int_div_scalar(lhs: TchTensor, rhs: i64) -> TchTensor { - lhs.unary_ops( + let lhs: TchTensor = + TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, true, false)); + let output: TchTensor = lhs.unary_ops( |mut tensor| tensor.f_div_scalar_(rhs).unwrap(), |tensor| tensor.f_div_scalar(rhs).unwrap(), - ) + ); + TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false)) } fn int_neg(tensor: TchTensor) -> TchTensor { @@ -249,11 +252,20 @@ impl IntTensorOps> for TchBackend { } fn int_mean(tensor: TchTensor) -> TchTensor { - TchOps::mean(tensor) + let tensor: TchTensor = + TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false)); + let output: TchTensor = TchTensor::new(TchOps::mean(tensor).tensor); + + TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false)) } fn int_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { - TchOps::mean_dim(tensor, dim) + let tensor: TchTensor = + TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false)); + + let output: TchTensor = TchTensor::new(TchOps::mean_dim(tensor, dim).tensor); + + TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false)) } fn int_gather( @@ -298,9 +310,9 @@ impl IntTensorOps> for TchBackend { TchTensor::binary_ops_tensor( tensor, source, - |tensor, source| tensor.f_masked_scatter_(&mask.tensor, source).unwrap(), - |tensor, source| tensor.f_masked_scatter(&mask.tensor, source).unwrap(), - |tensor, source| tensor.f_masked_scatter(&mask.tensor, source).unwrap(), + |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), + |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), + |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), ) } diff --git a/burn-tensor/src/tests/mod.rs b/burn-tensor/src/tests/mod.rs index 282cf33b3f..d137b8a764 100644 --- a/burn-tensor/src/tests/mod.rs +++ b/burn-tensor/src/tests/mod.rs @@ -39,12 +39,12 @@ macro_rules! testgen_all { burn_tensor::testgen_clamp!(); burn_tensor::testgen_cos!(); burn_tensor::testgen_div!(); - burn_tensor::testgen_empty!(); burn_tensor::testgen_erf!(); burn_tensor::testgen_exp!(); burn_tensor::testgen_flatten!(); burn_tensor::testgen_full!(); burn_tensor::testgen_gather_scatter!(); + burn_tensor::testgen_init!(); burn_tensor::testgen_log!(); burn_tensor::testgen_log1p!(); burn_tensor::testgen_map_comparison!(); diff --git a/burn-tensor/src/tests/ops/add.rs b/burn-tensor/src/tests/ops/add.rs index d2b3f60d6a..bd45b4376d 100644 --- a/burn-tensor/src/tests/ops/add.rs +++ b/burn-tensor/src/tests/ops/add.rs @@ -1,7 +1,7 @@ #[burn_tensor_testgen::testgen(add)] mod tests { use super::*; - use burn_tensor::{Data, Tensor}; + use burn_tensor::{Data, Int, Tensor}; #[test] fn test_add_d2() { @@ -41,4 +41,43 @@ mod tests { let data_expected = Data::from([[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]]); assert_eq!(data_expected, data_actual); } + + #[test] + fn test_add_d2_int() { + let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); + let data_2 = Data::from([[6, 7, 8], [9, 10, 11]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 + tensor_2).into_data(); + + let data_expected = Data::from([[6, 8, 10], [12, 14, 16]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_add_broadcast_int() { + let data_1 = Data::from([[0, 1, 2]]); + let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 + tensor_2).into_data(); + + let data_expected = Data::from([[3, 5, 7], [6, 8, 10]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_add_scalar_ops_int() { + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let scalar = 2; + let tensor = Tensor::::from_data(data); + + let output = tensor + scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[2, 3, 4], [5, 6, 7]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/aggregation.rs b/burn-tensor/src/tests/ops/aggregation.rs index 2bb82a24a5..a260453d47 100644 --- a/burn-tensor/src/tests/ops/aggregation.rs +++ b/burn-tensor/src/tests/ops/aggregation.rs @@ -12,6 +12,15 @@ mod tests { assert_eq!(data_actual, Data::from([15.0 / 6.0])); } + #[test] + fn test_should_mean_int() { + let tensor = TestTensorInt::from_data([[2, 2, 2], [3, 4, 5]]); + + let data_actual = tensor.mean().to_data(); + + assert_eq!(data_actual, Data::from([3])); + } + #[test] fn test_should_sum() { let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); @@ -21,6 +30,15 @@ mod tests { assert_eq!(data_actual, Data::from([15.0])); } + #[test] + fn test_should_sum_int() { + let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]); + + let data_actual = tensor.sum().to_data(); + + assert_eq!(data_actual, Data::from([15])); + } + #[test] fn test_should_mean_last_dim() { let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); @@ -39,6 +57,24 @@ mod tests { assert_eq!(data_actual, Data::from([[3.0], [12.0]])); } + #[test] + fn test_should_mean_last_dim_int() { + let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]); + + let data_actual = tensor.mean_dim(1).to_data(); + + assert_eq!(data_actual, Data::from([[1], [4]])); + } + + #[test] + fn test_should_sum_last_dim_int() { + let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]); + + let data_actual = tensor.sum_dim(1).to_data(); + + assert_eq!(data_actual, Data::from([[3], [12]])); + } + #[test] fn test_should_sum_first_dim() { let tensor = TestTensor::from_data([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]); diff --git a/burn-tensor/src/tests/ops/arg.rs b/burn-tensor/src/tests/ops/arg.rs index cfc4006edf..fd6f282b76 100644 --- a/burn-tensor/src/tests/ops/arg.rs +++ b/burn-tensor/src/tests/ops/arg.rs @@ -1,7 +1,7 @@ #[burn_tensor_testgen::testgen(arg)] mod tests { use super::*; - use burn_tensor::{Data, Tensor}; + use burn_tensor::{Data, Int, Tensor}; #[test] fn test_argmax_2d_dim0() { @@ -25,6 +25,28 @@ mod tests { assert_eq!(data_expected, data_actual.to_data()); } + #[test] + fn test_argmax_2d_dim0_int() { + let data = Data::from([[10, 11, 2], [3, 4, 5]]); + let tensor = Tensor::::from_data(data); + + let data_actual = tensor.argmax(0); + + let data_expected = Data::from([[0, 0, 1]]); + assert_eq!(data_expected, data_actual.to_data()); + } + + #[test] + fn test_argmin_2d_dim0_int() { + let data = Data::from([[10, 11, 2], [30, 4, 5]]); + let tensor = Tensor::::from_data(data); + + let data_actual = tensor.argmin(0); + + let data_expected = Data::from([[0, 1, 0]]); + assert_eq!(data_expected, data_actual.to_data()); + } + #[test] fn test_argmax_2d_dim1() { let data = Data::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]); diff --git a/burn-tensor/src/tests/ops/div.rs b/burn-tensor/src/tests/ops/div.rs index 20992deb69..ab48f74ed9 100644 --- a/burn-tensor/src/tests/ops/div.rs +++ b/burn-tensor/src/tests/ops/div.rs @@ -1,7 +1,7 @@ #[burn_tensor_testgen::testgen(div)] mod tests { use super::*; - use burn_tensor::{Data, Tensor}; + use burn_tensor::{Data, Int, Tensor}; #[test] fn should_support_div_ops() { @@ -42,4 +42,44 @@ mod tests { let data_expected = Data::from([[0.0, 0.5, 1.0], [1.5, 2.0, 2.5]]); assert_eq!(data_expected, data_actual); } + + #[test] + fn should_support_div_ops_int() { + let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); + let data_2 = Data::from([[1, 1, 2], [1, 1, 2]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let output = tensor_1 / tensor_2; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0, 1, 1], [3, 4, 2]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_div_broadcast_int() { + let data_1 = Data::from([[0, 1, 2]]); + let data_2 = Data::from([[1, 1, 2], [3, 4, 5]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 / tensor_2).into_data(); + + let data_expected = Data::from([[0, 1, 1], [0, 0, 0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_div_scalar_ops_int() { + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let scalar = 2; + let tensor = Tensor::::from_data(data); + + let output = tensor / scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0, 0, 1], [1, 2, 2]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/empty.rs b/burn-tensor/src/tests/ops/empty.rs deleted file mode 100644 index 1dacbf6c2d..0000000000 --- a/burn-tensor/src/tests/ops/empty.rs +++ /dev/null @@ -1,26 +0,0 @@ -#[burn_tensor_testgen::testgen(empty)] -mod tests { - use super::*; - use burn_tensor::{Bool, Int, Tensor}; - - #[test] - fn should_support_float_empty() { - let shape = [2, 2]; - let tensor = Tensor::::empty(shape); - assert_eq!(tensor.shape(), shape.into()) - } - - #[test] - fn should_support_int_empty() { - let shape = [2, 2]; - let tensor = Tensor::::empty(shape); - assert_eq!(tensor.shape(), shape.into()) - } - - #[test] - fn should_support_bool_empty() { - let shape = [2, 2]; - let tensor = Tensor::::empty(shape); - assert_eq!(tensor.shape(), shape.into()) - } -} diff --git a/burn-tensor/src/tests/ops/gather_scatter.rs b/burn-tensor/src/tests/ops/gather_scatter.rs index 597981796d..e133588739 100644 --- a/burn-tensor/src/tests/ops/gather_scatter.rs +++ b/burn-tensor/src/tests/ops/gather_scatter.rs @@ -13,6 +13,16 @@ mod tests { assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0])); } + #[test] + fn should_gather_1d_dim0_int() { + let tensor = TestTensorInt::from_ints([5, 6, 7]); + let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2]); + + let output = tensor.gather(0, indices); + + assert_eq!(output.into_data(), Data::from([6, 6, 5, 6, 7])); + } + #[test] fn should_gather_2d_dim0() { let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); @@ -79,6 +89,17 @@ mod tests { assert_eq!(output.into_data(), Data::from([4.0, 5.0, 3.0])); } + #[test] + fn should_scatter_1d_int() { + let tensor = TestTensorInt::from_ints([0, 0, 0]); + let values = TestTensorInt::from_ints([5, 4, 3]); + let indices = TestTensorInt::from_ints([1, 0, 2]); + + let output = tensor.scatter(0, indices, values); + + assert_eq!(output.into_data(), Data::from([4, 5, 3])); + } + #[test] fn should_scatter_2d_dim0() { let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); diff --git a/burn-tensor/src/tests/ops/init.rs b/burn-tensor/src/tests/ops/init.rs new file mode 100644 index 0000000000..7a89527cd3 --- /dev/null +++ b/burn-tensor/src/tests/ops/init.rs @@ -0,0 +1,58 @@ +#[burn_tensor_testgen::testgen(init)] +mod tests { + use super::*; + use burn_tensor::{Bool, Data, Int, Tensor}; + + #[test] + fn should_support_float_empty() { + let shape = [2, 2]; + let tensor = Tensor::::empty(shape); + assert_eq!(tensor.shape(), shape.into()) + } + + #[test] + fn should_support_int_empty() { + let shape = [2, 2]; + let tensor = Tensor::::empty(shape); + assert_eq!(tensor.shape(), shape.into()) + } + + #[test] + fn should_support_float_zeros() { + let shape = [2, 2]; + let tensor = Tensor::::zeros(shape); + assert_eq!(tensor.shape(), shape.into()); + assert_eq!(tensor.to_data(), Data::from([[0., 0.], [0., 0.]])) + } + + #[test] + fn should_support_int_zeros() { + let shape = [2, 2]; + let tensor = Tensor::::zeros(shape); + assert_eq!(tensor.shape(), shape.into()); + assert_eq!(tensor.to_data(), Data::from([[0, 0], [0, 0]])) + } + + #[test] + fn should_support_float_ones() { + let shape = [2, 2]; + let tensor = Tensor::::ones(shape); + assert_eq!(tensor.shape(), shape.into()); + assert_eq!(tensor.to_data(), Data::from([[1., 1.], [1., 1.]])) + } + + #[test] + fn should_support_int_ones() { + let shape = [2, 2]; + let tensor = Tensor::::ones(shape); + assert_eq!(tensor.shape(), shape.into()); + assert_eq!(tensor.to_data(), Data::from([[1, 1], [1, 1]])) + } + + #[test] + fn should_support_bool_empty() { + let shape = [2, 2]; + let tensor = Tensor::::empty(shape); + assert_eq!(tensor.shape(), shape.into()) + } +} diff --git a/burn-tensor/src/tests/ops/mask.rs b/burn-tensor/src/tests/ops/mask.rs index f4e32f8cd4..735c31127e 100644 --- a/burn-tensor/src/tests/ops/mask.rs +++ b/burn-tensor/src/tests/ops/mask.rs @@ -1,18 +1,18 @@ #[burn_tensor_testgen::testgen(mask)] mod tests { use super::*; - use burn_tensor::{Bool, Data, Tensor}; + use burn_tensor::{Bool, Data, Int, Tensor}; #[test] fn should_support_mask_where_ops() { let tensor = TestTensor::from_data([[1.0, 7.0], [2.0, 3.0]]); let mask = Tensor::::from_bool(Data::from([[true, false], [false, true]])); - let value = Tensor::::from_data(Data::from([[8.8, 8.8], [8.8, 8.8]])); + let value = Tensor::::from_data(Data::from([[1.8, 2.8], [3.8, 4.8]])); let data_actual = tensor.mask_where(mask, value).into_data(); - let data_expected = Data::from([[8.8, 7.0], [2.0, 8.8]]); + let data_expected = Data::from([[1.8, 7.0], [2.0, 4.8]]); assert_eq!(data_expected, data_actual); } @@ -27,4 +27,29 @@ mod tests { let data_expected = Data::from([[2.0, 7.0], [2.0, 2.0]]); assert_eq!(data_expected, data_actual); } + + #[test] + fn should_support_int_mask_where_ops() { + let tensor = Tensor::::from_data([[1, 7], [2, 3]]); + let mask = + Tensor::::from_bool(Data::from([[true, false], [false, true]])); + let value = Tensor::::from_data(Data::from([[8, 9], [10, 11]])); + + let data_actual = tensor.mask_where(mask, value).into_data(); + + let data_expected = Data::from([[8, 7], [2, 11]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_int_mask_fill_ops() { + let tensor = Tensor::::from_data([[1, 7], [2, 3]]); + let mask = + Tensor::::from_bool(Data::from([[true, false], [false, true]])); + + let data_actual = tensor.mask_fill(mask, 9).to_data(); + + let data_expected = Data::from([[9, 7], [2, 9]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/mod.rs b/burn-tensor/src/tests/ops/mod.rs index 6fadd5b072..d7f143f00b 100644 --- a/burn-tensor/src/tests/ops/mod.rs +++ b/burn-tensor/src/tests/ops/mod.rs @@ -9,12 +9,12 @@ mod cat; mod clamp; mod cos; mod div; -mod empty; mod erf; mod exp; mod flatten; mod full; mod gather_scatter; +mod init; mod log; mod log1p; mod map_comparison; diff --git a/burn-tensor/src/tests/ops/mul.rs b/burn-tensor/src/tests/ops/mul.rs index 11656d8b1f..81337b808f 100644 --- a/burn-tensor/src/tests/ops/mul.rs +++ b/burn-tensor/src/tests/ops/mul.rs @@ -1,7 +1,7 @@ #[burn_tensor_testgen::testgen(mul)] mod tests { use super::*; - use burn_tensor::{Data, Tensor}; + use burn_tensor::{Data, Int, Tensor}; #[test] fn should_support_mul_ops() { @@ -42,4 +42,44 @@ mod tests { let data_expected = Data::from([[0.0, 2.0, 4.0], [6.0, 8.0, 10.0]]); assert_eq!(data_expected, data_actual); } + + #[test] + fn should_support_mul_ops_int() { + let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); + let data_2 = Data::from([[0, 1, 2], [3, 4, 5]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let output = tensor_1 * tensor_2; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0, 1, 4], [9, 16, 25]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_mul_broadcast_int() { + let data_1 = Data::from([[0, 1, 2]]); + let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 * tensor_2).into_data(); + + let data_expected = Data::from([[0, 4, 10], [0, 7, 16]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_mul_scalar_ops_int() { + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let scalar = 2; + let tensor = Tensor::::from_data(data); + + let output = tensor * scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[0, 2, 4], [6, 8, 10]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/random.rs b/burn-tensor/src/tests/ops/random.rs index 6bf591a1dc..5aeccfa7a4 100644 --- a/burn-tensor/src/tests/ops/random.rs +++ b/burn-tensor/src/tests/ops/random.rs @@ -10,4 +10,18 @@ mod tests { // check that the tensor is within the range of [0..1) (1 is exclusive) tensor.into_data().assert_within_range(0.0..1.0); } + + #[test] + fn rand_uniform() { + let tensor = Tensor::::random([20], Distribution::Uniform(4., 5.)); + + tensor.into_data().assert_within_range(4.0..5.0); + } + + #[test] + fn rand_bernoulli() { + let tensor = Tensor::::random([20], Distribution::Bernoulli(1.)); + + assert_eq!(tensor.into_data(), [1.; 20].into()); + } } diff --git a/burn-tensor/src/tests/ops/select.rs b/burn-tensor/src/tests/ops/select.rs index f2ff024126..fefe8c7601 100644 --- a/burn-tensor/src/tests/ops/select.rs +++ b/burn-tensor/src/tests/ops/select.rs @@ -13,6 +13,16 @@ mod tests { assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0])); } + #[test] + fn should_select_1d_int() { + let tensor = TestTensorInt::from_data([5, 6, 7]); + let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]); + + let output = tensor.select(0, indices); + + assert_eq!(output.into_data(), Data::from([6, 6, 5, 6, 7])); + } + #[test] fn should_select_2d_dim0_same_num_dim() { let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); @@ -68,6 +78,17 @@ mod tests { assert_eq!(output.into_data(), Data::from([3.0, 12.0, 3.0])); } + #[test] + fn should_select_assign_1d_int() { + let tensor = TestTensorInt::from_data([7, 8, 9]); + let values = TestTensorInt::from_data([5, 4, 3, 2, 1]); + let indices = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2])); + + let output = tensor.select_assign(0, indices, values); + + assert_eq!(output.into_data(), Data::from([10, 19, 10])); + } + #[test] fn should_select_assign_2d_dim0() { let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); diff --git a/burn-tensor/src/tests/ops/sub.rs b/burn-tensor/src/tests/ops/sub.rs index 43da19c6d9..3293379abd 100644 --- a/burn-tensor/src/tests/ops/sub.rs +++ b/burn-tensor/src/tests/ops/sub.rs @@ -1,7 +1,7 @@ #[burn_tensor_testgen::testgen(sub)] mod tests { use super::*; - use burn_tensor::{Data, Tensor}; + use burn_tensor::{Data, Int, Tensor}; #[test] fn should_support_sub_ops() { @@ -41,4 +41,43 @@ mod tests { let data_expected = Data::from([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]]); assert_eq!(data_expected, data_actual); } + + #[test] + fn should_support_sub_ops_int() { + let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]); + let data_2 = Data::from([[6, 7, 8], [9, 10, 11]]); + let data_expected = Data::from([[-6, -6, -6], [-6, -6, -6]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 - tensor_2).into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_sub_broadcast_int() { + let data_1 = Data::from([[0, 1, 2]]); + let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 - tensor_2).into_data(); + + let data_expected = Data::from([[-3, -3, -3], [-6, -6, -6]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_sub_scalar_ops_int() { + let data = Data::from([[0, 1, 2], [3, 4, 5]]); + let scalar = 2; + let tensor = Tensor::::from_data(data); + + let output = tensor - scalar; + + let data_actual = output.into_data(); + let data_expected = Data::from([[-2, -1, 0], [1, 2, 3]]); + assert_eq!(data_expected, data_actual); + } }