You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
So quite a bit faster, but still not exactly fast. Part (or possibly most) of the problem seems to be that xla is running the computation in a different thread. Is it possible to pass it some options that would prevent that? Looking into the output of perf it also seems that it is spending a lot of time allocating things.
In an ideal world, an interface similar to that of numba.cfunc (see here) would be fantastic. Or a way similar to numba to access the different compiled functions. A way to create a LocalBuffer without copying data would also help (if the other problems are solved): eg xla_client.LocalBuffer.from_buffer(y, copy=False).
If there was any way to get our hands on the c function pointer that xla creates somewhere hidden in it's depths, that would help enormously.
The text was updated successfully, but these errors were encountered:
It looks like XLA AOT would support it. I'm also interested for the same reasons: I have a driving algorithm in C++ and want to feed it a model defined in Python, w/o the GIL-acquisition, numpy array allcoation overhead.
I'm looking into using jax to write callbacks for a C library, but the call overhead seems to be prohibitively large for smaller inputs.
So about 140x slower than a small numpy function, which already has a very high overhead compared to a smaller a C function.
I managed to get rid of most of the python overhead by compiling the xla function on my own:
So quite a bit faster, but still not exactly fast. Part (or possibly most) of the problem seems to be that xla is running the computation in a different thread. Is it possible to pass it some options that would prevent that? Looking into the output of
perf
it also seems that it is spending a lot of time allocating things.In an ideal world, an interface similar to that of
numba.cfunc
(see here) would be fantastic. Or a way similar to numba to access the different compiled functions. A way to create a LocalBuffer without copying data would also help (if the other problems are solved): egxla_client.LocalBuffer.from_buffer(y, copy=False)
.If there was any way to get our hands on the c function pointer that xla creates somewhere hidden in it's depths, that would help enormously.
The text was updated successfully, but these errors were encountered: