-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hot wiring xla'd kernels #4866
Comments
I think this is doable, and some folks at DeepMind may already do this. It might proceed slightly differently though, like
To get the protobuf, one can use cc @jlebar @tomhennigan @hawkinsp in case they can provide more specific tips |
Yeah! What @mattjj said. Write your cold code in C++ and your inner loops in Python. It is glorious. |
Just a little more detail on the XLA C++ client side... JAX itself (indirectly) uses the PJRT client (main header file here). It's built on top of the LocalClient, which you can also use directly, but I think PJRT will be handier because it includes extra runtime APIs for e.g. copying host data to devices and back. Feel free to reach out if you want help with the specifics. |
This is awesome - I will take a first pass at this over the break :) |
This might be another generally crazy but par-for-course-for-me issue, does anyone know how hard it would be to
Use case is of course, hot plugging it into a hot loop somewhere - okay with CPU only for now. Okay with going through any kind of necessary C++ hell to include/link into xla source etc.
It's totally okay to just tell me that this is not doable any time soon.
The text was updated successfully, but these errors were encountered: