-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Guard against unintentional transfers to CPU #23215
Comments
For 1) you can use https://jax.readthedocs.io/en/latest/transfer_guard.html For 2) You can add |
Thanks for the quick response, for (1) transfer_guard looks perfect. For (2), I may do that but don't really want to keep around a bunch of checks. I have my own wrapper around the compiled function anyways, and can add the is_deleted check automatically so may just do that. I think in general that isn't the tool I was hoping for, I am going to dig around to see what kind of allocation stats I can find and if there are good allocation stats available I can use that to make sure that the memory allocated at the end of each iteration of the loop is not more than the memory allocated at the start. |
For future context, I worked around the issue by using the suggestion and asserting that This approach took a while but in the end gave me confidence that the buffers are being properly donated. If a future change is implemented that causes arrays to always be flagged as deleted even if the buffer is not donated then these checks will no longer work. Word of warning to anyone going down this path, the buffer donation logic when pytrees are involved is a lot less obvious and it took a while to get it so that all my buffers were being properly donated. The warning message does not always work when pytrees are involved. I am closing this issue since it is effectively resolved. |
Closing per previous comment. |
In order to work around long compile times and hidden re-compilations, a bunch of JIT compiled logic can be broken into a sequence of smaller functions.
One of the implications of this is that the control flow logic has to be host-side, since using jax.lax control flow functions triggers results in compilation of the target condition/body functions which is will effectively just compiling the original function. For example, if you were to break a function
f
into two functionsf1
andf2
your code may look something like this with the loop running on the host:Making this work efficiently requires a few extra considerations:
This works but is fragile and it can be difficult to protect against regressions without burdensome performance testing. One thing that works great is that (3.) can be dealt with easily through AOT compilation, which will throw an exception whenever a regression is introduced that mistakenly changes the type of a return value (e.g. an array in state is initialized to [1] with a data type of int32 when it should have been initialized as [1.] with data type of float32).
For use cases like this, I think that it would be useful to have mechanisms that would protect against the following two cases:
One approach would be to introducing new contexts that prevent these from happening inside a critical section of code. For example, with (2.) it could look something like the following (please ignore variable names, I did not put much thought into them):
Is this something that had been thought about or considered? I am not sure if this is an issue that other people have faced or not, and figured it was worth bringing up.
The text was updated successfully, but these errors were encountered: