You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
On a sidenote, I am wondering why jax.lax.cond(pred, true_fun, false_fun, *operands) uses stablehlo.case, and not stablehlo.if (which is closer semantically). I found openxla/stablehlo#599, so perhaps there's nothing to be gained from stablehlo.if and stablehlo.case is used just because it's more generic? I would love someone from the Jax team to confirm :)
Description
HLO
The above fails with
However, it works if we change the branches slightly (just adding a constant):
Another example that fails is this:
but works when changed to
From the above I gather that this operation it is already supposed to work, but there is clearly some inconsistency.
System info (python version, jaxlib version, accelerator, etc.)
jax-metal 0.0.7
The text was updated successfully, but these errors were encountered: