Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding host_callbacks.id_tap reorders JVP evaluation #3198

Open
shoyer opened this issue May 24, 2020 · 1 comment
Open

Adding host_callbacks.id_tap reorders JVP evaluation #3198

shoyer opened this issue May 24, 2020 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@shoyer
Copy link
Collaborator

shoyer commented May 24, 2020

(Forked from #3127)

Consider the following example:

import jax
from jax.experimental import host_callback

def f1(x):
  y = x ** 2
  return y

def f2(x):
  y = x ** 2
  y = host_callback.id_print(y)
  return y

print('jvp without id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f1), (x,), (y,)))(0.0, 0.0))

print('\njvp with id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f2), (x,), (y,)))(0.0, 0.0))

The function f2 is exactly the same as f1, except with the addition of id_print. Naively, I would expect these functions to be evaluated in exactly the same order, expect with some extra calls to id_tap injected. But as we can see from the JAXprs, that isn't what happens:

jvp without id_print:
{ lambda  ; a b.
  let c = integer_pow[ y=2 ] a
      d = mul 2.0 a
      e = mul b d
  in (c, e) }

jvp with id_print:
{ lambda  ; a b.
  let c = mul 2.0 a
      d = mul b c
      e = integer_pow[ y=2 ] a
      f = id_tap[ arg_treedef=*
                  func=<function _print_consumer at 0x7f3ee7d7f620>
                  nr_untapped=0
                  output_stream=None
                  threshold=None ] e
      g h = id_tap[ arg_treedef=*
                    func=<function _print_consumer at 0x7f3ee7d7f620>
                    nr_untapped=1
                    output_stream=None
                    threshold=None
                    transforms=(('jvp',),) ] d f
  in (f, g) }

Without id_print, primals are evaluated before tangents. But with id_print, tangents are evaluated first!

This is a perfectly way to calculate JVPs, of course, but it's a little worrisome for a debugging utility to change how compute happens. It's all the more worrisome because JVP are implemented with tracers, which I would not expect to change the order of function evaluation. I can imagine this resulting in some very frustrating debugging sessions, e.g., if code crashes only during the tangent calculation.

@gnecula gnecula self-assigned this May 26, 2020
@hawkinsp hawkinsp added the bug Something isn't working label May 26, 2020
@rajasekharporeddy
Copy link
Contributor

rajasekharporeddy commented Mar 24, 2024

Hi @shoyer

Looks like this issue has been resolved in later versions of JAX. I executed the mentioned code on colab with JAX version 0.4.23. Now both the functions f1 and f2 are evaluated in exactly the same order with some extra calls to id_print injected for the function f2.

import jax
from jax.experimental import host_callback

def f1(x):
  y = x ** 2
  return y

def f2(x):
  y = x ** 2
  y = host_callback.id_print(y)
  return y

print('jvp without id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f1), (x,), (y,)))(0.0, 0.0))

print('\njvp with id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f2), (x,), (y,)))(0.0, 0.0))

Output:

jvp without id_print:
{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = integer_pow[y=2] a
    d:f32[] = integer_pow[y=1] a
    e:f32[] = mul 2.0 d
    f:f32[] = mul b e
  in (c, f) }

jvp with id_print:
{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = integer_pow[y=2] a
    d:f32[] = integer_pow[y=1] a
    e:f32[] = mul 2.0 d
    f:f32[] = mul b e
    g:f32[] = outside_call[
      arg_treedef=PyTreeDef(*)
      callback=<jax.experimental.host_callback._CallbackWrapper object at 0x7d4237b05a80>
      device_index=0
      identity=True
    ] c
  in (g, f) }

Since jax.experimental.host_callback is deprecated (#20385), I have tested with jax.debug.print and with it also the functions f1 and f2 are evaluated in the same order. Since jax.debug.print returns None, the function f2 also returns None here.

import jax


def f1(x):
  y = x ** 2
  return y


def f2(x):
  y = x ** 2
  y = jax.debug.print("{}", y)
  return y

print('jvp without id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f1), (x,), (y,)))(0.0, 0.0))

print('\njvp with id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f2), (x,), (y,)))(0.0, 0.0))

Output:

jvp without id_print:
{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = integer_pow[y=2] a
    d:f32[] = integer_pow[y=1] a
    e:f32[] = mul 2.0 d
    f:f32[] = mul b e
  in (c, f) }

jvp with id_print:
{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = integer_pow[y=2] a
    d:f32[] = integer_pow[y=1] a
    e:f32[] = mul 2.0 d
    _:f32[] = mul b e
    debug_callback[
      callback=<function debug_callback.<locals>._flat_callback at 0x7d4235c7f5b0>
      effect=Debug
    ] c
  in () }

Please find the gist for reference.

Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants