-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JITing loss function causes erroneous results #23226
Comments
Bumping this with some more testing/further narrowing down. Here's a snippet that is much shorter/hopefully easier to understand. Both check grads always fail (I suspect that this is just a too small tolerance). But importantly, the inclusion of any ReLU (leaky or otherwise) causes the JIT'd absolute difference to be a factor of 2 or more bigger than non-JIT'd, in my last test JIT had a max absolute difference of latent_dim = 2**11
input_dim = 3072
initializer = jax.nn.initializers.he_uniform()
encoder = initializer(jax.random.key(0), (latent_dim, input_dim), jnp.float32)
decoder = encoder.T
def encode(x):
codes = encoder @ x
return jax.nn.relu(codes)
#return leaky_offset_relu(codes, negative_slope=0., offset=1.96/jnp.sqrt(encoder.shape[0]))
def top_k_decode(top_k_indices, top_k_values):
decoder_weights = (decoder / jnp.linalg.norm(decoder, axis=-1, keepdims=True)).T
# top_k_indices is now 1D after vmap, so we don't need [:, :, None]
selected_decoder_weights = decoder_weights[top_k_indices]
# Adjust the sum operation to match the new shape
return jnp.sum(top_k_values[:, None] * selected_decoder_weights, axis=0)
def fwd_pass(batch: jnp.ndarray):
top_level_latent_codes = jax.vmap(encode)(batch)
top_k_values, top_k_indices = jax.lax.top_k(top_level_latent_codes, 8)
x_hat = jax.vmap(top_k_decode)(top_k_indices, top_k_values)
return x_hat
def recon_loss(batch: jnp.ndarray):
x_hat = fwd_pass(batch)
return jnp.mean(jnp.sum(jnp.square(batch - x_hat), axis=-1))
example_batch = jax.random.normal(jax.random.key(42), (4096,3072))
example_batch = example_batch/jnp.linalg.norm(example_batch, axis=-1, keepdims=True)
from jax.test_util import check_grads
check_grads(jax.jit(recon_loss), args=(example_batch,), order=1)
check_grads(recon_loss, args=(example_batch,), order=1) P.S. tested this on a TPU colab and jitting/adding ReLU doesn't make a difference--the grads are "out of tolerance" in all cases by the same amount |
I think the issue here is that your gradients are very close to zero, so very small absolute deviations become relatively large relative variations: grad_jit_result = jax.grad(jax.jit(recon_loss))(example_batch)
grad_result = jax.grad(recon_loss)(example_batch)
np.testing.assert_allclose(grad_jit_result, grad_result)
The gradients match in an absolute sense to 1 part in 10^12, but in a relative sense only match to 1 part in 3, meaning the mismatched gradients are something like |
That makes sense, but I still encounter the bug where JITing causes optimization to work in one case and totally fail in the other (and this behavior doesn't happen on a TPU platform, which makes me think there's some XLA bug, not a code bug). Removed the leaky ReLU since it doesn't do anything except add a threshold when there's a top-k activation and after that change JIT'd training on GPU still fails and removing the JIT (or turning off algsimp) causes it to succeed (e.g. at step 400 a JIT'd trainer still has loss above 1 and an eager optimization is ~0.8) |
Hmm, that's strange indeed. In general we don't expect JIT-compiled versions of functions to have bitwise-identical outputs to the non-compiled versions: any time you rearrange or fuse floating point operations, you'll change the details of the numerics, but it should be to within normal floating-point precision. Can you try running it again while setting |
That didn't help. Here's the code as it is being run right now (and the first 1000 steps of training)--sorry that it's no longer fully self-contained. Contrary to what I said earlier, substituting the leaky_offset_relufor either a normal relu or no activation fixes the issue (my other comment was testing a more complicated version of this model 🥲). There's no actual negative slope in how the leaky offset ReLU is called, just a thresholded ReLU; I suspect that a decent number of the activations end up below the threshold at the beginning of training and that causes issues. import jax
import jax.numpy as jnp
import equinox as eqx
import treescope
from jax.experimental import sparse
treescope.basic_interactive_setup()
from typing import Tuple
import optax
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
jax.config.update('jax_default_matmul_precision', "highest")
def leaky_offset_relu(x, negative_slope=1e-2, offset=0):
return jnp.where(x >= offset, x, negative_slope * x)
class Autoencoder(eqx.Module):
encoder: jnp.ndarray
decoder: jnp.ndarray
bias: jnp.ndarray
use_bias: bool
def __init__(self, latent_dim: int, input_dim: int, use_bias: bool = True, key=None):
initializer = jax.nn.initializers.he_uniform()
self.encoder = initializer(key, (latent_dim, input_dim), jnp.float32)
self.decoder = self.encoder.T
self.bias = jnp.zeros(input_dim) if use_bias else None
self.use_bias = use_bias
def encode(self, x):
x = x - self.bias if self.use_bias else x
codes = self.encoder @ x
#return codes
return leaky_offset_relu(codes, negative_slope=0., offset=1.96/jnp.sqrt(self.encoder.shape[0]))
def top_k_decode(self, top_k_indices, top_k_values):
decoder_weights = self.get_decoder()
# top_k_indices is now 1D after vmap, so we don't need [:, :, None]
selected_decoder_weights = decoder_weights[:, top_k_indices]
return selected_decoder_weights @ top_k_values + (self.bias if self.use_bias else 0)
def get_decoder(self):
return self.decoder / jnp.linalg.norm(self.decoder, axis=0, keepdims=True)
def get_encoder(self):
return self.encoder
def __call__(self, x):
z = self.encode(x)
return self.decode(z)
def fwd_pass(model: Autoencoder, batch: jnp.ndarray):
top_level_latent_codes = jax.vmap(model.encode)(batch)
top_k_values, top_k_indices = jax.lax.top_k(top_level_latent_codes, 8)
x_hat = jax.vmap(model.top_k_decode)(top_k_indices, top_k_values)
return x_hat, top_level_latent_codes
def recon_loss(x_hat: jnp.ndarray, batch: jnp.ndarray):
return jnp.mean(jnp.sum(jnp.square(batch - x_hat), axis=-1))
@eqx.filter_value_and_grad
def loss_fn(model: Autoencoder, batch: jnp.ndarray, l1_penalty: float, ortho_penalty: float) -> float:
x_hat, top_level_latent_codes = fwd_pass(model, batch)
return recon_loss(x_hat, batch)
def update_model(model, grads, opt_state, optimizer):
updates, new_opt_state = optimizer.update(grads, opt_state)
new_model = eqx.apply_updates(model, updates)
return new_model, new_opt_state
@eqx.filter_jit
def train_step(model: Autoencoder, batch: jnp.ndarray, opt_state, l1_penalty: float, ortho_penalty: float, optimizer) -> Tuple[Autoencoder, optax.OptState, float]:
loss, grads = loss_fn(model, batch, l1_penalty, ortho_penalty)
model, opt_state = update_model(model, grads, opt_state, optimizer)
return model, opt_state, loss
# omitted dataloader stuff
model = Autoencoder(
input_dim=3072,
latent_dim=2**12,
use_bias=False,
key=jax.random.key(0)
)
learning_rate = 1e-4
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
l1_penalty = 1e-3
ortho_penalty = 1e-2
num_steps = 10000
for step in range(num_steps):
batch = next(train_loader)
model, opt_state, loss = train_step(model, batch, opt_state, l1_penalty, ortho_penalty, optimizer)
if step % 50 == 0:
print(f"Step {step}, Loss: {loss:.4f}") No JIT
JIT on
|
Hi, what is happening is that the layout of top_level_latent_codes is changing the top_k algorithm and it looks like to me that there are a lot of near-zero values in the top_level_latent_codes and the top_k algorithm is struggling to choose the top_k consistently. If I add a layout constraint into your code (or return top_level_latent_codes as an aux output) then everything becomes consistent between jit and without_jit.
While the layout does make them match, I suspect that this implies that the choice of the top_k in the case of ties or near zeros really matters and you should do something smarter here rather than just forcing it to choose a particular top_k algorithm via constraining the layouts. |
Slight update on the usage of layout API: Use the experimental endpoint: |
This makes sense, but I'm still curious as to why this happens with |
You can debug these things by adding an aux output and print the intermediate results: I think the simple explanation is that top_k_indices with relu + jit are ordered like so:
And everything else (nojit relu, jit norelu, nojit norelu) is more random like so:
Printing the intermediate values for leaky_relu, there are a lot of zeros in the results. Anyways, looking more closely, it might actually be a bug in how XLA does top-k. I'll look into that. |
Thanks! Another thing I don't understand is why major-minor ordering helps regardless of batch size? Even if the batch dimension is larger than the latent dimension, your layout fix works--but isn't the batch dimension in that case major and latent minor? Is it just that the vmap means that the batch dimension will always be treated as the minor dimension since the ops are vectorized across it? (these runs were before I was adding in the aux output) |
Was making modifications to code today and noticed that adding the line |
Description
The below snippet causes drastically different results if the loss function is JIT'd vs not compiled. This results in a bad optimization where the version with the JIT'd loss function never converges. This error is probably with the calculation of the gradients, JITing any of the individual components doesn't cause the issue. Additionally, this bug only appears on systems with GPUs, in a TPU environment the issue doesn't appear. The bug is unrelated to my use of Equinox, the bug was originally observed in Flax. The flag
--xla_disable_hlo_passes=algsimp
will prevent the bug from appearing. I'm unsure what the jax equivalent of--tf_xla_parallel_checking
is so that I can compare the XLA-HLO computational graphs and narrow down the bug further. My guess would be something to do with the top-k gating in the decoder's interaction with jit'd grad (a version of this without a sparsity aware decoder didn't have this bug).System info (python version, jaxlib version, accelerator, etc.)
This bug has been observed on the latest jax version with an A100 as well. This is the colab environment where this isolated example was created
The text was updated successfully, but these errors were encountered: