Skip to content

Commit

Permalink
fix approximately equal precision issue in test code (tracel-ai#954)
Browse files Browse the repository at this point in the history
  • Loading branch information
bytesnail authored Nov 13, 2023
1 parent 4d63a24 commit 2614944
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
11 changes: 6 additions & 5 deletions burn-autodiff/src/tests/cos.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ mod tests {
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

grad_1
.to_data()
.assert_approx_eq(&Data::from([[26.8063, -27.7870], [26.8063, -27.7870]]), 3);
grad_2.to_data().assert_approx_eq(
grad_1.to_data().assert_approx_eq_diff(
&Data::from([[26.8063, -27.7870], [26.8063, -27.7870]]),
2.0e-3,
);
grad_2.to_data().assert_approx_eq_diff(
&Data::from([[9.222064, -39.123375], [-28.721354, 49.748356]]),
3,
2.0e-3,
);
}
}
6 changes: 3 additions & 3 deletions burn-autodiff/src/tests/sin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ mod tests {

grad_1
.to_data()
.assert_approx_eq(&Data::from([[8.8500, -4.9790], [8.8500, -4.9790]]), 3);
grad_2.to_data().assert_approx_eq(
.assert_approx_eq_diff(&Data::from([[8.8500, -4.9790], [8.8500, -4.9790]]), 2.6e-3);
grad_2.to_data().assert_approx_eq_diff(
&Data::from([[38.668987, 44.194775], [-59.97261, -80.46094]]),
3,
2.6e-3,
);
}
}
18 changes: 17 additions & 1 deletion burn-tensor/src/tensor/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,23 @@ impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<E
/// Panics if the data is not approximately equal.
#[track_caller]
pub fn assert_approx_eq(&self, other: &Self, precision: usize) {
let tolerance = libm::pow(0.1, precision as f64);

self.assert_approx_eq_diff(other, tolerance)
}

/// Asserts the data is approximately equal to another data.
///
/// # Arguments
///
/// * `other` - The other data.
/// * `tolerance` - The tolerance of the comparison.
///
/// # Panics
///
/// Panics if the data is not approximately equal.
#[track_caller]
pub fn assert_approx_eq_diff(&self, other: &Self, tolerance: f64) {
let mut message = String::new();
if self.shape != other.shape {
message += format!(
Expand All @@ -320,7 +337,6 @@ impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<E
let b: f64 = b.into();

let err = libm::sqrt(libm::pow(a - b, 2.0));
let tolerance = libm::pow(0.1, precision as f64);

if err > tolerance {
// Only print the first 5 different values.
Expand Down

0 comments on commit 2614944

Please sign in to comment.