Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[question] Call overhead / disable parallelism #1169

Open
aseyboldt opened this issue Aug 12, 2019 · 4 comments
Open

[question] Call overhead / disable parallelism #1169

aseyboldt opened this issue Aug 12, 2019 · 4 comments

Comments

@aseyboldt
Copy link

aseyboldt commented Aug 12, 2019

I'm looking into using jax to write callbacks for a C library, but the call overhead seems to be prohibitively large for smaller inputs.

import jax
import numpy as np
from jax.lib import xla_client

def foo(x):
    return 2 * x**2

foo_jit = jax.jit(foo)

x = jax.device_put(np.zeros(10))
%timeit foo_jit(x).block_until_ready()
%timeit np.arange(10)
# 82.5 µs ± 1.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
# 448 ns ± 19.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

So about 140x slower than a small numpy function, which already has a very high overhead compared to a smaller a C function.

I managed to get rid of most of the python overhead by compiling the xla function on my own:

comp = jax.xla_computation(foo)(x)
exe = comp.Compile()
buffer = xla_client.LocalBuffer.from_pyval(x)
%timeit exe.Execute([buffer,]).block_host_until_ready()
# 19.6 µs ± 863 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

So quite a bit faster, but still not exactly fast. Part (or possibly most) of the problem seems to be that xla is running the computation in a different thread. Is it possible to pass it some options that would prevent that? Looking into the output of perf it also seems that it is spending a lot of time allocating things.

In an ideal world, an interface similar to that of numba.cfunc (see here) would be fantastic. Or a way similar to numba to access the different compiled functions. A way to create a LocalBuffer without copying data would also help (if the other problems are solved): eg xla_client.LocalBuffer.from_buffer(y, copy=False).

If there was any way to get our hands on the c function pointer that xla creates somewhere hidden in it's depths, that would help enormously.

@junpenglao
Copy link

Very interesting question, I wonder if there is indeed ways to do that? @mattjj

@maedoc
Copy link

maedoc commented Dec 13, 2019

It looks like XLA AOT would support it. I'm also interested for the same reasons: I have a driving algorithm in C++ and want to feed it a model defined in Python, w/o the GIL-acquisition, numpy array allcoation overhead.

@skye
Copy link
Member

skye commented Dec 16, 2019

Check out https://github.com/google/jax/blob/master/jax/tools/jax_to_hlo.py. Combined with the XLA C++ client, this should give you something very similar to XLA AOT. @aseyboldt @maedoc would this help with your use cases?

@maedoc
Copy link

maedoc commented Dec 20, 2019

@skye looks promising, will try to make a working prototype when possible

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants