Skip to content

Commit

Permalink
[JAX] Include pre-transformed stack traces as additional context to J…
Browse files Browse the repository at this point in the history
…AX exceptions, where present.

PiperOrigin-RevId: 371695248
  • Loading branch information
hawkinsp authored and jax authors committed May 3, 2021
1 parent 75b00a1 commit 2c92bc9
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 7 deletions.
34 changes: 34 additions & 0 deletions jax/_src/source_info_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,45 @@ def __init__(self):
def current() -> Optional[Traceback]:
return _source_info_context.context or xla_client.Traceback.get_traceback()

class JaxStackTraceBeforeTransformation(Exception): pass

_message = (
'The preceding stack trace is the source of the JAX operation that, once '
'transformed by JAX, triggered the following exception.\n'
'\n--------------------')

def has_user_context(e):
while e is not None:
if isinstance(e, JaxStackTraceBeforeTransformation):
return True
e = e.__cause__
return False

@contextlib.contextmanager
def user_context(c):
prev = _source_info_context.context
_source_info_context.context = c or _source_info_context.context
filtered_tb = None
try:
yield
except Exception as e:
if c is None or has_user_context(e):
raise
# TODO(phawkins): remove the following condition after Jaxlib 0.1.66 is the
# minimum.
if not hasattr(c, 'as_python_traceback'):
raise
filtered_tb = traceback_util.filter_traceback(c.as_python_traceback())
if filtered_tb:
msg = traceback_util.format_exception_only(e)
msg = f'{msg}\n\n{_message}'
c = JaxStackTraceBeforeTransformation(msg).with_traceback(filtered_tb)
c.__context__ = e.__context__
c.__cause__ = e.__cause__
c.__suppress_context__ = e.__suppress_context__
e.__context__ = None
e.__cause__ = c
raise
finally:
_source_info_context.context = prev
del filtered_tb
12 changes: 8 additions & 4 deletions jax/_src/traceback_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,16 @@ def include_frame(f):
def ignore_known_hidden_frame(f):
return 'importlib._bootstrap' in f.f_code.co_filename

def filter_traceback_and_stack(tb):
def filter_traceback(tb):
out = None
# Scan the traceback and collect relevant frames.
for f, lineno in reversed(list(traceback.walk_tb(tb))):
if include_frame(f) or out is None:
frames = list(traceback.walk_tb(tb))
for f, lineno in reversed(frames):
if include_frame(f):
out = make_traceback(out, f, f.f_lasti, lineno) # pytype: disable=wrong-arg-count
if out is None and len(frames) > 0:
f, lineno = frames[-1]
out = make_traceback(out, f, f.f_lasti, lineno)
return out

def add_call_stack_frames(tb):
Expand Down Expand Up @@ -141,7 +145,7 @@ def reraise_with_filtered_traceback(*args, **kwargs):
if not is_under_reraiser(e):
filtered_tb, unfiltered = None, None
try:
filtered_tb = filter_traceback_and_stack(e.__traceback__)
filtered_tb = filter_traceback(e.__traceback__)
if filtered_tb is None:
raise
msg = format_exception_only(e)
Expand Down
25 changes: 22 additions & 3 deletions tests/errors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from jax import core, grad, jit, vmap, lax
import jax.numpy as jnp
from jax import test_util as jtu
from jax._src import source_info_util
from jax._src import traceback_util
from jax.lib import xla_extension

Expand Down Expand Up @@ -51,8 +52,8 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=[]):
if frame_patterns:
for (fname_pat, line_pat), frame_fmt in zip(
reversed(frame_patterns), reversed(c_tb)):
file = '.*' if fname_pat is None else re.escape(__file__)
fname_pat = '.*' if fname_pat is None else re.escape(fname_pat)
file = re.escape(__file__)
fname_pat = re.escape(fname_pat)
line_pat = re.escape(line_pat)
full_pat = (
f' File "{file}", line ' r'[0-9]+'
Expand Down Expand Up @@ -126,7 +127,6 @@ def outermost(x):
('<lambda>', 'f = lambda: outermost'),
('outermost', 'return 2 + inbetween(x)'),
('inbetween', 'return 1 + grad(innermost)(x)'),
(None, 'raise TypeError'),
])

def test_lax_cond(self):
Expand Down Expand Up @@ -314,6 +314,25 @@ def outer(x):
self.assertIsInstance(e.__cause__.__cause__, ValueError)


class UserContextTracebackTest(jtu.JaxTestCase):

def test_grad_norm(self):
e = None
try:
with jax.debug_nans(True):
jax.grad(jnp.linalg.norm)(jnp.zeros((3, 3), jnp.float32))
except FloatingPointError as exc:
e = exc
self.assertIsNot(e, None)
self.assertIn("invalid value", str(e))
# TODO(phawkins): make this test unconditional after jaxlib 0.1.66 is the
# minimum.
if jax.lib._xla_extension_version >= 19:
self.assertIsInstance(
e.__cause__.__cause__,
source_info_util.JaxStackTraceBeforeTransformation)


class CustomErrorsTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(errorclass), "errorclass": errorclass}
Expand Down

0 comments on commit 2c92bc9

Please sign in to comment.