You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The function f2 is exactly the same as f1, except with the addition of id_print. Naively, I would expect these functions to be evaluated in exactly the same order, expect with some extra calls to id_tap injected. But as we can see from the JAXprs, that isn't what happens:
jvp without id_print:
{ lambda ; a b.
let c = integer_pow[ y=2 ] a
d = mul 2.0 a
e = mul b d
in (c, e) }
jvp with id_print:
{ lambda ; a b.
let c = mul 2.0 a
d = mul b c
e = integer_pow[ y=2 ] a
f = id_tap[ arg_treedef=*
func=<function _print_consumer at 0x7f3ee7d7f620>
nr_untapped=0
output_stream=None
threshold=None ] e
g h = id_tap[ arg_treedef=*
func=<function _print_consumer at 0x7f3ee7d7f620>
nr_untapped=1
output_stream=None
threshold=None
transforms=(('jvp',),) ] d f
in (f, g) }
Without id_print, primals are evaluated before tangents. But with id_print, tangents are evaluated first!
This is a perfectly way to calculate JVPs, of course, but it's a little worrisome for a debugging utility to change how compute happens. It's all the more worrisome because JVP are implemented with tracers, which I would not expect to change the order of function evaluation. I can imagine this resulting in some very frustrating debugging sessions, e.g., if code crashes only during the tangent calculation.
The text was updated successfully, but these errors were encountered:
Looks like this issue has been resolved in later versions of JAX. I executed the mentioned code on colab with JAX version 0.4.23. Now both the functions f1 and f2 are evaluated in exactly the same order with some extra calls to id_print injected for the function f2.
jvp without id_print:
{ lambda ; a:f32[] b:f32[]. let
c:f32[] = integer_pow[y=2] a
d:f32[] = integer_pow[y=1] a
e:f32[] = mul 2.0 d
f:f32[] = mul b e
in (c, f) }
jvp with id_print:
{ lambda ; a:f32[] b:f32[]. let
c:f32[] = integer_pow[y=2] a
d:f32[] = integer_pow[y=1] a
e:f32[] = mul 2.0 d
f:f32[] = mul b e
g:f32[] = outside_call[
arg_treedef=PyTreeDef(*)
callback=<jax.experimental.host_callback._CallbackWrapper object at 0x7d4237b05a80>
device_index=0
identity=True
] c
in (g, f) }
Since jax.experimental.host_callback is deprecated (#20385), I have tested with jax.debug.print and with it also the functions f1 and f2 are evaluated in the same order. Since jax.debug.print returns None, the function f2 also returns None here.
jvp without id_print:
{ lambda ; a:f32[] b:f32[]. let
c:f32[] = integer_pow[y=2] a
d:f32[] = integer_pow[y=1] a
e:f32[] = mul 2.0 d
f:f32[] = mul b e
in (c, f) }
jvp with id_print:
{ lambda ; a:f32[] b:f32[]. let
c:f32[] = integer_pow[y=2] a
d:f32[] = integer_pow[y=1] a
e:f32[] = mul 2.0 d
_:f32[] = mul b e
debug_callback[
callback=<function debug_callback.<locals>._flat_callback at 0x7d4235c7f5b0>
effect=Debug
] c
in () }
(Forked from #3127)
Consider the following example:
The function
f2
is exactly the same asf1
, except with the addition ofid_print
. Naively, I would expect these functions to be evaluated in exactly the same order, expect with some extra calls toid_tap
injected. But as we can see from the JAXprs, that isn't what happens:Without id_print, primals are evaluated before tangents. But with id_print, tangents are evaluated first!
This is a perfectly way to calculate JVPs, of course, but it's a little worrisome for a debugging utility to change how compute happens. It's all the more worrisome because JVP are implemented with tracers, which I would not expect to change the order of function evaluation. I can imagine this resulting in some very frustrating debugging sessions, e.g., if code crashes only during the tangent calculation.
The text was updated successfully, but these errors were encountered: