Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

io_callback does not work with custom_vjp #23614

Closed
GeophyAI opened this issue Sep 13, 2024 · 23 comments
Closed

io_callback does not work with custom_vjp #23614

GeophyAI opened this issue Sep 13, 2024 · 23 comments
Assignees
Labels
enhancement New feature or request

Comments

@GeophyAI
Copy link

I want to save some data to disk in the forward pass, and reload them in the backward pass, and I found that only jvp example is provided in the doc. The question is can we use and how to use the io_callback with custom_vjp? The following is my implementation, but it does not work, could anyone help me with this?

def f(x):
    return x*2+1

def fwd(x, cstep):
    y = f(x)
    io_callback(save2file, x, x, cstep)
    return y, None

def bwd(res, g):
    x, cstep = res
    io_callback(loadfromfile, x, x, cstep)
    return g, None

f = jax.custom_vjp(f)
f.defvjp(fwd, bwd)

def save2file(data, cstep):
    jnp.save(f'tmp/x{cstep}.npy', data)
    return data

def loadfromfile(cstep):
    return jnp.load(f'tmp/x{cstep}.npy')

def step(x, cstep):
    # _f.defvjp(fwd, bwd)
    y = f(x)
    return y, None

x = jnp.array([1., 2.])
nt = 5

# loop over using scan
initial_carry = x
csteps = jnp.arange(10)
final_carry, _ = lax.scan(step, initial_carry, jnp.arange(nt))

def loss(x):
    return jnp.sum(x)
# compute grad
@jax.jit
def cal_grad(x):
    return jax.value_and_grad(loss)(x)

loss, grad = cal_grad(final_carry)
@GeophyAI GeophyAI added the enhancement New feature or request label Sep 13, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Sep 13, 2024

io_callback has no autodiff rule, and so any time you use it in autodiff, you need to wrap it in custom_vjp or custom_jvp. In your case, you used io_callback in the fwd and bwd rules, not in f, and so when the autodiff transformation hits it, it doesn't know what to do.

If this is what you're hoping to do (use io_callback in the fwd and bwd rule of another function), then both the fwd function and bwd function would also have to be wrapped in custom_jvp or custom_vjp, each with their own custom autodiff rules describing the gradient of the callback functions you're using.

Does that make sense?

@jakevdp jakevdp self-assigned this Sep 13, 2024
@mattjj
Copy link
Member

mattjj commented Sep 13, 2024

In the pasted code, f isn't being run under autodiff. The function we're differentiating, loss, just calls jnp.sum. So there's no way we would call the autodiff rules associated with f.

@mattjj
Copy link
Member

mattjj commented Sep 13, 2024

and I found that only jvp example is provided in the doc

Did you find this tutorial? There are several custom_vjp examples in there, like this one on gradient clipping.

def f(x):
    return x*2+1

def fwd(x, cstep):
    y = f(x)
    io_callback(save2file, x, x, cstep)
    return y, None

fwd and f need to have the same argument signature, but here fwd takes an extra argument cstep.

def bwd(res, g):
    x, cstep = res
    io_callback(loadfromfile, cstep)
    return g, None

The first input of bwd comes from the second output of fwd, but here fwd is just returning a None as its second output.

@GeophyAI
Copy link
Author

GeophyAI commented Sep 13, 2024

and I found that only jvp example is provided in the doc

Did you find this tutorial? There are several custom_vjp examples in there, like this one on gradient clipping.

def f(x):
    return x*2+1

def fwd(x, cstep):
    y = f(x)
    io_callback(save2file, x, x, cstep)
    return y, None

fwd and f need to have the same argument signature, but here fwd takes an extra argument cstep.

def bwd(res, g):
    x, cstep = res
    io_callback(loadfromfile, cstep)
    return g, None

The first input of bwd comes from the second output of fwd, but here fwd is just returning a None as its second output.

Hi, Thank you for your reply. I have read your mentioned doc before, I mean I can not find information about how to combine the usage of custom_vjp and io_callback. The doc decribes that jax.experimental.io_callback(): appropriate for impure functions: e.g. functions which read or write data to disk., but I cannot find some example.

@GeophyAI
Copy link
Author

and I found that only jvp example is provided in the doc

Did you find this tutorial? There are several custom_vjp examples in there, like this one on gradient clipping.

def f(x):
    return x*2+1

def fwd(x, cstep):
    y = f(x)
    io_callback(save2file, x, x, cstep)
    return y, None

fwd and f need to have the same argument signature, but here fwd takes an extra argument cstep.

def bwd(res, g):
    x, cstep = res
    io_callback(loadfromfile, cstep)
    return g, None

The first input of bwd comes from the second output of fwd, but here fwd is just returning a None as its second output.

io_callback has no autodiff rule, and so any time you use it in autodiff, you need to wrap it in custom_vjp or custom_jvp. In your case, you used io_callback in the fwd and bwd rules, not in f, and so when the autodiff transformation hits it, it doesn't know what to do.

If this is what you're hoping to do (use io_callback in the fwd and bwd rule of another function), then both the fwd function and bwd function would also have to be wrapped in custom_jvp or custom_vjp, each with their own custom autodiff rules describing the gradient of the callback functions you're using.

Does that make sense?

io_callback has no autodiff rule, and so any time you use it in autodiff, you need to wrap it in custom_vjp or custom_jvp. In your case, you used io_callback in the fwd and bwd rules, not in f, and so when the autodiff transformation hits it, it doesn't know what to do.

If this is what you're hoping to do (use io_callback in the fwd and bwd rule of another function), then both the fwd function and bwd function would also have to be wrapped in custom_jvp or custom_vjp, each with their own custom autodiff rules describing the gradient of the callback functions you're using.

Does that make sense?

Thank you for you reply. According to your comments, I modified my codes, it still cannot write to disk as expected. The following is my new implementation, is there any misunderstading from my side?

import jax
import jax.numpy as jnp
from jax import lax
from jax.experimental import io_callback
import os

def f(x, cstep):
    return x * 2 + 1

# Save data to file
def save2file(data, cstep):
    jnp.save(f'tmp/x{cstep}.npy', data)
    return data

# Load data from file
def loadfromfile(cstep):
    return jnp.load(f'tmp/x{cstep}.npy')

@jax.custom_vjp
def fwd(x, cstep):
    y = f(x)
    # io_callback(save2file, x, x, cstep)
    return y

def fwd_fwd(x, cstep):
    y = f(x)
    io_callback(save2file, x, x, cstep)
    return y, (x, cstep)

def fwd_bwd(res, g):
    x, cstep = res
    # x_loaded = loadfromfile(cstep)
    return (g * 2,), None

fwd.defvjp(fwd_fwd, fwd_bwd)

@jax.custom_vjp
def bwd(res, g):
    x, cstep = res
    # x_loaded = loadfromfile(cstep)
    return (g * 2,)

def bwd_fwd(res, g):
    x, cstep = res
    # x_loaded = loadfromfile(cstep)
    return g, (x, cstep)

def bwd_bwd(res, g):
    return None

bwd.defvjp(bwd_fwd, bwd_bwd)

f_vjp = jax.custom_vjp(f)
f_vjp.defvjp(fwd, bwd)

def step(x, cstep):
    y = f_vjp(x, cstep)
    return y, None

x = jnp.array([1., 2.])
nt = 5

initial_carry = x
csteps = jnp.arange(10)
final_carry, _ = lax.scan(step, initial_carry, jnp.arange(nt))

def loss(x):
    return jnp.sum(x)

@jax.jit
def cal_grad(x):
    return jax.value_and_grad(loss)(x)

loss_value, grad = cal_grad(final_carry)
print(f"Loss: {loss_value}, Gradient: {grad}")

assert os.path.exists('tmp/x0.npy')

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 13, 2024

Hi - I think I may have been unclear in my suggestion. The problem previsously was that you were calling io_callback in fwd and bwd, which do not have custom JVP/VJP rules, and so the autodiff machinery doesn't know how to trace through your callback.

In your update, you created custom autodiff rules for fwd and bwd, but now you are calling io_callback within fwd_fwd, which does not have a custom autodiff rule, and so the autodiff machinery still doesn't know how to trace through your callback.

This may be an XY problem: perhaps we should step back, and you can describe what larger problem you're trying to solve here?

Also, since you mentioned docs: I should mention that although we don't have documentation of io_callback with custom_jvp, we do have an example in the case of pure_callback: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp The io_callback conceptually will be more-or-less identical.

@GeophyAI
Copy link
Author

Hi, I misunstood your suggestions at that time. So I need to warp my callback functions with custom_vjp, i.e. both save2file and loadfromfile need their own fwd and bwd, right?

Actually, my functions f needs several inputs to calculate the final results z=f(x,y,z) iteratively, so I use a lax.scan here. Instead of return (x,y,z) in fwd at each step, I want to save (x,y,z) to disk in the foward pass to save GPU memory, then reload them from disk at every step and use jax.vjp to calculate the vjp_func and gradients.

@yashk2810
Copy link
Member

OOC, why are you saving the residuals to disk? What problem are you trying to solve? Maybe we can suggest something else?

@GeophyAI
Copy link
Author

GeophyAI commented Sep 13, 2024 via email

@yashk2810
Copy link
Member

Maybe offloading it to host is what you want here and you can achieve that via different API instead of using io_callback which are more efficient.

Take a look at this:

jax/tests/memories_test.py

Lines 1447 to 1488 in 0daca46

def test_remat_jaxpr_offloadable(self):
mesh = jtu.create_mesh((2,), ("x",))
inp = jax.device_put(np.arange(16.), NamedSharding(mesh, P("x")))
def policy(prim, *avals, **params):
return Offloadable(src="device", dst="pinned_host")
@functools.partial(remat, policy=policy)
def f(x):
x = jnp.sin(x)
x = jnp.sin(x)
x = jnp.sin(x)
return jnp.sum(x)
fwd_jaxpr, bwd_jaxpr = jtu.fwd_bwd_jaxprs(f, inp)
self.assertLen(fwd_jaxpr.out_avals, 4) # 1 output, 3 offloaded residuals
fwd_mem_kind_count = str(fwd_jaxpr).count(
"TransferToMemoryKind(memory_kind='pinned_host')")
self.assertEqual(fwd_mem_kind_count, 3)
self.assertLen(bwd_jaxpr.in_avals, 4) # 3 offloaded residuals, 1 input
bwd_mem_kind_count = str(bwd_jaxpr).count(
"TransferToMemoryKind(memory_kind='device')")
self.assertEqual(bwd_mem_kind_count, 3)
# Execution test.
f = jax.jit(jax.grad(f))
f(inp) # doesn't crash
compiled_f = f.lower(inp).compile()
compiled_text = compiled_f.as_text()
if compiled_text is not None:
self.assertIn('S(5)', compiled_text)
self.assertRegex(compiled_text, r"copy-start.*S\(5\)")
self.assertRegex(compiled_text, r"copy-done.*S\(5\)")
compiled_stats = compiled_f.memory_analysis()
if compiled_stats is not None:
if jtu.pjrt_c_api_version_at_least(0, 43):
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
to do this via remat.

You can also run manual device_put in your custom_vjp function on residuals to offload them to pinned_host memory space if you want and do the opposite on your bwd function to reload them.

@GeophyAI
Copy link
Author

Thank you, @yashk2810. Yes, transfer to host is another choice. Does the following codes make sense, how to know the device of a DynamicJaxprTracer ?


def f(x, y):
    z = x**2 + y**2
    return jnp.sum(z)

def fwd(x, y):
    return f(x, y), (jax.device_get(x), jax.device_get(y))

def bwd(res, g):
    _, vjp_fun = jax.vjp(f, *res)
    grads = vjp_fun(g)
    return grads

f_vjp = jax.custom_vjp(f)
f_vjp.defvjp(fwd, bwd)

f_vjp = jax.jit(f_vjp)

x = jnp.array([1.,2.])
y = jnp.array([2.,3.])

def loss1(x, y):
    return (f_vjp(x, y)**2).sum()
def loss2(x, y):
    return (f(x, y)**2).sum()

print('f=', f(x, y))
print('f_vjp=', f_vjp(x, y))
print(jax.grad(loss1)(x, y))
print(jax.grad(loss2)(x, y))

@yashk2810
Copy link
Member

No, like this:

def test_offload(self):
    def f(x, y):
      z = x**2 + y**2
      return jnp.sum(z)

    def fwd(x, y):
      return f(x, y), jax.device_put(
          (x, y),
          SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host'))

    def bwd(res, g):
      reloaded_res = jax.device_put(
          res, SingleDeviceSharding(jax.devices()[0], memory_kind='device'))
      _, vjp_fun = jax.vjp(f, *reloaded_res)
      grads = vjp_fun(g)
      return grads

    f_vjp = jax.custom_vjp(f)
    f_vjp.defvjp(fwd, bwd)

    f_vjp = jax.jit(f_vjp)

    x = jnp.array([1.,2.])
    y = jnp.array([2.,3.])

    def loss1(x, y):
      return (f_vjp(x, y)**2).sum()
    def loss2(x, y):
      return (f(x, y)**2).sum()

    print('f=', f(x, y))
    print('f_vjp=', f_vjp(x, y))
    print(jax.grad(loss1)(x, y))
    print(jax.grad(loss2)(x, y))

this is the output I got:

f= 18.0
f_vjp= 18.0
[ 72. 144.]
[ 72. 144.]

Note that I ran this on TPU. Which hardware are you planning on run?

@GeophyAI
Copy link
Author

@yashk2810, I'm coding on a GPU. Really thank you, I have implement this in my original codes, it does works for reducing the memory with slowing down the efficiency (I have to do that due to the OOM problem). I'm still new to JAX and I will try to find some strategy more efficient that can solve my problem.

@yashk2810
Copy link
Member

Yeah, this should work on a GPU too. The efficiency might be affected because you incur transfer costs.

Can you tell me how much of a slowdown you see? If you have enough compute, we should be able to hide the transfers behind compute.

@GeophyAI
Copy link
Author

GeophyAI commented Sep 14, 2024 via email

@yashk2810
Copy link
Member

cc @jaro-sevcik @nouiz can you improve the speed of transfers to pinned_host and back here? I wonder how fast this is on the latest GPUs.

@yashk2810
Copy link
Member

I ran your test internally on H100 GPU

but got these timings:

f= 18.0
f_vjp= 18.0
[ 72. 144.]
0.17064547538757324
[ 72. 144.]
0.17866754531860352

@GeophyAI
Copy link
Author

GeophyAI commented Sep 14, 2024 via email

@yashk2810
Copy link
Member

the values I provided before are based on my original implementation in inverse problem.

Yeah I got that. Can you try running the test here to see the timing difference? So we know if this is a GPU generation problem or not. In other words, if you run your original code on H100, what speed difference do you see?

@GeophyAI
Copy link
Author

For jax.grad(loss1)(x,y), it's 3.02 ms ± 7.83 μs per loop (mean ± std. dev. of 7 runs, 100 loops each),
For jax.grad(loss2)(x, y), it's 4.39 ms ± 34 μs per loop (mean ± std. dev. of 7 runs, 100 loops each).

@GeophyAI
Copy link
Author

@yashk2810, sorry, I donot have a H100, that's why I am doing so here.

@yashk2810
Copy link
Member

Yeah, I was just wondering how the GPU generation affected the speed of transfers :)

But I am glad that your problem is resolved. Feel free to close this issue if so :)

@GeophyAI
Copy link
Author

@yashk2810 Thank you very much. And thanks for everyone here.

This issue was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants