Replies: 1 comment
-
Hi @mkskeller, as SF/SPU uses Jax for frontend Python programs, you could use Jax's powerful vectorized capability to optimize this program like below, which may help XLA to do more hardware-independent optimizations and produce much better codes. data = jnp.array([list(range(n))] * 3 + [[i % 5 for i in range(n)]])
def xtabs(data):
xid, yid, vals, cats = data
n = len(xid)
num_k = 5
def map_i(i):
def map_j(j):
def map_k(k):
return (xid[i] == yid[j]) * (cats[j] == k) * vals[i]
return jax.vmap(map_k)(jnp.arange(num_k))
return jnp.sum(jax.vmap(map_j)(jnp.arange(n)), axis=0)
res = jnp.sum(jax.vmap(map_i)(jnp.arange(n)), axis=0)
return res |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I've tried to port the cross-tabulation example from https://github.com/MPC-SoK/frameworks as follows:
It gives me the correct results, but it seems slower than most frameworks in the above repository. Is there a way of improving it like vectorized operation or reducing the number of bits used in the comparisons? I cannot find any hints in the documentation. I'm using the latest Docker container.
Full code:
Beta Was this translation helpful? Give feedback.
All reactions