Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for using JAX's custom LAPACK calls externally #23172

Open
joeycarter opened this issue Aug 21, 2024 · 7 comments
Open

Improve support for using JAX's custom LAPACK calls externally #23172

joeycarter opened this issue Aug 21, 2024 · 7 comments
Labels
enhancement New feature or request

Comments

@joeycarter
Copy link

Hi there,

This request is a follow-up to the discussion here: #18065.

What we're trying to accomplish

Suppose we have our own JIT compiler, called qjit. We would like to be able to do the following, for example:

import jax
import jax.numpy as jnp
import qjit

@qjit
def matrix_sqrt(A):
    return jax.scipy.linalg.sqrtm(A)

X = jnp.array([[1., 2., 3.],
               [2., 4., 2.],
               [3., 2., 1.]])
matrix_sqrt(X)

and similarly for the other functions in jax.scipy.linalg. However, when we do so, we get undefined symbol errors, in this case:

OSError: /tmp/matrix_sqrtonhryf12/matrix_sqrt.so: undefined symbol: lapack_zgees

Current workaround

What we currently do to get around this is what was suggested in #18065: to manually compile and link in the required custom JAX LAPACK modules under jaxlib/cpu/, to define the required symbols such as lapack_zgees.

However, this is cumbersome and difficult to maintain (suppose these modules change in a future JAX release).

What else we've tried

We noticed that the jaxlib package comes shipped with the shared object file jaxlib/cpu/_lapack.so, which contains the symbols for the kernel functions that the custom JAX wrappers use. For instance, using the nm tool, we can find the corresponding kernel function for lapack_zgees, ComplexGees<std::complex<double>>::Kernel:

$ nm -C _lapack.so | grep "ComplexGees<std::complex<double> >::Kernel"
000000000000d840 t jax::ComplexGees<std::complex<double> >::Kernel(void*, void**, XlaCustomCallStatus_*)

or, in its mangled form:

$ nm  _lapack.so | grep "000000000000d840"                          
000000000000d840 t _ZN3jax11ComplexGeesISt7complexIdEE6KernelEPvPS4_P20XlaCustomCallStatus_

We tried loading this symbol using the dynamic linking loader as follows (simplified for brevity):

struct XlaCustomCallStatus_ {};

void* handle = dlopen(".../jaxlib/cpu/_lapack.so", RTLD_LAZY);

std::string symbol = "_ZN3jax11ComplexGeesISt7complexIdEE6KernelEPvPS4_P20XlaCustomCallStatus_";
typedef void (*Kernel_t)(void* out_tuple, void** data, XlaCustomCallStatus_*);
Kernel_t _dgetrf = (Kernel_t) dlsym(handle, symbol.c_str());

However, dlsym fails to find the symbol. We believe this is because this function is not exported, as denoted by the lowercase t in the nm output (where exported functions are conventionally denoted by uppercase letters, e.g. T).

So, we believe that we've hit a dead-end with this approach.

Possible solutions

Given the fact that these symbols are already shipped with jaxlib in jaxlib/cpu/_lapack.so, would it be possible to make these functions globally available in a future JAX release to make it possible to dynamically load them using dlopen and dlsym? If that is not possible, is there another approach that is more amenable to using these custom LAPACK calls than manually building and linking them in ourselves?

@joeycarter joeycarter added the enhancement New feature or request label Aug 21, 2024
@dfm
Copy link
Collaborator

dfm commented Aug 22, 2024

Thanks for this request! This is something we'd love to support. We don't have a specific timeline, but I wanted to just confirm here that this feature request is acknowledged.

@dfm
Copy link
Collaborator

dfm commented Aug 22, 2024

Adding one clarification: jaxlib doesn't actually come with its own LAPACK library. It actually links to the one from scipy. The code used to populate our API with those symbols is here:

https://github.com/google/jax/blob/b56ed8eeddc5794f3981832a38b6bcc195eb20f8/jaxlib/cpu/lapack.cc#L40-L152

It's probably worth noting that as part of this conversation!

@joeycarter
Copy link
Author

Hi @dfm,

Could you help clarify something? We were testing a few of the jax.linalg functions with the LAPACK wrappers mentioned above and noticed in certain cases we get incorrect results when calling some JAX linear algebra function from within our qjit-compiled block. For example, with jax.scipy.linalg.lu:

import jax
import jax.numpy as jnp
import qjit

A = jnp.array([[1., 2., 3.],
               [5., 4., 2.],
               [3., 2., 1.]])

P, L, U = jax.scipy.linalg.lu(A)

@qjit
def qjit_lu(X):
    return jax.scipy.linalg.lu(X)

P_qjit, L_qjit, U_qjit = qjit_lu(A)

assert jnp.allclose(P @ L @ U, A)  # Passes
assert jnp.allclose(P, P_qjit)     # Fails
assert jnp.allclose(L, L_qjit)     # Fails
assert jnp.allclose(U, U_qjit)     # Fails

We noticed, for example, that in this case, some of the pivot matrix rows have been reordered:

>>> print(P)
[[0. 1. 0.]
 [0. 0. 1.]
 [1. 0. 0.]]
>>> print(P_qjit)
[[0. 1. 0.]
 [1. 0. 0.]
 [0. 0. 1.]]

To get to this point we've used the jax::Getrf::Kernel function:

template <typename T>
typename Getrf<T>::FnType* Getrf<T>::fn = nullptr;

template <typename T>
void Getrf<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
  int b = *(reinterpret_cast<int32_t*>(data[0]));
  int m = *(reinterpret_cast<int32_t*>(data[1]));
  int n = *(reinterpret_cast<int32_t*>(data[2]));
  const T* a_in = reinterpret_cast<T*>(data[3]);

  void** out = reinterpret_cast<void**>(out_tuple);
  T* a_out = reinterpret_cast<T*>(out[0]);
  int* ipiv = reinterpret_cast<int*>(out[1]);
  int* info = reinterpret_cast<int*>(out[2]);
  if (a_out != a_in) {
    std::memcpy(a_out, a_in,
                static_cast<int64_t>(b) * static_cast<int64_t>(m) *
                    static_cast<int64_t>(n) * sizeof(T));
  }
  for (int i = 0; i < b; ++i) {
    fn(&m, &n, a_out, &m, ipiv, info);
    a_out += static_cast<int64_t>(m) * static_cast<int64_t>(n);
    ipiv += std::min(m, n);
    ++info;
  }
}

template struct Getrf<float>;
template struct Getrf<double>;
template struct Getrf<std::complex<float>>;
template struct Getrf<std::complex<double>>;

We noticed that a number of FFI kernels were added recently, e.g. jax::LuDecomposition::Kernel. Should we be using these kernels instead? We used the other functions because they don't depend on the XLA FFI libraries.

@joeycarter
Copy link
Author

We figured out the issue with the JIT-compiled block giving a different answer: our JIT compiler sends the input to the LAPACK call in C-ordered (row-major) format, but if I've understood correctly, the scipy LAPACK calls expect FORTRAN-ordered (column-major) format.

Sorry for the noise! I am still curious whether we should be using the old kernel function or their FFI variants, though.

@dfm
Copy link
Collaborator

dfm commented Aug 27, 2024

Glad you got that figured out! Yeah, our plan is to migrate all the custom calls to the FFI in the near future (see #23056), so in the long run, that's what you'll need to target. Unfortunately we're currently in that awkward compatibility period where both exist in parallel, and the FFI kernels don't all exist yet!

@joeycarter
Copy link
Author

Great, thanks for the clarification!

While we're on the subject, I'm curious how jax handles the row-/column-major issue. Is there a transformation that occurs somewhere before the call to the LAPACK routine that ensures the array is in column-major format? If so could you point me to where in the code that happens?

@dfm
Copy link
Collaborator

dfm commented Sep 3, 2024

Sure! The place where this is specified on the JAX side is via the operand_layouts and result_layouts parameters in the custom call lowering. For a n dimensional input, we pass: (n - 2, n - 1, n - 3, n - 4, ..., 0) as the layout to specify column-major (instead of (n - 1, n - 2, ...) for row-major). For example:

https://github.com/google/jax/blob/530ed026b8926cba3cb3d06c855b516fd4c9fb38/jaxlib/gpu_solver.py#L112

I haven't gone spelunking to find out where exactly this is used in XLA, but I'm sure it would be possible to track!

joeycarter added a commit to PennyLaneAI/catalyst that referenced this issue Sep 10, 2024
… a `qjit` block (#1097)

**Context:** Certain JAX linear algebra functions from
[`jax.scipy.linalg`](https://jax.readthedocs.io/en/latest/jax.scipy.html#module-jax.scipy.linalg)
occasionally give incorrect results when invoked from within a `qjit`
block, such as `expm`. Furthermore, many of these linear algebra
function are not supported due to missing LAPACK symbols.

**Description of the Change:** This change fixes the issue where the
linear algebra functions invoked from within a `qjit` block occasionally
yield incorrect results and adds full support for the major linear
algebra functions included in
[`jax.scipy.linalg`](https://jax.readthedocs.io/en/latest/jax.scipy.html#module-jax.scipy.linalg),
including, but not limited to:

*
[`jax.scipy.linalg.cholesky`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cholesky.html)
*
[`jax.scipy.linalg.expm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.expm.html)
*
[`jax.scipy.linalg.funm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.funm.html)
*
[`jax.scipy.linalg.hessenberg`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.hessenberg.html)
*
[`jax.scipy.linalg.lu`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu.html)
*
[`jax.scipy.linalg.lu_solve`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu_solve.html)
*
[`jax.scipy.linalg.polar`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.polar.html)
*
[`jax.scipy.linalg.qr`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.qr.html)
*
[`jax.scipy.linalg.schur`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.schur.html)
*
[`jax.scipy.linalg.solve`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.solve.html)
*
[`jax.scipy.linalg.sqrtm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.sqrtm.html)
*
[`jax.scipy.linalg.svd`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.svd.html)

The root cause of the occasionally incorrect results was that the JAX
BLAS and LAPACK wrapper functions defined in
[jaxlib/cpu/lapack_kernels.cc](https://github.com/google/jax/blob/main/jaxlib/cpu/lapack_kernels.cc)
and
[jaxlib/cpu/lapack_kernels_using_lapack.cc](https://github.com/google/jax/blob/main/jaxlib/cpu/lapack_kernels_using_lapack.cc)
are written assuming FORTRAN-style (column-major) matrix layout, and
invoke the FORTRAN BLAS and LAPACK routines. Catalyst on the other hand
currently assumes C-style (row-major) matrix layout. This change
includes modified versions of the JAX wrapper functions that invoke
their respective BLAS and LAPACK routines using the C interfaces.

**Benefits:** Add support for JAX linear algebra functions to Catalyst.

**Possible Drawbacks:** This fix required extensive modifications to the
JAX wrappers to the BLAS and LAPACK routines. If there are upstream
modifications in JAX to these wrappers, we would need to manually update
them in Catalyst for compatibility with future JAX releases. We would
eventually like to remove the need to manually supply these LAPACK
symbols altogether. This is not currently possible, but is being
discussed in JAX Issue
#[23172](jax-ml/jax#23172). In the meantime, a
slight improvement on this approach may be to preserve the column-major
matrix layout in Catalyst so that we may use the FORTRAN BLAS and LAPACK
interfaces shipped with JAX more directly.

**Related GitHub Issues:** #753, closes #1071.


[[sc-72974](https://app.shortcut.com/xanaduai/story/72974/manually-add-support-for-the-important-jax-linear-algebra-functions-in-catalyst)]
mehrdad2m pushed a commit to PennyLaneAI/catalyst that referenced this issue Sep 17, 2024
… a `qjit` block (#1097)

**Context:** Certain JAX linear algebra functions from
[`jax.scipy.linalg`](https://jax.readthedocs.io/en/latest/jax.scipy.html#module-jax.scipy.linalg)
occasionally give incorrect results when invoked from within a `qjit`
block, such as `expm`. Furthermore, many of these linear algebra
function are not supported due to missing LAPACK symbols.

**Description of the Change:** This change fixes the issue where the
linear algebra functions invoked from within a `qjit` block occasionally
yield incorrect results and adds full support for the major linear
algebra functions included in
[`jax.scipy.linalg`](https://jax.readthedocs.io/en/latest/jax.scipy.html#module-jax.scipy.linalg),
including, but not limited to:

*
[`jax.scipy.linalg.cholesky`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cholesky.html)
*
[`jax.scipy.linalg.expm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.expm.html)
*
[`jax.scipy.linalg.funm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.funm.html)
*
[`jax.scipy.linalg.hessenberg`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.hessenberg.html)
*
[`jax.scipy.linalg.lu`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu.html)
*
[`jax.scipy.linalg.lu_solve`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu_solve.html)
*
[`jax.scipy.linalg.polar`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.polar.html)
*
[`jax.scipy.linalg.qr`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.qr.html)
*
[`jax.scipy.linalg.schur`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.schur.html)
*
[`jax.scipy.linalg.solve`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.solve.html)
*
[`jax.scipy.linalg.sqrtm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.sqrtm.html)
*
[`jax.scipy.linalg.svd`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.svd.html)

The root cause of the occasionally incorrect results was that the JAX
BLAS and LAPACK wrapper functions defined in
[jaxlib/cpu/lapack_kernels.cc](https://github.com/google/jax/blob/main/jaxlib/cpu/lapack_kernels.cc)
and
[jaxlib/cpu/lapack_kernels_using_lapack.cc](https://github.com/google/jax/blob/main/jaxlib/cpu/lapack_kernels_using_lapack.cc)
are written assuming FORTRAN-style (column-major) matrix layout, and
invoke the FORTRAN BLAS and LAPACK routines. Catalyst on the other hand
currently assumes C-style (row-major) matrix layout. This change
includes modified versions of the JAX wrapper functions that invoke
their respective BLAS and LAPACK routines using the C interfaces.

**Benefits:** Add support for JAX linear algebra functions to Catalyst.

**Possible Drawbacks:** This fix required extensive modifications to the
JAX wrappers to the BLAS and LAPACK routines. If there are upstream
modifications in JAX to these wrappers, we would need to manually update
them in Catalyst for compatibility with future JAX releases. We would
eventually like to remove the need to manually supply these LAPACK
symbols altogether. This is not currently possible, but is being
discussed in JAX Issue
#[23172](jax-ml/jax#23172). In the meantime, a
slight improvement on this approach may be to preserve the column-major
matrix layout in Catalyst so that we may use the FORTRAN BLAS and LAPACK
interfaces shipped with JAX more directly.

**Related GitHub Issues:** #753, closes #1071.


[[sc-72974](https://app.shortcut.com/xanaduai/story/72974/manually-add-support-for-the-important-jax-linear-algebra-functions-in-catalyst)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants