Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port all custom calls to FFI #23056

Open
5 of 20 tasks
dfm opened this issue Aug 14, 2024 · 0 comments
Open
5 of 20 tasks

Port all custom calls to FFI #23056

dfm opened this issue Aug 14, 2024 · 0 comments
Assignees
Labels
enhancement New feature or request

Comments

@dfm
Copy link
Member

dfm commented Aug 14, 2024

This is an issue for tracking our progress towards porting all "custom calls" within JAX core to the new XLA FFI.

Motivation

This project is worthwhile for a few reasons, including:

  1. It gets us closer to being ready to deprecate the legacy custom call API and the related JAX internals.
  2. When ported to the FFI, custom calls support export, shape polymorphism, and ahead-of-time lowering without access to the relevant device.
  3. The FFI supports useful features like device-specific scratch space allocators and (soon) access to the XLA thread pool for CPU parallelization.
  4. The lowering logic can be significantly simplified, and in most cases the lowering rule is now the same for GPU and CPU.

Progress

CPU

The main CPU kernels are all in jaxlib/cpu and @pparuzel has been leading the work to port these kernels:

GPU

The GPU kernels all live in jaxlib/gpu:

BLAS:

Merged with solver below so that decisions about batching can be made in the kernel (details).

linalg:

PRNG:

RNN:

  • RNNForward
  • RNNBackward

Solver:

Sparse:

...

Triton:

...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant