-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
host_callback.call
fails on multi-gpu machine
#5577
Comments
I think that this is not specific to multi-GPU, but can happen even with one GPU (randomly). I think it is related to #4374. There are two fixes possible: fix the implementation of outfeed for XLA:GPU, or replace the implementation mechanism for GPUs to use CustomCall (this is in progress). |
Is this any closer to being fixed (or, ideas for a workaround?) |
There are two updates. It turns out that the infeed/outfeed in XLA:GPU is not so easy to fix for multi-GPU. So that hope has gotten dimmer. The second update is more positive, we have a new implementation in the works for GPU, using XLA CustomCall. This means that the host_callback will be synchronous. This implementation was blocked on GPU due to another XLA bug that has been fixed. So the plan is to enable this second implementation mechanism, choosable with an environment variable and command-line flag. This change involves both Python and C++ and will take at least a couple of weeks to land. Sorry for the delay! |
Sorry for the bump, but what's the current status of the second update? |
The custom call on GPU is now landed but not used in host callback quite yet. You can try out the new callback mechanism on GPU with |
@C-J-Cundy @AllanChain can you say more about your intended use case? For example, is it to have a callback for a debugging side-effect (like printing), or to perform some functionally pure numerical computation (on the host?), or something else? I ask because if it's one of those two applications we can recommend a replacement API (without having to wait for porting the HCB API to use a new implementation). |
What would be the replacement API in that case? |
It is |
Was the former. But I have figured out my problem and not waiting for this anymore. |
For reference, the "callback for a debugging side-effect (like printing)" is |
Not sure if this is still an issue. |
If I run the following code:
on a 2-gpu machine then it crashes with the error message
If I run with one GPU (by setting
CUDA_VISIBLE_DEVICES=0
) it finishes with no errors. Is there something I've missed in the documentation for host_callback about how it should be used on multi-device setups?I ran both with the full debug information
CUDA_VISIBLE_DEVICES=0 TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,xfeed_manager=3 python test_2.py --verbosity=2 2> test_output_one_gpu.txt
if that's helpful.test_output_one_gpu.txt
test_output_two_gpu.txt
The text was updated successfully, but these errors were encountered: