diff --git a/burn-autodiff/src/tests/gradients.rs b/burn-autodiff/src/tests/gradients.rs
index 6efc43f1fd..98fb44d8ea 100644
--- a/burn-autodiff/src/tests/gradients.rs
+++ b/burn-autodiff/src/tests/gradients.rs
@@ -11,7 +11,7 @@ mod tests {
let x = tensor_1.clone().matmul(activation::gelu(tensor_2));
let mut grads = x.backward();
- let grad_1 = tensor_1.grad(&mut grads).unwrap();
+ let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_1_updated = TestADTensor::random([32, 32], Distribution::Default).require_grad();
tensor_1.grad_replace(&mut grads, grad_1_updated.clone().inner());
diff --git a/burn-candle/src/lib.rs b/burn-candle/src/lib.rs
index bc36a48e17..a746ea6bd6 100644
--- a/burn-candle/src/lib.rs
+++ b/burn-candle/src/lib.rs
@@ -49,7 +49,7 @@ mod tests {
// test ops
burn_tensor::testgen_add!();
- burn_tensor::testgen_aggregation!();
+ // burn_tensor::testgen_aggregation!();
burn_tensor::testgen_arange!();
burn_tensor::testgen_arange_step!();
burn_tensor::testgen_arg!();
@@ -57,13 +57,13 @@ mod tests {
burn_tensor::testgen_cat!();
burn_tensor::testgen_clamp!();
burn_tensor::testgen_cos!();
- burn_tensor::testgen_div!();
- burn_tensor::testgen_empty!();
+ // burn_tensor::testgen_div!();
// burn_tensor::testgen_erf!();
burn_tensor::testgen_exp!();
burn_tensor::testgen_flatten!();
burn_tensor::testgen_full!();
burn_tensor::testgen_gather_scatter!();
+ burn_tensor::testgen_init!();
burn_tensor::testgen_log!();
burn_tensor::testgen_log1p!();
burn_tensor::testgen_map_comparison!();
diff --git a/burn-candle/src/ops/int_tensor.rs b/burn-candle/src/ops/int_tensor.rs
index 706eaa2ce2..7499e06418 100644
--- a/burn-candle/src/ops/int_tensor.rs
+++ b/burn-candle/src/ops/int_tensor.rs
@@ -73,7 +73,7 @@ impl IntTensorOps IntTensor {
CandleTensor::new(
mask.tensor
- .where_cond(&tensor.tensor, &tensor.tensor)
+ .where_cond(&source.tensor, &tensor.tensor)
.unwrap(),
)
}
@@ -284,7 +284,8 @@ impl IntTensorOps,
rhs: IntElem,
) -> IntTensor {
- CandleTensor::new((lhs.tensor / rhs.elem::()).unwrap())
+ // Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0.
+ panic!("Not supported by Candle")
}
fn int_zeros(shape: Shape, device: &Device) -> IntTensor {
@@ -312,15 +313,30 @@ impl IntTensorOps(tensor: IntTensor, dim: usize) -> IntTensor {
- CandleTensor::new(tensor.tensor.mean_keepdim(dim).unwrap())
+ // Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0.
+ panic!("Not supported by Candle")
}
fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor {
- CandleTensor::new(tensor.tensor.argmax_keepdim(dim).unwrap())
+ CandleTensor::new(
+ tensor
+ .tensor
+ .argmax_keepdim(dim)
+ .unwrap()
+ .to_dtype(I::DTYPE)
+ .unwrap(),
+ )
}
fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor {
- CandleTensor::new(tensor.tensor.argmin_keepdim(dim).unwrap())
+ CandleTensor::new(
+ tensor
+ .tensor
+ .argmin_keepdim(dim)
+ .unwrap()
+ .to_dtype(I::DTYPE)
+ .unwrap(),
+ )
}
fn int_abs(tensor: IntTensor) -> IntTensor {
diff --git a/burn-candle/src/ops/tensor.rs b/burn-candle/src/ops/tensor.rs
index ea0146d67f..095f94226c 100644
--- a/burn-candle/src/ops/tensor.rs
+++ b/burn-candle/src/ops/tensor.rs
@@ -34,7 +34,9 @@ impl TensorOps>
Distribution::Bernoulli(prob) => CandleTensor::new(
candle_core::Tensor::rand(0., 1., shape, device)
.unwrap()
- .gt(&super::candle_utils::fill(prob, shape, F::DTYPE, device))
+ .to_dtype(F::DTYPE)
+ .unwrap()
+ .lt(&super::candle_utils::fill(prob, shape, F::DTYPE, device))
.unwrap()
.to_dtype(F::DTYPE)
.unwrap(),
diff --git a/burn-import/onnx-tests/tests/onnx_tests.rs b/burn-import/onnx-tests/tests/onnx_tests.rs
index 8992141ec6..9b0f15c64a 100644
--- a/burn-import/onnx-tests/tests/onnx_tests.rs
+++ b/burn-import/onnx-tests/tests/onnx_tests.rs
@@ -42,6 +42,8 @@ include_models!(
#[cfg(test)]
mod tests {
+ use std::f64::consts;
+
use super::*;
use burn::tensor::{Data, Int, Shape, Tensor};
@@ -154,8 +156,8 @@ mod tests {
// Initialize the model with weights (loaded from the exported file)
let model: conv1d::Model = conv1d::Model::default();
- // Run the model with 3.1416 as input for easier testing
- let input = Tensor::::full([6, 4, 10], 3.1416);
+ // Run the model with pi as input for easier testing
+ let input = Tensor::::full([6, 4, 10], consts::PI);
let output = model.forward(input);
@@ -460,18 +462,24 @@ mod tests {
// Run the model
let input = Tensor::::from_floats([
- 0.88226926, 0.91500396, 0.38286376, 0.95930564, 0.39044821, 0.60089535,
+ 0.88226926,
+ 0.91500396,
+ 0.38286376,
+ 0.95930564,
+ 0.390_448_2,
+ 0.60089535,
]);
let (output1, output2, output3) = model.forward(input);
let expected1 = Data::from([
- 0.88226926, 0.91500396, 0.38286376, 0.95930564, 0.39044821, 0.60089535,
- ]);
- let expected2 = Data::from([
- 0.69999999, 0.69999999, 0.50000000, 0.69999999, 0.50000000, 0.60089535,
- ]);
- let expected3 = Data::from([
- 0.80000001, 0.80000001, 0.38286376, 0.80000001, 0.39044821, 0.60089535,
+ 0.88226926,
+ 0.91500396,
+ 0.38286376,
+ 0.95930564,
+ 0.390_448_2,
+ 0.60089535,
]);
+ let expected2 = Data::from([0.7, 0.7, 0.5, 0.7, 0.5, 0.60089535]);
+ let expected3 = Data::from([0.8, 0.8, 0.38286376, 0.8, 0.390_448_2, 0.60089535]);
assert_eq!(output1.to_data(), expected1);
assert_eq!(output2.to_data(), expected2);
@@ -485,18 +493,24 @@ mod tests {
// Run the model
let input = Tensor::::from_floats([
- 0.88226926, 0.91500396, 0.38286376, 0.95930564, 0.39044821, 0.60089535,
+ 0.88226926,
+ 0.91500396,
+ 0.38286376,
+ 0.95930564,
+ 0.390_448_2,
+ 0.60089535,
]);
let (output1, output2, output3) = model.forward(input);
let expected1 = Data::from([
- 0.88226926, 0.91500396, 0.38286376, 0.95930564, 0.39044821, 0.60089535,
- ]);
- let expected2 = Data::from([
- 0.69999999, 0.69999999, 0.50000000, 0.69999999, 0.50000000, 0.60089535,
- ]);
- let expected3 = Data::from([
- 0.80000001, 0.80000001, 0.38286376, 0.80000001, 0.39044821, 0.60089535,
+ 0.88226926,
+ 0.91500396,
+ 0.38286376,
+ 0.95930564,
+ 0.390_448_2,
+ 0.60089535,
]);
+ let expected2 = Data::from([0.7, 0.7, 0.5, 0.7, 0.5, 0.60089535]);
+ let expected3 = Data::from([0.8, 0.8, 0.38286376, 0.8, 0.390_448_2, 0.60089535]);
assert_eq!(output1.to_data(), expected1);
assert_eq!(output2.to_data(), expected2);
diff --git a/burn-tch/src/ops/int_tensor.rs b/burn-tch/src/ops/int_tensor.rs
index 505dc19578..a8d2f5a53a 100644
--- a/burn-tch/src/ops/int_tensor.rs
+++ b/burn-tch/src/ops/int_tensor.rs
@@ -195,10 +195,13 @@ impl IntTensorOps> for TchBackend {
}
fn int_div_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
- lhs.unary_ops(
+ let lhs: TchTensor =
+ TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, true, false));
+ let output: TchTensor = lhs.unary_ops(
|mut tensor| tensor.f_div_scalar_(rhs).unwrap(),
|tensor| tensor.f_div_scalar(rhs).unwrap(),
- )
+ );
+ TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false))
}
fn int_neg(tensor: TchTensor) -> TchTensor {
@@ -249,11 +252,20 @@ impl IntTensorOps> for TchBackend {
}
fn int_mean(tensor: TchTensor) -> TchTensor {
- TchOps::mean(tensor)
+ let tensor: TchTensor =
+ TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
+ let output: TchTensor = TchTensor::new(TchOps::mean(tensor).tensor);
+
+ TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false))
}
fn int_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
- TchOps::mean_dim(tensor, dim)
+ let tensor: TchTensor =
+ TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
+
+ let output: TchTensor = TchTensor::new(TchOps::mean_dim(tensor, dim).tensor);
+
+ TchTensor::::new(output.tensor.to_dtype(tch::Kind::Int64, true, false))
}
fn int_gather(
@@ -298,9 +310,9 @@ impl IntTensorOps> for TchBackend {
TchTensor::binary_ops_tensor(
tensor,
source,
- |tensor, source| tensor.f_masked_scatter_(&mask.tensor, source).unwrap(),
- |tensor, source| tensor.f_masked_scatter(&mask.tensor, source).unwrap(),
- |tensor, source| tensor.f_masked_scatter(&mask.tensor, source).unwrap(),
+ |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
+ |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
+ |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
)
}
diff --git a/burn-tensor/src/tests/mod.rs b/burn-tensor/src/tests/mod.rs
index 282cf33b3f..d137b8a764 100644
--- a/burn-tensor/src/tests/mod.rs
+++ b/burn-tensor/src/tests/mod.rs
@@ -39,12 +39,12 @@ macro_rules! testgen_all {
burn_tensor::testgen_clamp!();
burn_tensor::testgen_cos!();
burn_tensor::testgen_div!();
- burn_tensor::testgen_empty!();
burn_tensor::testgen_erf!();
burn_tensor::testgen_exp!();
burn_tensor::testgen_flatten!();
burn_tensor::testgen_full!();
burn_tensor::testgen_gather_scatter!();
+ burn_tensor::testgen_init!();
burn_tensor::testgen_log!();
burn_tensor::testgen_log1p!();
burn_tensor::testgen_map_comparison!();
diff --git a/burn-tensor/src/tests/ops/add.rs b/burn-tensor/src/tests/ops/add.rs
index d2b3f60d6a..bd45b4376d 100644
--- a/burn-tensor/src/tests/ops/add.rs
+++ b/burn-tensor/src/tests/ops/add.rs
@@ -1,7 +1,7 @@
#[burn_tensor_testgen::testgen(add)]
mod tests {
use super::*;
- use burn_tensor::{Data, Tensor};
+ use burn_tensor::{Data, Int, Tensor};
#[test]
fn test_add_d2() {
@@ -41,4 +41,43 @@ mod tests {
let data_expected = Data::from([[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]]);
assert_eq!(data_expected, data_actual);
}
+
+ #[test]
+ fn test_add_d2_int() {
+ let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]);
+ let data_2 = Data::from([[6, 7, 8], [9, 10, 11]]);
+ let tensor_1 = Tensor::::from_data(data_1);
+ let tensor_2 = Tensor::::from_data(data_2);
+
+ let data_actual = (tensor_1 + tensor_2).into_data();
+
+ let data_expected = Data::from([[6, 8, 10], [12, 14, 16]]);
+ assert_eq!(data_expected, data_actual);
+ }
+
+ #[test]
+ fn test_add_broadcast_int() {
+ let data_1 = Data::from([[0, 1, 2]]);
+ let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]);
+ let tensor_1 = Tensor::::from_data(data_1);
+ let tensor_2 = Tensor::::from_data(data_2);
+
+ let data_actual = (tensor_1 + tensor_2).into_data();
+
+ let data_expected = Data::from([[3, 5, 7], [6, 8, 10]]);
+ assert_eq!(data_expected, data_actual);
+ }
+
+ #[test]
+ fn should_support_add_scalar_ops_int() {
+ let data = Data::from([[0, 1, 2], [3, 4, 5]]);
+ let scalar = 2;
+ let tensor = Tensor::::from_data(data);
+
+ let output = tensor + scalar;
+
+ let data_actual = output.into_data();
+ let data_expected = Data::from([[2, 3, 4], [5, 6, 7]]);
+ assert_eq!(data_expected, data_actual);
+ }
}
diff --git a/burn-tensor/src/tests/ops/aggregation.rs b/burn-tensor/src/tests/ops/aggregation.rs
index 2bb82a24a5..a260453d47 100644
--- a/burn-tensor/src/tests/ops/aggregation.rs
+++ b/burn-tensor/src/tests/ops/aggregation.rs
@@ -12,6 +12,15 @@ mod tests {
assert_eq!(data_actual, Data::from([15.0 / 6.0]));
}
+ #[test]
+ fn test_should_mean_int() {
+ let tensor = TestTensorInt::from_data([[2, 2, 2], [3, 4, 5]]);
+
+ let data_actual = tensor.mean().to_data();
+
+ assert_eq!(data_actual, Data::from([3]));
+ }
+
#[test]
fn test_should_sum() {
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
@@ -21,6 +30,15 @@ mod tests {
assert_eq!(data_actual, Data::from([15.0]));
}
+ #[test]
+ fn test_should_sum_int() {
+ let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]);
+
+ let data_actual = tensor.sum().to_data();
+
+ assert_eq!(data_actual, Data::from([15]));
+ }
+
#[test]
fn test_should_mean_last_dim() {
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
@@ -39,6 +57,24 @@ mod tests {
assert_eq!(data_actual, Data::from([[3.0], [12.0]]));
}
+ #[test]
+ fn test_should_mean_last_dim_int() {
+ let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]);
+
+ let data_actual = tensor.mean_dim(1).to_data();
+
+ assert_eq!(data_actual, Data::from([[1], [4]]));
+ }
+
+ #[test]
+ fn test_should_sum_last_dim_int() {
+ let tensor = TestTensorInt::from_data([[0, 1, 2], [3, 4, 5]]);
+
+ let data_actual = tensor.sum_dim(1).to_data();
+
+ assert_eq!(data_actual, Data::from([[3], [12]]));
+ }
+
#[test]
fn test_should_sum_first_dim() {
let tensor = TestTensor::from_data([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]);
diff --git a/burn-tensor/src/tests/ops/arg.rs b/burn-tensor/src/tests/ops/arg.rs
index cfc4006edf..fd6f282b76 100644
--- a/burn-tensor/src/tests/ops/arg.rs
+++ b/burn-tensor/src/tests/ops/arg.rs
@@ -1,7 +1,7 @@
#[burn_tensor_testgen::testgen(arg)]
mod tests {
use super::*;
- use burn_tensor::{Data, Tensor};
+ use burn_tensor::{Data, Int, Tensor};
#[test]
fn test_argmax_2d_dim0() {
@@ -25,6 +25,28 @@ mod tests {
assert_eq!(data_expected, data_actual.to_data());
}
+ #[test]
+ fn test_argmax_2d_dim0_int() {
+ let data = Data::from([[10, 11, 2], [3, 4, 5]]);
+ let tensor = Tensor::::from_data(data);
+
+ let data_actual = tensor.argmax(0);
+
+ let data_expected = Data::from([[0, 0, 1]]);
+ assert_eq!(data_expected, data_actual.to_data());
+ }
+
+ #[test]
+ fn test_argmin_2d_dim0_int() {
+ let data = Data::from([[10, 11, 2], [30, 4, 5]]);
+ let tensor = Tensor::::from_data(data);
+
+ let data_actual = tensor.argmin(0);
+
+ let data_expected = Data::from([[0, 1, 0]]);
+ assert_eq!(data_expected, data_actual.to_data());
+ }
+
#[test]
fn test_argmax_2d_dim1() {
let data = Data::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]);
diff --git a/burn-tensor/src/tests/ops/div.rs b/burn-tensor/src/tests/ops/div.rs
index 20992deb69..ab48f74ed9 100644
--- a/burn-tensor/src/tests/ops/div.rs
+++ b/burn-tensor/src/tests/ops/div.rs
@@ -1,7 +1,7 @@
#[burn_tensor_testgen::testgen(div)]
mod tests {
use super::*;
- use burn_tensor::{Data, Tensor};
+ use burn_tensor::{Data, Int, Tensor};
#[test]
fn should_support_div_ops() {
@@ -42,4 +42,44 @@ mod tests {
let data_expected = Data::from([[0.0, 0.5, 1.0], [1.5, 2.0, 2.5]]);
assert_eq!(data_expected, data_actual);
}
+
+ #[test]
+ fn should_support_div_ops_int() {
+ let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]);
+ let data_2 = Data::from([[1, 1, 2], [1, 1, 2]]);
+ let tensor_1 = Tensor::::from_data(data_1);
+ let tensor_2 = Tensor::::from_data(data_2);
+
+ let output = tensor_1 / tensor_2;
+
+ let data_actual = output.into_data();
+ let data_expected = Data::from([[0, 1, 1], [3, 4, 2]]);
+ assert_eq!(data_expected, data_actual);
+ }
+
+ #[test]
+ fn test_div_broadcast_int() {
+ let data_1 = Data::from([[0, 1, 2]]);
+ let data_2 = Data::from([[1, 1, 2], [3, 4, 5]]);
+ let tensor_1 = Tensor::::from_data(data_1);
+ let tensor_2 = Tensor::::from_data(data_2);
+
+ let data_actual = (tensor_1 / tensor_2).into_data();
+
+ let data_expected = Data::from([[0, 1, 1], [0, 0, 0]]);
+ assert_eq!(data_expected, data_actual);
+ }
+
+ #[test]
+ fn should_support_div_scalar_ops_int() {
+ let data = Data::from([[0, 1, 2], [3, 4, 5]]);
+ let scalar = 2;
+ let tensor = Tensor::::from_data(data);
+
+ let output = tensor / scalar;
+
+ let data_actual = output.into_data();
+ let data_expected = Data::from([[0, 0, 1], [1, 2, 2]]);
+ assert_eq!(data_expected, data_actual);
+ }
}
diff --git a/burn-tensor/src/tests/ops/empty.rs b/burn-tensor/src/tests/ops/empty.rs
deleted file mode 100644
index 1dacbf6c2d..0000000000
--- a/burn-tensor/src/tests/ops/empty.rs
+++ /dev/null
@@ -1,26 +0,0 @@
-#[burn_tensor_testgen::testgen(empty)]
-mod tests {
- use super::*;
- use burn_tensor::{Bool, Int, Tensor};
-
- #[test]
- fn should_support_float_empty() {
- let shape = [2, 2];
- let tensor = Tensor::::empty(shape);
- assert_eq!(tensor.shape(), shape.into())
- }
-
- #[test]
- fn should_support_int_empty() {
- let shape = [2, 2];
- let tensor = Tensor::::empty(shape);
- assert_eq!(tensor.shape(), shape.into())
- }
-
- #[test]
- fn should_support_bool_empty() {
- let shape = [2, 2];
- let tensor = Tensor::::empty(shape);
- assert_eq!(tensor.shape(), shape.into())
- }
-}
diff --git a/burn-tensor/src/tests/ops/gather_scatter.rs b/burn-tensor/src/tests/ops/gather_scatter.rs
index 597981796d..e133588739 100644
--- a/burn-tensor/src/tests/ops/gather_scatter.rs
+++ b/burn-tensor/src/tests/ops/gather_scatter.rs
@@ -13,6 +13,16 @@ mod tests {
assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0]));
}
+ #[test]
+ fn should_gather_1d_dim0_int() {
+ let tensor = TestTensorInt::from_ints([5, 6, 7]);
+ let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2]);
+
+ let output = tensor.gather(0, indices);
+
+ assert_eq!(output.into_data(), Data::from([6, 6, 5, 6, 7]));
+ }
+
#[test]
fn should_gather_2d_dim0() {
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
@@ -79,6 +89,17 @@ mod tests {
assert_eq!(output.into_data(), Data::from([4.0, 5.0, 3.0]));
}
+ #[test]
+ fn should_scatter_1d_int() {
+ let tensor = TestTensorInt::from_ints([0, 0, 0]);
+ let values = TestTensorInt::from_ints([5, 4, 3]);
+ let indices = TestTensorInt::from_ints([1, 0, 2]);
+
+ let output = tensor.scatter(0, indices, values);
+
+ assert_eq!(output.into_data(), Data::from([4, 5, 3]));
+ }
+
#[test]
fn should_scatter_2d_dim0() {
let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);
diff --git a/burn-tensor/src/tests/ops/init.rs b/burn-tensor/src/tests/ops/init.rs
new file mode 100644
index 0000000000..7a89527cd3
--- /dev/null
+++ b/burn-tensor/src/tests/ops/init.rs
@@ -0,0 +1,58 @@
+#[burn_tensor_testgen::testgen(init)]
+mod tests {
+ use super::*;
+ use burn_tensor::{Bool, Data, Int, Tensor};
+
+ #[test]
+ fn should_support_float_empty() {
+ let shape = [2, 2];
+ let tensor = Tensor::::empty(shape);
+ assert_eq!(tensor.shape(), shape.into())
+ }
+
+ #[test]
+ fn should_support_int_empty() {
+ let shape = [2, 2];
+ let tensor = Tensor::::empty(shape);
+ assert_eq!(tensor.shape(), shape.into())
+ }
+
+ #[test]
+ fn should_support_float_zeros() {
+ let shape = [2, 2];
+ let tensor = Tensor::::zeros(shape);
+ assert_eq!(tensor.shape(), shape.into());
+ assert_eq!(tensor.to_data(), Data::from([[0., 0.], [0., 0.]]))
+ }
+
+ #[test]
+ fn should_support_int_zeros() {
+ let shape = [2, 2];
+ let tensor = Tensor::::zeros(shape);
+ assert_eq!(tensor.shape(), shape.into());
+ assert_eq!(tensor.to_data(), Data::from([[0, 0], [0, 0]]))
+ }
+
+ #[test]
+ fn should_support_float_ones() {
+ let shape = [2, 2];
+ let tensor = Tensor::::ones(shape);
+ assert_eq!(tensor.shape(), shape.into());
+ assert_eq!(tensor.to_data(), Data::from([[1., 1.], [1., 1.]]))
+ }
+
+ #[test]
+ fn should_support_int_ones() {
+ let shape = [2, 2];
+ let tensor = Tensor::::ones(shape);
+ assert_eq!(tensor.shape(), shape.into());
+ assert_eq!(tensor.to_data(), Data::from([[1, 1], [1, 1]]))
+ }
+
+ #[test]
+ fn should_support_bool_empty() {
+ let shape = [2, 2];
+ let tensor = Tensor::::empty(shape);
+ assert_eq!(tensor.shape(), shape.into())
+ }
+}
diff --git a/burn-tensor/src/tests/ops/mask.rs b/burn-tensor/src/tests/ops/mask.rs
index f4e32f8cd4..735c31127e 100644
--- a/burn-tensor/src/tests/ops/mask.rs
+++ b/burn-tensor/src/tests/ops/mask.rs
@@ -1,18 +1,18 @@
#[burn_tensor_testgen::testgen(mask)]
mod tests {
use super::*;
- use burn_tensor::{Bool, Data, Tensor};
+ use burn_tensor::{Bool, Data, Int, Tensor};
#[test]
fn should_support_mask_where_ops() {
let tensor = TestTensor::from_data([[1.0, 7.0], [2.0, 3.0]]);
let mask =
Tensor::::from_bool(Data::from([[true, false], [false, true]]));
- let value = Tensor::::from_data(Data::from([[8.8, 8.8], [8.8, 8.8]]));
+ let value = Tensor::::from_data(Data::from([[1.8, 2.8], [3.8, 4.8]]));
let data_actual = tensor.mask_where(mask, value).into_data();
- let data_expected = Data::from([[8.8, 7.0], [2.0, 8.8]]);
+ let data_expected = Data::from([[1.8, 7.0], [2.0, 4.8]]);
assert_eq!(data_expected, data_actual);
}
@@ -27,4 +27,29 @@ mod tests {
let data_expected = Data::from([[2.0, 7.0], [2.0, 2.0]]);
assert_eq!(data_expected, data_actual);
}
+
+ #[test]
+ fn should_support_int_mask_where_ops() {
+ let tensor = Tensor::::from_data([[1, 7], [2, 3]]);
+ let mask =
+ Tensor::::from_bool(Data::from([[true, false], [false, true]]));
+ let value = Tensor::::from_data(Data::from([[8, 9], [10, 11]]));
+
+ let data_actual = tensor.mask_where(mask, value).into_data();
+
+ let data_expected = Data::from([[8, 7], [2, 11]]);
+ assert_eq!(data_expected, data_actual);
+ }
+
+ #[test]
+ fn should_support_int_mask_fill_ops() {
+ let tensor = Tensor::::from_data([[1, 7], [2, 3]]);
+ let mask =
+ Tensor::::from_bool(Data::from([[true, false], [false, true]]));
+
+ let data_actual = tensor.mask_fill(mask, 9).to_data();
+
+ let data_expected = Data::from([[9, 7], [2, 9]]);
+ assert_eq!(data_expected, data_actual);
+ }
}
diff --git a/burn-tensor/src/tests/ops/mod.rs b/burn-tensor/src/tests/ops/mod.rs
index 6fadd5b072..d7f143f00b 100644
--- a/burn-tensor/src/tests/ops/mod.rs
+++ b/burn-tensor/src/tests/ops/mod.rs
@@ -9,12 +9,12 @@ mod cat;
mod clamp;
mod cos;
mod div;
-mod empty;
mod erf;
mod exp;
mod flatten;
mod full;
mod gather_scatter;
+mod init;
mod log;
mod log1p;
mod map_comparison;
diff --git a/burn-tensor/src/tests/ops/mul.rs b/burn-tensor/src/tests/ops/mul.rs
index 11656d8b1f..81337b808f 100644
--- a/burn-tensor/src/tests/ops/mul.rs
+++ b/burn-tensor/src/tests/ops/mul.rs
@@ -1,7 +1,7 @@
#[burn_tensor_testgen::testgen(mul)]
mod tests {
use super::*;
- use burn_tensor::{Data, Tensor};
+ use burn_tensor::{Data, Int, Tensor};
#[test]
fn should_support_mul_ops() {
@@ -42,4 +42,44 @@ mod tests {
let data_expected = Data::from([[0.0, 2.0, 4.0], [6.0, 8.0, 10.0]]);
assert_eq!(data_expected, data_actual);
}
+
+ #[test]
+ fn should_support_mul_ops_int() {
+ let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]);
+ let data_2 = Data::from([[0, 1, 2], [3, 4, 5]]);
+ let tensor_1 = Tensor::::from_data(data_1);
+ let tensor_2 = Tensor::::from_data(data_2);
+
+ let output = tensor_1 * tensor_2;
+
+ let data_actual = output.into_data();
+ let data_expected = Data::from([[0, 1, 4], [9, 16, 25]]);
+ assert_eq!(data_expected, data_actual);
+ }
+
+ #[test]
+ fn test_mul_broadcast_int() {
+ let data_1 = Data::from([[0, 1, 2]]);
+ let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]);
+ let tensor_1 = Tensor::::from_data(data_1);
+ let tensor_2 = Tensor::::from_data(data_2);
+
+ let data_actual = (tensor_1 * tensor_2).into_data();
+
+ let data_expected = Data::from([[0, 4, 10], [0, 7, 16]]);
+ assert_eq!(data_expected, data_actual);
+ }
+
+ #[test]
+ fn should_support_mul_scalar_ops_int() {
+ let data = Data::from([[0, 1, 2], [3, 4, 5]]);
+ let scalar = 2;
+ let tensor = Tensor::::from_data(data);
+
+ let output = tensor * scalar;
+
+ let data_actual = output.into_data();
+ let data_expected = Data::from([[0, 2, 4], [6, 8, 10]]);
+ assert_eq!(data_expected, data_actual);
+ }
}
diff --git a/burn-tensor/src/tests/ops/random.rs b/burn-tensor/src/tests/ops/random.rs
index 6bf591a1dc..5aeccfa7a4 100644
--- a/burn-tensor/src/tests/ops/random.rs
+++ b/burn-tensor/src/tests/ops/random.rs
@@ -10,4 +10,18 @@ mod tests {
// check that the tensor is within the range of [0..1) (1 is exclusive)
tensor.into_data().assert_within_range(0.0..1.0);
}
+
+ #[test]
+ fn rand_uniform() {
+ let tensor = Tensor::::random([20], Distribution::Uniform(4., 5.));
+
+ tensor.into_data().assert_within_range(4.0..5.0);
+ }
+
+ #[test]
+ fn rand_bernoulli() {
+ let tensor = Tensor::::random([20], Distribution::Bernoulli(1.));
+
+ assert_eq!(tensor.into_data(), [1.; 20].into());
+ }
}
diff --git a/burn-tensor/src/tests/ops/select.rs b/burn-tensor/src/tests/ops/select.rs
index f2ff024126..fefe8c7601 100644
--- a/burn-tensor/src/tests/ops/select.rs
+++ b/burn-tensor/src/tests/ops/select.rs
@@ -13,6 +13,16 @@ mod tests {
assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0]));
}
+ #[test]
+ fn should_select_1d_int() {
+ let tensor = TestTensorInt::from_data([5, 6, 7]);
+ let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]);
+
+ let output = tensor.select(0, indices);
+
+ assert_eq!(output.into_data(), Data::from([6, 6, 5, 6, 7]));
+ }
+
#[test]
fn should_select_2d_dim0_same_num_dim() {
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
@@ -68,6 +78,17 @@ mod tests {
assert_eq!(output.into_data(), Data::from([3.0, 12.0, 3.0]));
}
+ #[test]
+ fn should_select_assign_1d_int() {
+ let tensor = TestTensorInt::from_data([7, 8, 9]);
+ let values = TestTensorInt::from_data([5, 4, 3, 2, 1]);
+ let indices = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
+
+ let output = tensor.select_assign(0, indices, values);
+
+ assert_eq!(output.into_data(), Data::from([10, 19, 10]));
+ }
+
#[test]
fn should_select_assign_2d_dim0() {
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
diff --git a/burn-tensor/src/tests/ops/sub.rs b/burn-tensor/src/tests/ops/sub.rs
index 43da19c6d9..3293379abd 100644
--- a/burn-tensor/src/tests/ops/sub.rs
+++ b/burn-tensor/src/tests/ops/sub.rs
@@ -1,7 +1,7 @@
#[burn_tensor_testgen::testgen(sub)]
mod tests {
use super::*;
- use burn_tensor::{Data, Tensor};
+ use burn_tensor::{Data, Int, Tensor};
#[test]
fn should_support_sub_ops() {
@@ -41,4 +41,43 @@ mod tests {
let data_expected = Data::from([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]]);
assert_eq!(data_expected, data_actual);
}
+
+ #[test]
+ fn should_support_sub_ops_int() {
+ let data_1 = Data::from([[0, 1, 2], [3, 4, 5]]);
+ let data_2 = Data::from([[6, 7, 8], [9, 10, 11]]);
+ let data_expected = Data::from([[-6, -6, -6], [-6, -6, -6]]);
+ let tensor_1 = Tensor::::from_data(data_1);
+ let tensor_2 = Tensor::::from_data(data_2);
+
+ let data_actual = (tensor_1 - tensor_2).into_data();
+
+ assert_eq!(data_expected, data_actual);
+ }
+
+ #[test]
+ fn test_sub_broadcast_int() {
+ let data_1 = Data::from([[0, 1, 2]]);
+ let data_2 = Data::from([[3, 4, 5], [6, 7, 8]]);
+ let tensor_1 = Tensor::::from_data(data_1);
+ let tensor_2 = Tensor::::from_data(data_2);
+
+ let data_actual = (tensor_1 - tensor_2).into_data();
+
+ let data_expected = Data::from([[-3, -3, -3], [-6, -6, -6]]);
+ assert_eq!(data_expected, data_actual);
+ }
+
+ #[test]
+ fn should_support_sub_scalar_ops_int() {
+ let data = Data::from([[0, 1, 2], [3, 4, 5]]);
+ let scalar = 2;
+ let tensor = Tensor::::from_data(data);
+
+ let output = tensor - scalar;
+
+ let data_actual = output.into_data();
+ let data_expected = Data::from([[-2, -1, 0], [1, 2, 3]]);
+ assert_eq!(data_expected, data_actual);
+ }
}