-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve support for using JAX's custom LAPACK calls externally #23172
Comments
Thanks for this request! This is something we'd love to support. We don't have a specific timeline, but I wanted to just confirm here that this feature request is acknowledged. |
Adding one clarification: jaxlib doesn't actually come with its own LAPACK library. It actually links to the one from scipy. The code used to populate our API with those symbols is here: It's probably worth noting that as part of this conversation! |
Hi @dfm, Could you help clarify something? We were testing a few of the import jax
import jax.numpy as jnp
import qjit
A = jnp.array([[1., 2., 3.],
[5., 4., 2.],
[3., 2., 1.]])
P, L, U = jax.scipy.linalg.lu(A)
@qjit
def qjit_lu(X):
return jax.scipy.linalg.lu(X)
P_qjit, L_qjit, U_qjit = qjit_lu(A)
assert jnp.allclose(P @ L @ U, A) # Passes
assert jnp.allclose(P, P_qjit) # Fails
assert jnp.allclose(L, L_qjit) # Fails
assert jnp.allclose(U, U_qjit) # Fails We noticed, for example, that in this case, some of the pivot matrix rows have been reordered: >>> print(P)
[[0. 1. 0.]
[0. 0. 1.]
[1. 0. 0.]]
>>> print(P_qjit)
[[0. 1. 0.]
[1. 0. 0.]
[0. 0. 1.]] To get to this point we've used the template <typename T>
typename Getrf<T>::FnType* Getrf<T>::fn = nullptr;
template <typename T>
void Getrf<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
int b = *(reinterpret_cast<int32_t*>(data[0]));
int m = *(reinterpret_cast<int32_t*>(data[1]));
int n = *(reinterpret_cast<int32_t*>(data[2]));
const T* a_in = reinterpret_cast<T*>(data[3]);
void** out = reinterpret_cast<void**>(out_tuple);
T* a_out = reinterpret_cast<T*>(out[0]);
int* ipiv = reinterpret_cast<int*>(out[1]);
int* info = reinterpret_cast<int*>(out[2]);
if (a_out != a_in) {
std::memcpy(a_out, a_in,
static_cast<int64_t>(b) * static_cast<int64_t>(m) *
static_cast<int64_t>(n) * sizeof(T));
}
for (int i = 0; i < b; ++i) {
fn(&m, &n, a_out, &m, ipiv, info);
a_out += static_cast<int64_t>(m) * static_cast<int64_t>(n);
ipiv += std::min(m, n);
++info;
}
}
template struct Getrf<float>;
template struct Getrf<double>;
template struct Getrf<std::complex<float>>;
template struct Getrf<std::complex<double>>; We noticed that a number of FFI kernels were added recently, e.g. |
We figured out the issue with the JIT-compiled block giving a different answer: our JIT compiler sends the input to the LAPACK call in C-ordered (row-major) format, but if I've understood correctly, the scipy LAPACK calls expect FORTRAN-ordered (column-major) format. Sorry for the noise! I am still curious whether we should be using the old kernel function or their FFI variants, though. |
Glad you got that figured out! Yeah, our plan is to migrate all the custom calls to the FFI in the near future (see #23056), so in the long run, that's what you'll need to target. Unfortunately we're currently in that awkward compatibility period where both exist in parallel, and the FFI kernels don't all exist yet! |
Great, thanks for the clarification! While we're on the subject, I'm curious how jax handles the row-/column-major issue. Is there a transformation that occurs somewhere before the call to the LAPACK routine that ensures the array is in column-major format? If so could you point me to where in the code that happens? |
Sure! The place where this is specified on the JAX side is via the I haven't gone spelunking to find out where exactly this is used in XLA, but I'm sure it would be possible to track! |
… a `qjit` block (#1097) **Context:** Certain JAX linear algebra functions from [`jax.scipy.linalg`](https://jax.readthedocs.io/en/latest/jax.scipy.html#module-jax.scipy.linalg) occasionally give incorrect results when invoked from within a `qjit` block, such as `expm`. Furthermore, many of these linear algebra function are not supported due to missing LAPACK symbols. **Description of the Change:** This change fixes the issue where the linear algebra functions invoked from within a `qjit` block occasionally yield incorrect results and adds full support for the major linear algebra functions included in [`jax.scipy.linalg`](https://jax.readthedocs.io/en/latest/jax.scipy.html#module-jax.scipy.linalg), including, but not limited to: * [`jax.scipy.linalg.cholesky`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cholesky.html) * [`jax.scipy.linalg.expm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.expm.html) * [`jax.scipy.linalg.funm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.funm.html) * [`jax.scipy.linalg.hessenberg`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.hessenberg.html) * [`jax.scipy.linalg.lu`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu.html) * [`jax.scipy.linalg.lu_solve`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu_solve.html) * [`jax.scipy.linalg.polar`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.polar.html) * [`jax.scipy.linalg.qr`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.qr.html) * [`jax.scipy.linalg.schur`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.schur.html) * [`jax.scipy.linalg.solve`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.solve.html) * [`jax.scipy.linalg.sqrtm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.sqrtm.html) * [`jax.scipy.linalg.svd`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.svd.html) The root cause of the occasionally incorrect results was that the JAX BLAS and LAPACK wrapper functions defined in [jaxlib/cpu/lapack_kernels.cc](https://github.com/google/jax/blob/main/jaxlib/cpu/lapack_kernels.cc) and [jaxlib/cpu/lapack_kernels_using_lapack.cc](https://github.com/google/jax/blob/main/jaxlib/cpu/lapack_kernels_using_lapack.cc) are written assuming FORTRAN-style (column-major) matrix layout, and invoke the FORTRAN BLAS and LAPACK routines. Catalyst on the other hand currently assumes C-style (row-major) matrix layout. This change includes modified versions of the JAX wrapper functions that invoke their respective BLAS and LAPACK routines using the C interfaces. **Benefits:** Add support for JAX linear algebra functions to Catalyst. **Possible Drawbacks:** This fix required extensive modifications to the JAX wrappers to the BLAS and LAPACK routines. If there are upstream modifications in JAX to these wrappers, we would need to manually update them in Catalyst for compatibility with future JAX releases. We would eventually like to remove the need to manually supply these LAPACK symbols altogether. This is not currently possible, but is being discussed in JAX Issue #[23172](jax-ml/jax#23172). In the meantime, a slight improvement on this approach may be to preserve the column-major matrix layout in Catalyst so that we may use the FORTRAN BLAS and LAPACK interfaces shipped with JAX more directly. **Related GitHub Issues:** #753, closes #1071. [[sc-72974](https://app.shortcut.com/xanaduai/story/72974/manually-add-support-for-the-important-jax-linear-algebra-functions-in-catalyst)]
… a `qjit` block (#1097) **Context:** Certain JAX linear algebra functions from [`jax.scipy.linalg`](https://jax.readthedocs.io/en/latest/jax.scipy.html#module-jax.scipy.linalg) occasionally give incorrect results when invoked from within a `qjit` block, such as `expm`. Furthermore, many of these linear algebra function are not supported due to missing LAPACK symbols. **Description of the Change:** This change fixes the issue where the linear algebra functions invoked from within a `qjit` block occasionally yield incorrect results and adds full support for the major linear algebra functions included in [`jax.scipy.linalg`](https://jax.readthedocs.io/en/latest/jax.scipy.html#module-jax.scipy.linalg), including, but not limited to: * [`jax.scipy.linalg.cholesky`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cholesky.html) * [`jax.scipy.linalg.expm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.expm.html) * [`jax.scipy.linalg.funm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.funm.html) * [`jax.scipy.linalg.hessenberg`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.hessenberg.html) * [`jax.scipy.linalg.lu`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu.html) * [`jax.scipy.linalg.lu_solve`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu_solve.html) * [`jax.scipy.linalg.polar`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.polar.html) * [`jax.scipy.linalg.qr`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.qr.html) * [`jax.scipy.linalg.schur`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.schur.html) * [`jax.scipy.linalg.solve`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.solve.html) * [`jax.scipy.linalg.sqrtm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.sqrtm.html) * [`jax.scipy.linalg.svd`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.svd.html) The root cause of the occasionally incorrect results was that the JAX BLAS and LAPACK wrapper functions defined in [jaxlib/cpu/lapack_kernels.cc](https://github.com/google/jax/blob/main/jaxlib/cpu/lapack_kernels.cc) and [jaxlib/cpu/lapack_kernels_using_lapack.cc](https://github.com/google/jax/blob/main/jaxlib/cpu/lapack_kernels_using_lapack.cc) are written assuming FORTRAN-style (column-major) matrix layout, and invoke the FORTRAN BLAS and LAPACK routines. Catalyst on the other hand currently assumes C-style (row-major) matrix layout. This change includes modified versions of the JAX wrapper functions that invoke their respective BLAS and LAPACK routines using the C interfaces. **Benefits:** Add support for JAX linear algebra functions to Catalyst. **Possible Drawbacks:** This fix required extensive modifications to the JAX wrappers to the BLAS and LAPACK routines. If there are upstream modifications in JAX to these wrappers, we would need to manually update them in Catalyst for compatibility with future JAX releases. We would eventually like to remove the need to manually supply these LAPACK symbols altogether. This is not currently possible, but is being discussed in JAX Issue #[23172](jax-ml/jax#23172). In the meantime, a slight improvement on this approach may be to preserve the column-major matrix layout in Catalyst so that we may use the FORTRAN BLAS and LAPACK interfaces shipped with JAX more directly. **Related GitHub Issues:** #753, closes #1071. [[sc-72974](https://app.shortcut.com/xanaduai/story/72974/manually-add-support-for-the-important-jax-linear-algebra-functions-in-catalyst)]
Hi there,
This request is a follow-up to the discussion here: #18065.
What we're trying to accomplish
Suppose we have our own JIT compiler, called
qjit
. We would like to be able to do the following, for example:and similarly for the other functions in
jax.scipy.linalg
. However, when we do so, we getundefined symbol
errors, in this case:Current workaround
What we currently do to get around this is what was suggested in #18065: to manually compile and link in the required custom JAX LAPACK modules under
jaxlib/cpu/
, to define the required symbols such aslapack_zgees
.However, this is cumbersome and difficult to maintain (suppose these modules change in a future JAX release).
What else we've tried
We noticed that the
jaxlib
package comes shipped with the shared object filejaxlib/cpu/_lapack.so
, which contains the symbols for the kernel functions that the custom JAX wrappers use. For instance, using thenm
tool, we can find the corresponding kernel function forlapack_zgees
,ComplexGees<std::complex<double>>::Kernel
:or, in its mangled form:
We tried loading this symbol using the dynamic linking loader as follows (simplified for brevity):
However,
dlsym
fails to find the symbol. We believe this is because this function is not exported, as denoted by the lowercaset
in thenm
output (where exported functions are conventionally denoted by uppercase letters, e.g.T
).So, we believe that we've hit a dead-end with this approach.
Possible solutions
Given the fact that these symbols are already shipped with
jaxlib
injaxlib/cpu/_lapack.so
, would it be possible to make these functions globally available in a future JAX release to make it possible to dynamically load them usingdlopen
anddlsym
? If that is not possible, is there another approach that is more amenable to using these custom LAPACK calls than manually building and linking them in ourselves?The text was updated successfully, but these errors were encountered: