Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

host_callback.call fails on multi-gpu machine #5577

Closed
C-J-Cundy opened this issue Feb 1, 2021 · 11 comments
Closed

host_callback.call fails on multi-gpu machine #5577

C-J-Cundy opened this issue Feb 1, 2021 · 11 comments
Assignees
Labels
bug Something isn't working needs info More information is required to diagnose & prioritize the issue. NVIDIA GPU Issues specific to NVIDIA GPUs P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)

Comments

@C-J-Cundy
Copy link

If I run the following code:

from jax.experimental import host_callback
import numpy as np
from jax import pmap, jit, partial, ShapeDtypeStruct


def host_fn(x):
    return x


x = np.ones(4, dtype=np.float32)
host_callback.call(host_fn, x, result_shape=x)

on a 2-gpu machine then it crashes with the error message

2021-01-31 17:53:40.778121: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc:56] Check failed: ShapeUtil::Compatible(config_.input_shape, outfeed_buffers->shape()) XLA program outfeed request of shape (f32[4]) did not match the runtime's outfeed buffer of shape u32[2]

If I run with one GPU (by setting CUDA_VISIBLE_DEVICES=0) it finishes with no errors. Is there something I've missed in the documentation for host_callback about how it should be used on multi-device setups?

I ran both with the full debug information
CUDA_VISIBLE_DEVICES=0 TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,xfeed_manager=3 python test_2.py --verbosity=2 2> test_output_one_gpu.txt if that's helpful.
test_output_one_gpu.txt
test_output_two_gpu.txt

@froystig froystig added the bug Something isn't working label Feb 1, 2021
@gnecula
Copy link
Collaborator

gnecula commented Feb 2, 2021

I think that this is not specific to multi-GPU, but can happen even with one GPU (randomly). I think it is related to #4374.

There are two fixes possible: fix the implementation of outfeed for XLA:GPU, or replace the implementation mechanism for GPUs to use CustomCall (this is in progress).

@C-J-Cundy
Copy link
Author

Is this any closer to being fixed (or, ideas for a workaround?)
host_callback is a really great addition to jax. It's a bit frustrating that it's currently not possible for me to use it with multiple GPUs.

@gnecula
Copy link
Collaborator

gnecula commented Mar 18, 2021

There are two updates. It turns out that the infeed/outfeed in XLA:GPU is not so easy to fix for multi-GPU. So that hope has gotten dimmer.

The second update is more positive, we have a new implementation in the works for GPU, using XLA CustomCall. This means that the host_callback will be synchronous. This implementation was blocked on GPU due to another XLA bug that has been fixed. So the plan is to enable this second implementation mechanism, choosable with an environment variable and command-line flag. This change involves both Python and C++ and will take at least a couple of weeks to land. Sorry for the delay!

@AllanChain
Copy link

Sorry for the bump, but what's the current status of the second update?

@sudhakarsingh27 sudhakarsingh27 added NVIDIA GPU Issues specific to NVIDIA GPUs P0 (urgent) An issue of the highest priority. We are addressing this urgently. (Assignee required) labels Aug 10, 2022
@sharadmv
Copy link
Collaborator

Sorry for the bump, but what's the current status of the second update?

The custom call on GPU is now landed but not used in host callback quite yet. You can try out the new callback mechanism on GPU with jax.debug.print and we should be porting HCB to use the new custom call very soon.

@sudhakarsingh27 sudhakarsingh27 added P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) and removed P0 (urgent) An issue of the highest priority. We are addressing this urgently. (Assignee required) labels Aug 15, 2022
@mattjj
Copy link
Collaborator

mattjj commented Aug 24, 2022

@C-J-Cundy @AllanChain can you say more about your intended use case? For example, is it to have a callback for a debugging side-effect (like printing), or to perform some functionally pure numerical computation (on the host?), or something else? I ask because if it's one of those two applications we can recommend a replacement API (without having to wait for porting the HCB API to use a new implementation).

@mattjj mattjj added the needs info More information is required to diagnose & prioritize the issue. label Aug 24, 2022
@PhilipVinc
Copy link
Contributor

functionally pure numerical computation (on the host?),

What would be the replacement API in that case?

@sharadmv
Copy link
Collaborator

It is jax.pure_callback.

@AllanChain
Copy link

is it to have a callback for a debugging side-effect (like printing), or to perform some functionally pure numerical computation (on the host?)

Was the former. But I have figured out my problem and not waiting for this anymore.

@sharadmv
Copy link
Collaborator

For reference, the "callback for a debugging side-effect (like printing)" is jax.debug.callback

@gnecula
Copy link
Collaborator

gnecula commented Sep 13, 2024

Not sure if this is still an issue.

@gnecula gnecula closed this as completed Sep 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs info More information is required to diagnose & prioritize the issue. NVIDIA GPU Issues specific to NVIDIA GPUs P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)
Projects
None yet
Development

No branches or pull requests

8 participants