-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
io_callback does not work with custom_vjp #23614
Comments
If this is what you're hoping to do (use Does that make sense? |
In the pasted code, |
Did you find this tutorial? There are several def f(x):
return x*2+1
def fwd(x, cstep):
y = f(x)
io_callback(save2file, x, x, cstep)
return y, None
def bwd(res, g):
x, cstep = res
io_callback(loadfromfile, cstep)
return g, None The first input of |
Hi, Thank you for your reply. I have read your mentioned doc before, I mean I can not find information about how to combine the usage of |
Thank you for you reply. According to your comments, I modified my codes, it still cannot write to disk as expected. The following is my new implementation, is there any misunderstading from my side?
|
Hi - I think I may have been unclear in my suggestion. The problem previsously was that you were calling In your update, you created custom autodiff rules for This may be an XY problem: perhaps we should step back, and you can describe what larger problem you're trying to solve here? Also, since you mentioned docs: I should mention that although we don't have documentation of |
Hi, I misunstood your suggestions at that time. So I need to warp my callback functions with custom_vjp, i.e. both Actually, my functions |
OOC, why are you saving the residuals to disk? What problem are you trying to solve? Maybe we can suggest something else? |
I am working on full waveform inversion, an inverse problem based on wave equation. I need to perform a time step forward scheme for calculating the final results and compare it with the target. I don’t want to return the inputs of fwd for bwd, I want to save and reload them. I think it’s something like checkpoint, but with a save_to_disk version.
…---- Replied Message ----
| From | Yash ***@***.***> |
| Date | 09/14/2024 06:53 |
| To | google/jax ***@***.***> |
| Cc | Shaowen ***@***.***>,
Author ***@***.***> |
| Subject | Re: [google/jax] io_callback does not work with custom_vjp (Issue #23614) |
OOC, why are you saving the residuals to disk? What problem are you trying to solve? Maybe we can suggest something else?
—
Reply to this email directly, view it on GitHub, or unsubscribe.
You are receiving this because you authored the thread.Message ID: ***@***.***>
|
Maybe offloading it to host is what you want here and you can achieve that via different API instead of using io_callback which are more efficient. Take a look at this: Lines 1447 to 1488 in 0daca46
remat .
You can also run manual |
Thank you, @yashk2810. Yes, transfer to host is another choice. Does the following codes make sense, how to know the device of a
|
No, like this:
this is the output I got:
Note that I ran this on TPU. Which hardware are you planning on run? |
@yashk2810, I'm coding on a GPU. Really thank you, I have implement this in my original codes, it does works for reducing the memory with slowing down the efficiency (I have to do that due to the OOM problem). I'm still new to JAX and I will try to find some strategy more efficient that can solve my problem. |
Yeah, this should work on a GPU too. The efficiency might be affected because you incur transfer costs. Can you tell me how much of a slowdown you see? If you have enough compute, we should be able to hide the transfers behind compute. |
It's 3.7s/it v.s 0.7s/it on a RTX 3090, around 5x slower. I think it's acceptable for me, since without this the code can even not run.
Sincerely,
Shaowen
---- Replied Message ----
| From | Yash ***@***.***> |
| Date | 9/14/2024 10:11 |
| To | ***@***.***> |
| Cc | Shaowen ***@***.***>,
***@***.***> |
| Subject | Re: [google/jax] io_callback does not work with custom_vjp (Issue #23614) |
Yeah, this should work on a GPU too. The efficiency might be affected because you incur transfer costs.
Can you tell me how much of a slowdown you see? If you have enough compute, we should be able to hide the transfers behind compute.
—
Reply to this email directly, view it on GitHub, or unsubscribe.
You are receiving this because you authored the thread.Message ID: ***@***.***>
|
cc @jaro-sevcik @nouiz can you improve the speed of transfers to pinned_host and back here? I wonder how fast this is on the latest GPUs. |
I ran your test internally on H100 GPU but got these timings:
|
Oh sorry, I thought you mentioned my original codes, the values I provided before are based on my original implementation in inverse problem.
Sincerely,
Shaowen
---- Replied Message ----
| From | Yash ***@***.***> |
| Date | 9/14/2024 10:33 |
| To | ***@***.***> |
| Cc | Shaowen ***@***.***>,
***@***.***> |
| Subject | Re: [google/jax] io_callback does not work with custom_vjp (Issue #23614) |
I ran your test internally on H100 GPU
but got these timings:
f= 18.0
f_vjp= 18.0
[ 72. 144.]
0.17064547538757324
[ 72. 144.]
0.17866754531860352
—
Reply to this email directly, view it on GitHub, or unsubscribe.
You are receiving this because you authored the thread.Message ID: ***@***.***>
|
Yeah I got that. Can you try running the test here to see the timing difference? So we know if this is a GPU generation problem or not. In other words, if you run your original code on H100, what speed difference do you see? |
For |
@yashk2810, sorry, I donot have a H100, that's why I am doing so here. |
Yeah, I was just wondering how the GPU generation affected the speed of transfers :) But I am glad that your problem is resolved. Feel free to close this issue if so :) |
@yashk2810 Thank you very much. And thanks for everyone here. |
I want to save some data to disk in the forward pass, and reload them in the backward pass, and I found that only jvp example is provided in the doc. The question is can we use and how to use the io_callback with custom_vjp? The following is my implementation, but it does not work, could anyone help me with this?
The text was updated successfully, but these errors were encountered: