Skip to content

Commit

Permalink
fix: tensor prod and prod dim containing nan values (tracel-ai#2515)
Browse files Browse the repository at this point in the history
  • Loading branch information
quinton11 authored Nov 20, 2024
1 parent f64914b commit a0e8e4d
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 0 deletions.
51 changes: 51 additions & 0 deletions crates/burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,57 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
out
}

fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float_ops!(ProdOps, B::float_prod, reduce);

let stream = tensor.stream;
let out = tensor
.client
.tensor_uninitialized(vec![1], B::FloatElem::dtype());

let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Prod(desc.clone()),
),
ProdOps::<B>::new(desc),
);

out
}

fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
scalar_float_ops!(ProdDimOps, B::float_prod_dim, usize, noconvert);

let stream = tensor.stream;
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor
.client
.tensor_uninitialized(shape, B::FloatElem::dtype());

let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::ProdDim(desc.clone()),
),
ProdDimOps::<B>::new(desc),
);

out
}

fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float_ops!(MeanOps, B::float_mean, reduce);

Expand Down
8 changes: 8 additions & 0 deletions crates/burn-ndarray/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,14 @@ impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorO
NdArrayTensor::new(array)
}

fn float_prod(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
NdArrayMathOps::prod(tensor)
}

fn float_prod_dim(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<E> {
NdArrayMathOps::prod_dim(tensor, dim)
}

fn float_log1p(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
let array = tensor.array.mapv_into(|a| a.log1p_elem()).into_shared();

Expand Down
1 change: 1 addition & 0 deletions crates/burn-tensor/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ macro_rules! testgen_with_float_param {
burn_tensor::testgen_floor!();
burn_tensor::testgen_ceil!();
burn_tensor::testgen_select!();
burn_tensor::testgen_prod!();

// test stats
burn_tensor::testgen_var!();
Expand Down
1 change: 1 addition & 0 deletions crates/burn-tensor/src/tests/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ mod padding;
mod permute;
mod powf;
mod powf_scalar;
mod prod;
mod random;
mod recip;
mod remainder;
Expand Down
16 changes: 16 additions & 0 deletions crates/burn-tensor/src/tests/ops/prod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#[burn_tensor_testgen::testgen(prod)]
mod tests {
use super::*;
use burn_tensor::{Tensor, TensorData};

#[test]
fn test_prod_float() {
let tensor_1 = TestTensor::<2>::from([[-5.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);

let output = tensor_1.prod();

output
.into_data()
.assert_eq(&TensorData::from([-600.0]), false);
}
}

0 comments on commit a0e8e4d

Please sign in to comment.