-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added erf op to math.py #908
Changes from 1 commit
6485433
61c9690
7960106
d9f99f2
ef05f70
8d4b8eb
e829832
686f339
6b52b8a
7f76459
5efabf5
ebe06c3
3f199c3
0307b74
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -239,3 +239,7 @@ def istft( | |
|
||
def rsqrt(x): | ||
return tf.math.rsqrt(x) | ||
|
||
|
||
def erf(x): | ||
return tf.math.erf(x) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -929,3 +929,38 @@ def rsqrt(x): | |
return Rsqrt().symbolic_call(x) | ||
x = backend.convert_to_tensor(x) | ||
return backend.math.rsqrt(x) | ||
|
||
|
||
class Erf(Operation): | ||
"""Computes the error function of x element-wise. | ||
|
||
Args: | ||
input_tensor: A tensor of type `float32` or `float64`. | ||
|
||
Returns: | ||
A tensor of the same shape and type as `input_tensor`, containing the error function values. | ||
|
||
Examples: | ||
|
||
# Basic usage | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you're not printing any outputs, just use a fenced code block for the code example. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
>>> x = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) | ||
>>> y = Erf()(x) | ||
# Using `float32` data type | ||
>>> x_float32 = np.array([-3.0, -2.0], dtype=np.float32) | ||
>>> y_float32 = Erf()(x_float32) | ||
# Using large values | ||
>>> x_large = np.array([1e10, -1e10]) | ||
>>> y_large = Erf()(x_large) | ||
""" | ||
def __init__(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not needed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed it |
||
super().__init__() | ||
|
||
def compute_output_spec(self, input_tensor): | ||
return KerasTensor(shape=input_tensor.shape, dtype=input_tensor.dtype) | ||
|
||
def call(self, input_tensor): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just x There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replaced with x |
||
return backend.erf(input_tensor) | ||
|
||
def erf(x): | ||
"""Functional interface to the `Erf` operation.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is where the docstring should be, not the op above, since this is the public symbol. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Relocated it! |
||
return Erf()(x) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -835,3 +835,34 @@ def test_rsqrt(self): | |
x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float32") | ||
self.assertAllClose(kmath.rsqrt(x), 1 / np.sqrt(x)) | ||
self.assertAllClose(kmath.Rsqrt()(x), 1 / np.sqrt(x)) | ||
|
||
|
||
class MathOpsCorrectnessTest(testing.TestCase, parameterized.TestCase): | ||
|
||
def test_erf_operation_basic(self): | ||
# Sample values for testing | ||
sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) | ||
|
||
# Expected output using numpy's approximation of the error function | ||
expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(sample_values) | ||
|
||
# Output from the erf operation in keras_core | ||
output_from_erf_op = kmath.erf(sample_values).numpy() | ||
|
||
# Assert that the outputs are close | ||
self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5) | ||
|
||
def test_erf_operation_dtype(self): | ||
# Test for float32 and float64 data types | ||
for dtype in [np.float32, np.float64]: | ||
sqali marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], dtype=dtype) | ||
expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(sample_values) | ||
output_from_erf_op = kmath.erf(sample_values).numpy() | ||
self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5) | ||
|
||
def test_erf_operation_edge_cases(self): | ||
# Test for edge cases | ||
edge_values = np.array([1e10, -1e10, 1e-10, -1e-10], dtype=np.float64) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Your test values are too large. Try 1e5. This the source of the large discrepancy IMO. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have implemented the changes, but I can see from the tests that it is failing for the below array examples. I wonder if there is anything wrong in the implementation function itself.
|
||
expected_edge_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(edge_values) | ||
output_from_edge_erf_op = kmath.erf(edge_values).numpy() | ||
self.assertAllClose(expected_edge_output, output_from_edge_erf_op, atol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can have more types, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edited the comments based on that