Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hot wiring xla'd kernels #4866

Open
proteneer opened this issue Nov 10, 2020 · 4 comments
Open

Hot wiring xla'd kernels #4866

proteneer opened this issue Nov 10, 2020 · 4 comments
Labels
question Questions for the JAX team

Comments

@proteneer
Copy link
Contributor

This might be another generally crazy but par-for-course-for-me issue, does anyone know how hard it would be to

  1. Define a computation in jax
  2. Call jit to generate a compiled kernel
  3. Grab handle to said kernel somehow and call it directly from C++ to avoid python overhead?

Use case is of course, hot plugging it into a hot loop somewhere - okay with CPU only for now. Okay with going through any kind of necessary C++ hell to include/link into xla source etc.

It's totally okay to just tell me that this is not doable any time soon.

@proteneer proteneer changed the title Hot wiring xla'd GPU kernels Hot wiring xla'd kernels Nov 10, 2020
@mattjj
Copy link
Collaborator

mattjj commented Nov 11, 2020

I think this is doable, and some folks at DeepMind may already do this. It might proceed slightly differently though, like

  1. define a computation in jax
  2. get the XLA HLO program as a protobuf
  3. hand the protobuf to C++, compile and call it using the XLA C++ APIs.

To get the protobuf, one can use jax.xla_computation. This file shows one way to do it (in that case meant to be run at bazel build time, so that a C++ library can bazel-depend on the generated proto). I'm not sure if we have an examples of the C++ side, but it must be using the API in these XLA files.

cc @jlebar @tomhennigan @hawkinsp in case they can provide more specific tips

@jlebar
Copy link
Contributor

jlebar commented Nov 11, 2020

Yeah! What @mattjj said.

Write your cold code in C++ and your inner loops in Python. It is glorious.

@skye
Copy link
Member

skye commented Nov 11, 2020

Just a little more detail on the XLA C++ client side... JAX itself (indirectly) uses the PJRT client (main header file here). It's built on top of the LocalClient, which you can also use directly, but I think PJRT will be handier because it includes extra runtime APIs for e.g. copying host data to devices and back. Feel free to reach out if you want help with the specifics.

@proteneer
Copy link
Contributor Author

This is awesome - I will take a first pass at this over the break :)

@mattjj mattjj added the question Questions for the JAX team label Nov 24, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

4 participants