-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added erf op to math.py #908
Changes from 1 commit
6485433
61c9690
7960106
d9f99f2
ef05f70
8d4b8eb
e829832
686f339
6b52b8a
7f76459
5efabf5
ebe06c3
3f199c3
0307b74
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -248,3 +248,7 @@ def istft( | |
|
||
def rsqrt(x): | ||
return jax.lax.rsqrt(x) | ||
|
||
|
||
def erf(x): | ||
return jnp.erf(x) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -302,3 +302,7 @@ def istft( | |
|
||
def rsqrt(x): | ||
return 1.0 / np.sqrt(x) | ||
|
||
|
||
def erf(x): | ||
return scipy.special.erf(x) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -242,4 +242,4 @@ def rsqrt(x): | |
|
||
|
||
def erf(x): | ||
return tf.math.erf(x) | ||
return tf.math.erf(x) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -936,7 +936,7 @@ class Erf(Operation): | |
|
||
Args: | ||
input_tensor: A tensor of type `float32` or `float64`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can have more types, no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Edited the comments based on that |
||
|
||
Returns: | ||
A tensor of the same shape and type as `input_tensor`, containing the error function values. | ||
|
||
|
@@ -952,6 +952,7 @@ class Erf(Operation): | |
>>> x_large = np.array([1e10, -1e10]) | ||
>>> y_large = Erf()(x_large) | ||
""" | ||
|
||
def __init__(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not needed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed it |
||
super().__init__() | ||
|
||
|
@@ -961,6 +962,7 @@ def compute_output_spec(self, input_tensor): | |
def call(self, input_tensor): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just x There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replaced with x |
||
return backend.erf(input_tensor) | ||
|
||
|
||
def erf(x): | ||
"""Functional interface to the `Erf` operation.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is where the docstring should be, not the op above, since this is the public symbol. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Relocated it! |
||
return Erf()(x) | ||
return Erf()(x) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -838,13 +838,14 @@ def test_rsqrt(self): | |
|
||
|
||
class MathOpsCorrectnessTest(testing.TestCase, parameterized.TestCase): | ||
|
||
def test_erf_operation_basic(self): | ||
# Sample values for testing | ||
sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) | ||
|
||
# Expected output using numpy's approximation of the error function | ||
expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(sample_values) | ||
expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)( | ||
sample_values | ||
) | ||
|
||
# Output from the erf operation in keras_core | ||
output_from_erf_op = kmath.erf(sample_values).numpy() | ||
|
@@ -855,14 +856,22 @@ def test_erf_operation_basic(self): | |
def test_erf_operation_dtype(self): | ||
# Test for float32 and float64 data types | ||
for dtype in [np.float32, np.float64]: | ||
sqali marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], dtype=dtype) | ||
expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(sample_values) | ||
sample_values = np.array( | ||
[-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], dtype=dtype | ||
) | ||
expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)( | ||
sample_values | ||
) | ||
output_from_erf_op = kmath.erf(sample_values).numpy() | ||
self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5) | ||
|
||
def test_erf_operation_edge_cases(self): | ||
# Test for edge cases | ||
edge_values = np.array([1e10, -1e10, 1e-10, -1e-10], dtype=np.float64) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Your test values are too large. Try 1e5. This the source of the large discrepancy IMO. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have implemented the changes, but I can see from the tests that it is failing for the below array examples. I wonder if there is anything wrong in the implementation function itself.
|
||
expected_edge_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(edge_values) | ||
expected_edge_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)( | ||
edge_values | ||
) | ||
output_from_edge_erf_op = kmath.erf(edge_values).numpy() | ||
self.assertAllClose(expected_edge_output, output_from_edge_erf_op, atol=1e-5) | ||
self.assertAllClose( | ||
expected_edge_output, output_from_edge_erf_op, atol=1e-5 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just do
x = convert_to_tensor(x)
unconditionallyThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done