Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Guard against unintentional transfers to CPU #23215

Open
kxs-dhenshall opened this issue Aug 23, 2024 · 4 comments
Open

Guard against unintentional transfers to CPU #23215

kxs-dhenshall opened this issue Aug 23, 2024 · 4 comments
Labels
enhancement New feature or request

Comments

@kxs-dhenshall
Copy link

In order to work around long compile times and hidden re-compilations, a bunch of JIT compiled logic can be broken into a sequence of smaller functions.

One of the implications of this is that the control flow logic has to be host-side, since using jax.lax control flow functions triggers results in compilation of the target condition/body functions which is will effectively just compiling the original function. For example, if you were to break a function f into two functions f1 and f2 your code may look something like this with the loop running on the host:

state = compute_initial_state()

for i in range(n) :
    state = fn1(state)
    state = fn2(state)

print(state.result)

Making this work efficiently requires a few extra considerations:

  1. Avoiding GPU-to-CPU transfer between calls to the JAX-compiled functions
  2. Avoid re-allocations as much as possible through aggressive donation of arguments
  3. Avoid hidden recompilation

This works but is fragile and it can be difficult to protect against regressions without burdensome performance testing. One thing that works great is that (3.) can be dealt with easily through AOT compilation, which will throw an exception whenever a regression is introduced that mistakenly changes the type of a return value (e.g. an array in state is initialized to [1] with a data type of int32 when it should have been initialized as [1.] with data type of float32).

For use cases like this, I think that it would be useful to have mechanisms that would protect against the following two cases:

  1. A GPU-to-CPU transfer mistakenly took place within the critical loop, because host-side logic accessed a GPU variable.
  2. An argument that should have been donated was not donated (e.g. could be introduced by changing of a method signature without updating the donation_argnames)

One approach would be to introducing new contexts that prevent these from happening inside a critical section of code. For example, with (2.) it could look something like the following (please ignore variable names, I did not put much thought into them):

state = compute_initial_state()
state = jax.device_put(state)

with jax.disable_transfer_to_cpu() :
    # Within this block, an exception will be thrown
    compiled_fn1 = compile(fn1, state)
    compiled_fn2 = compile(fn2, state)

    # This would throw an exception, since printing state.result would trigger
    # a transfer from GPU to CPU
    # print(state.result)

# This would not trigger an exception, since it is outside the protected block above
print(state.result)

Is this something that had been thought about or considered? I am not sure if this is an issue that other people have faced or not, and figured it was worth bringing up.

@kxs-dhenshall kxs-dhenshall added the enhancement New feature or request label Aug 23, 2024
@yashk2810
Copy link
Member

For 1) you can use https://jax.readthedocs.io/en/latest/transfer_guard.html

For 2) You can add arr.is_deleted() check to make sure donation was successful? But note that in the future, we will delete the input array regardless of where donation was successful or not (which will help you?)

@kxs-dhenshall
Copy link
Author

Thanks for the quick response, for (1) transfer_guard looks perfect.

For (2), I may do that but don't really want to keep around a bunch of checks. I have my own wrapper around the compiled function anyways, and can add the is_deleted check automatically so may just do that. I think in general that isn't the tool I was hoping for, I am going to dig around to see what kind of allocation stats I can find and if there are good allocation stats available I can use that to make sure that the memory allocated at the end of each iteration of the loop is not more than the memory allocated at the start.

@kxs-dhenshall
Copy link
Author

For future context, I worked around the issue by using the suggestion and asserting that arr.is_deleted() after every call to the JIT'tted function while a do-extra-validation flag was enabled.

This approach took a while but in the end gave me confidence that the buffers are being properly donated. If a future change is implemented that causes arrays to always be flagged as deleted even if the buffer is not donated then these checks will no longer work.

Word of warning to anyone going down this path, the buffer donation logic when pytrees are involved is a lot less obvious and it took a while to get it so that all my buffers were being properly donated. The warning message does not always work when pytrees are involved.

I am closing this issue since it is effectively resolved.

@kxs-dhenshall
Copy link
Author

Closing per previous comment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants