Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added erf op to math.py #908

Closed
wants to merge 14 commits into from
Closed
Prev Previous commit
Next Next commit
corrected jax expression, numpy function
  • Loading branch information
sqali committed Sep 21, 2023
commit 5efabf539a4eac8288e563c1a3c00854b8d5c8fb
2 changes: 1 addition & 1 deletion keras_core/backend/jax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,4 @@ def rsqrt(x):


def erf(x):
return jnp.erf(x)
return jax.lax.erf(x)
2 changes: 1 addition & 1 deletion keras_core/backend/numpy/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,4 @@ def rsqrt(x):


def erf(x):
return scipy.special.erf(x)
return np.array(scipy.special.erf(x))
6 changes: 3 additions & 3 deletions keras_core/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def test_erf_operation_basic(self):
)

# Output from the erf operation in keras_core
output_from_erf_op = kmath.erf(sample_values).numpy()
output_from_erf_op = kmath.erf(sample_values)

# Assert that the outputs are close
self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5)
Expand All @@ -860,7 +860,7 @@ def test_erf_operation_dtype(self):
expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(
sample_values
)
output_from_erf_op = kmath.erf(sample_values).numpy()
output_from_erf_op = kmath.erf(sample_values)
self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5)

def test_erf_operation_edge_cases(self):
Expand All @@ -869,7 +869,7 @@ def test_erf_operation_edge_cases(self):
expected_edge_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(
edge_values
)
output_from_edge_erf_op = kmath.erf(edge_values).numpy()
output_from_edge_erf_op = kmath.erf(edge_values)
self.assertAllClose(
expected_edge_output, output_from_edge_erf_op, atol=1e-5
)