Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JITing loss function causes erroneous results #23226

Open
muchanem opened this issue Aug 24, 2024 · 11 comments
Open

JITing loss function causes erroneous results #23226

muchanem opened this issue Aug 24, 2024 · 11 comments
Assignees
Labels
bug Something isn't working

Comments

@muchanem
Copy link

muchanem commented Aug 24, 2024

Description

The below snippet causes drastically different results if the loss function is JIT'd vs not compiled. This results in a bad optimization where the version with the JIT'd loss function never converges. This error is probably with the calculation of the gradients, JITing any of the individual components doesn't cause the issue. Additionally, this bug only appears on systems with GPUs, in a TPU environment the issue doesn't appear. The bug is unrelated to my use of Equinox, the bug was originally observed in Flax. The flag --xla_disable_hlo_passes=algsimp will prevent the bug from appearing. I'm unsure what the jax equivalent of --tf_xla_parallel_checking is so that I can compare the XLA-HLO computational graphs and narrow down the bug further. My guess would be something to do with the top-k gating in the decoder's interaction with jit'd grad (a version of this without a sparsity aware decoder didn't have this bug).

import jax
import jax.numpy as jnp
import os
import equinox as eqx
import treescope
treescope.basic_interactive_setup(autovisualize_arrays=True)

eps = 1e-6

def leaky_offset_relu(x, negative_slope=1e-2, offset=0):
    return jnp.where(x >= offset, x, negative_slope * x)

def gram_matrix_regularizer(weights):
    weights = weights / (jnp.linalg.norm(weights, axis=0, keepdims=True) + 1e-6)
    gram_matrix = jnp.dot(weights.T, weights)
    off_diagonal_elements = gram_matrix - jnp.diag(jnp.diag(gram_matrix))
    dim = off_diagonal_elements.shape[0]
    regularization_penalty = jnp.sum(off_diagonal_elements ** 2) / (dim**2 - dim)
    return weights.shape[0] * regularization_penalty

class Autoencoder(eqx.Module):
    encoder: eqx.nn.Linear
    decoder: eqx.nn.Linear
    bias: jnp.ndarray
    use_bias: bool

    def __init__(self, latent_dim: int, input_dim: int, use_bias: bool = True, key=None):
        encoder_key, decoder_key = jax.random.split(key)
        self.encoder = eqx.nn.Linear(input_dim, latent_dim, key=encoder_key)
        self.decoder = eqx.nn.Linear(latent_dim, input_dim, key=decoder_key)
        self.bias = jnp.zeros(input_dim) if use_bias else None
        self.use_bias = use_bias

    def encode(self, x):
        x = x - self.bias if self.use_bias else x
        codes = self.encoder(x)
        codes = leaky_offset_relu(codes, negative_slope=0., offset=1.96/jnp.sqrt(self.encoder.in_features))
        return codes
      
    def decode(self, z):
        return self.decoder(z) + (self.bias if self.use_bias else 0)

    def top_k_decode(self, top_k_indices, top_k_values):
        decoder_weights = self.get_decoder()
        # top_k_indices is now 1D after vmap, so we don't need [:, :, None]
        selected_decoder_weights = decoder_weights[top_k_indices]
        # Adjust the sum operation to match the new shape
        return jnp.sum(top_k_values[:, None] * selected_decoder_weights, axis=0) + (self.bias if self.use_bias else 0)

    def get_decoder(self):
        return self.decoder.weight.T / jnp.linalg.norm(self.decoder.weight.T, axis=-1, keepdims=True)

    def get_encoder(self):
        return self.encoder.weight

    def __call__(self, x):
        z = self.encode(x)
        return self.decode(z)


def fwd_pass(model: Autoencoder, batch: jnp.ndarray):
    top_level_latent_codes = jax.vmap(model.encode)(batch)
    top_k_values, top_k_indices = jax.lax.top_k(top_level_latent_codes, 8) # Top-k gating
    x_hat = jax.vmap(model.top_k_decode)(top_k_indices, top_k_values)
    return x_hat, top_level_latent_codes
  
def recon_loss(x_hat: jnp.ndarray, batch: jnp.ndarray):
  return jnp.mean(jnp.sum(jnp.square(batch - x_hat), axis=-1))


def l1(top_level_latent_codes: jnp.array, batch_size: int):
  return (jnp.sum(jnp.abs(top_level_latent_codes))) / batch_size
  
@eqx.filter_value_and_grad
def loss_fn(model: Autoencoder, batch: jnp.ndarray, l1_penalty: float, ortho_penalty: float) -> float:

    x_hat, top_level_latent_codes = fwd_pass(model, batch)
    reconstruction_loss = recon_loss(x_hat, batch)

    batch_size = batch.shape[0]
    l1_loss = l1_penalty * l1(top_level_latent_codes, batch_size)

    top_ortho_loss = gram_matrix_regularizer(model.get_encoder())

    ortho_loss = top_ortho_loss

    total_loss = reconstruction_loss + l1_loss + ortho_penalty * ortho_loss

    return total_loss
model = Autoencoder(
    input_dim=3072,
    latent_dim=2**11,
    use_bias=False,
    key=jax.random.key(0)
)
batch = jax.random.normal(jax.random.key(42), (4096, 3072))
batch = batch / jnp.linalg.norm(batch, axis=1, keepdims=True)
l1_penalty = 1e-3
ortho_penalty = 1e-2
loss, grads = loss_fn(model, batch, l1_penalty, ortho_penalty)
jit_loss, jit_grads = eqx.filter_jit(loss_fn)(model, batch, l1_penalty, ortho_penalty)

print(jit_loss) # 1.015691
print(loss) # 1.0234939

diff = jax.tree.map(lambda x, y: x - y, grads, jit_grads) # will show a big difference in the grads

System info (python version, jaxlib version, accelerator, etc.)

This bug has been observed on the latest jax version with an A100 as well. This is the colab environment where this isolated example was created

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='2fb653a068c6', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 27 21:05:47 UTC 2024', machine='x86_64')


$ nvidia-smi
Sat Aug 24 21:33:45 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA L4                      Off | 00000000:00:03.0 Off |                    0 |
| N/A   67C    P0              32W /  72W |  17247MiB / 23034MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+

@muchanem muchanem added the bug Something isn't working label Aug 24, 2024
@muchanem muchanem changed the title JIting loss function causes erroneous results JITing loss function causes erroneous results Aug 24, 2024
@muchanem
Copy link
Author

muchanem commented Aug 26, 2024

Bumping this with some more testing/further narrowing down. Here's a snippet that is much shorter/hopefully easier to understand. Both check grads always fail (I suspect that this is just a too small tolerance). But importantly, the inclusion of any ReLU (leaky or otherwise) causes the JIT'd absolute difference to be a factor of 2 or more bigger than non-JIT'd, in my last test JIT had a max absolute difference of 0.10245637 vs 0.04284701 without JIT (notably without an activation, they remain the same but out of tolerance).

latent_dim = 2**11
input_dim = 3072
initializer = jax.nn.initializers.he_uniform()
encoder = initializer(jax.random.key(0), (latent_dim, input_dim), jnp.float32)
decoder = encoder.T

def encode(x):
    codes = encoder @ x
    return jax.nn.relu(codes)
    #return leaky_offset_relu(codes, negative_slope=0., offset=1.96/jnp.sqrt(encoder.shape[0]))

def top_k_decode(top_k_indices, top_k_values):
    decoder_weights = (decoder / jnp.linalg.norm(decoder, axis=-1, keepdims=True)).T
    # top_k_indices is now 1D after vmap, so we don't need [:, :, None]
    selected_decoder_weights = decoder_weights[top_k_indices]
    # Adjust the sum operation to match the new shape
    return jnp.sum(top_k_values[:, None] * selected_decoder_weights, axis=0)
    
def fwd_pass(batch: jnp.ndarray):
    top_level_latent_codes = jax.vmap(encode)(batch)
    top_k_values, top_k_indices = jax.lax.top_k(top_level_latent_codes, 8)
    x_hat = jax.vmap(top_k_decode)(top_k_indices, top_k_values)
    return x_hat
    
def recon_loss(batch: jnp.ndarray):
    x_hat = fwd_pass(batch)
    return jnp.mean(jnp.sum(jnp.square(batch - x_hat), axis=-1))
example_batch = jax.random.normal(jax.random.key(42), (4096,3072))
example_batch = example_batch/jnp.linalg.norm(example_batch, axis=-1, keepdims=True)
from jax.test_util import check_grads
check_grads(jax.jit(recon_loss), args=(example_batch,), order=1)
check_grads(recon_loss, args=(example_batch,), order=1)

P.S. tested this on a TPU colab and jitting/adding ReLU doesn't make a difference--the grads are "out of tolerance" in all cases by the same amount

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 26, 2024

I think the issue here is that your gradients are very close to zero, so very small absolute deviations become relatively large relative variations:

grad_jit_result = jax.grad(jax.jit(recon_loss))(example_batch)
grad_result = jax.grad(recon_loss)(example_batch)
np.testing.assert_allclose(grad_jit_result, grad_result)
AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 2554069 / 12582912 (20.3%)
Max absolute difference: 7.275958e-12
Max relative difference: 0.3
 x: array([[-6.417103e-06,  1.638319e-05, -2.042344e-06, ...,  2.303527e-05,
         2.817694e-06, -6.996353e-06],
       [-3.996819e-06, -8.903306e-06,  7.238759e-08, ..., -2.694251e-06,...
 y: array([[-6.417103e-06,  1.638319e-05, -2.042344e-06, ...,  2.303527e-05,
         2.817694e-06, -6.996353e-06],
       [-3.996818e-06, -8.903306e-06,  7.238714e-08, ..., -2.694251e-06,...

The gradients match in an absolute sense to 1 part in 10^12, but in a relative sense only match to 1 part in 3, meaning the mismatched gradients are something like 3E-13 vs. 1E-12, so both basically zero. This large relative difference is likely causing check_grads to report an error, but it's not an error that will be numerically significant in the overall computation.

@muchanem
Copy link
Author

muchanem commented Aug 26, 2024

That makes sense, but I still encounter the bug where JITing causes optimization to work in one case and totally fail in the other (and this behavior doesn't happen on a TPU platform, which makes me think there's some XLA bug, not a code bug). Removed the leaky ReLU since it doesn't do anything except add a threshold when there's a top-k activation and after that change JIT'd training on GPU still fails and removing the JIT (or turning off algsimp) causes it to succeed (e.g. at step 400 a JIT'd trainer still has loss above 1 and an eager optimization is ~0.8)

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 26, 2024

Hmm, that's strange indeed. In general we don't expect JIT-compiled versions of functions to have bitwise-identical outputs to the non-compiled versions: any time you rearrange or fuse floating point operations, you'll change the details of the numerics, but it should be to within normal floating-point precision.

Can you try running it again while setting jax_default_matmul_precision=highest?

@muchanem
Copy link
Author

That didn't help. Here's the code as it is being run right now (and the first 1000 steps of training)--sorry that it's no longer fully self-contained. Contrary to what I said earlier, substituting the leaky_offset_relufor either a normal relu or no activation fixes the issue (my other comment was testing a more complicated version of this model 🥲). There's no actual negative slope in how the leaky offset ReLU is called, just a thresholded ReLU; I suspect that a decent number of the activations end up below the threshold at the beginning of training and that causes issues.

import jax
import jax.numpy as jnp
import equinox as eqx
import treescope
from jax.experimental import sparse
treescope.basic_interactive_setup()
from typing import Tuple
import optax
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
jax.config.update('jax_default_matmul_precision', "highest")

def leaky_offset_relu(x, negative_slope=1e-2, offset=0):
    return jnp.where(x >= offset, x, negative_slope * x)

class Autoencoder(eqx.Module):
    encoder: jnp.ndarray
    decoder: jnp.ndarray
    bias: jnp.ndarray
    use_bias: bool

    def __init__(self, latent_dim: int, input_dim: int, use_bias: bool = True, key=None):
        initializer = jax.nn.initializers.he_uniform()
        self.encoder = initializer(key, (latent_dim, input_dim), jnp.float32)
        self.decoder = self.encoder.T
        self.bias = jnp.zeros(input_dim) if use_bias else None
        self.use_bias = use_bias

    def encode(self, x):
        x = x - self.bias if self.use_bias else x
        codes = self.encoder @ x
        #return codes
        return leaky_offset_relu(codes, negative_slope=0., offset=1.96/jnp.sqrt(self.encoder.shape[0]))

    def top_k_decode(self, top_k_indices, top_k_values):
        decoder_weights = self.get_decoder()
        # top_k_indices is now 1D after vmap, so we don't need [:, :, None]
        selected_decoder_weights = decoder_weights[:, top_k_indices]
        return selected_decoder_weights @ top_k_values + (self.bias if self.use_bias else 0)

    def get_decoder(self):
        return self.decoder / jnp.linalg.norm(self.decoder, axis=0, keepdims=True)

    def get_encoder(self):
        return self.encoder

    def __call__(self, x):
        z = self.encode(x)
        return self.decode(z)

def fwd_pass(model: Autoencoder, batch: jnp.ndarray):
    top_level_latent_codes = jax.vmap(model.encode)(batch)
    top_k_values, top_k_indices = jax.lax.top_k(top_level_latent_codes, 8)
    x_hat = jax.vmap(model.top_k_decode)(top_k_indices, top_k_values)
    return x_hat, top_level_latent_codes

def recon_loss(x_hat: jnp.ndarray, batch: jnp.ndarray):
    return jnp.mean(jnp.sum(jnp.square(batch - x_hat), axis=-1))


@eqx.filter_value_and_grad
def loss_fn(model: Autoencoder, batch: jnp.ndarray, l1_penalty: float, ortho_penalty: float) -> float:
    x_hat, top_level_latent_codes = fwd_pass(model, batch)
    return recon_loss(x_hat, batch)

def update_model(model, grads, opt_state, optimizer):
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_model = eqx.apply_updates(model, updates)
    return new_model, new_opt_state

@eqx.filter_jit
def train_step(model: Autoencoder, batch: jnp.ndarray, opt_state, l1_penalty: float, ortho_penalty: float, optimizer) -> Tuple[Autoencoder, optax.OptState, float]:
    loss, grads = loss_fn(model, batch, l1_penalty, ortho_penalty)
    model, opt_state = update_model(model, grads, opt_state, optimizer)
    return model, opt_state, loss
# omitted dataloader stuff
model = Autoencoder(
    input_dim=3072,
    latent_dim=2**12,
    use_bias=False,
    key=jax.random.key(0)
)

learning_rate = 1e-4
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

l1_penalty = 1e-3
ortho_penalty = 1e-2

num_steps = 10000

for step in range(num_steps):
    batch = next(train_loader)
    model, opt_state, loss = train_step(model, batch, opt_state, l1_penalty, ortho_penalty, optimizer)
    
    if step % 50 == 0:
        print(f"Step {step}, Loss: {loss:.4f}")

No JIT

Step 0, Loss: 0.9761
Step 50, Loss: 0.9737
Step 100, Loss: 0.9722
Step 150, Loss: 0.9705
Step 200, Loss: 0.9687
Step 250, Loss: 0.9664
Step 300, Loss: 0.9630
Step 350, Loss: 0.9590
Step 400, Loss: 0.9531
Step 450, Loss: 0.9460
Step 500, Loss: 0.9373
Step 550, Loss: 0.9309
Step 600, Loss: 0.9174
Step 650, Loss: 0.9082
Step 700, Loss: 0.8965
Step 750, Loss: 0.8895
Step 800, Loss: 0.8821
Step 850, Loss: 0.8753
Step 900, Loss: 0.8647
Step 950, Loss: 0.8614
Step 1000, Loss: 0.8555

JIT on

Step 0, Loss: 1.0394
Step 50, Loss: 1.0389
Step 100, Loss: 1.0383
Step 150, Loss: 1.0381
Step 200, Loss: 1.0375
Step 250, Loss: 1.0372
Step 300, Loss: 1.0372
Step 350, Loss: 1.0367
Step 400, Loss: 1.0362
Step 450, Loss: 1.0357
Step 500, Loss: 1.0356
Step 550, Loss: 1.0353
Step 600, Loss: 1.0350
Step 650, Loss: 1.0346
Step 700, Loss: 1.0342
Step 750, Loss: 1.0341
Step 800, Loss: 1.0339
Step 850, Loss: 1.0333
Step 900, Loss: 1.0333
Step 950, Loss: 1.0331
Step 1000, Loss: 1.0325

@pschuh
Copy link
Collaborator

pschuh commented Aug 27, 2024

Hi, what is happening is that the layout of top_level_latent_codes is changing the top_k algorithm and it looks like to me that there are a lot of near-zero values in the top_level_latent_codes and the top_k algorithm is struggling to choose the top_k consistently. If I add a layout constraint into your code (or return top_level_latent_codes as an aux output) then everything becomes consistent between jit and without_jit.

def fwd_pass(model: Autoencoder, batch: jnp.ndarray):
    top_level_latent_codes = jax.vmap(model.encode)(batch)
    custom_dll = jax._src.layout.DeviceLocalLayout(major_to_minor=(0, 1)) # switch to (1, 0) to get 'jit_loss'
    s = jax.sharding.SingleDeviceSharding(jax.devices()[0])
    top_level_latent_codes = jax.lax.with_sharding_constraint(top_level_latent_codes, jax._src.layout.Layout(custom_dll, s)) 
    top_k_values, top_k_indices = jax.lax.top_k(top_level_latent_codes, 8) # Top-k gating
    x_hat = jax.vmap(model.top_k_decode)(top_k_indices, top_k_values)
    return x_hat, top_level_latent_codes, top_k_indices

While the layout does make them match, I suspect that this implies that the choice of the top_k in the case of ties or near zeros really matters and you should do something smarter here rather than just forcing it to choose a particular top_k algorithm via constraining the layouts.

@yashk2810
Copy link
Member

Slight update on the usage of layout API: Use the experimental endpoint: from jax.experimental.layout import Layout, DeviceLocalLayout

@muchanem
Copy link
Author

This makes sense, but I'm still curious as to why this happens with leaky_offset_relu but not otherwise. The offset means that with leaky_offset_relu turned on, there's no values in top_level_latent_codes below ~.03. With no offset (i.e. regular ReLU or no activation), values go as low as 1E-5 but the bad training doesn't happen (sort of the opposite of what you'd expect if near zeroes are the issue). There's not any ties, nor any near ties (the mean distance between the activation values among the nonzero elements is ~0.01). The only other difference between the various activations is that there's only ~340 nonzero elements in the offset case and ~2048 without an offset. All these numbers are from step 10 of optimization, but the same is true at initialization.

@pschuh
Copy link
Collaborator

pschuh commented Aug 27, 2024

You can debug these things by adding an aux output and print the intermediate results: @functools.partial(eqx.filter_value_and_grad, has_aux=True)

I think the simple explanation is that top_k_indices with relu + jit are ordered like so:

Array([[1793, 1338,    0, ...,    3,    4,    5],
       [1074,    0,    1, ...,    4,    5,    6],
       [   0,    1,    2, ...,    5,    6,    7],
       ...,
       [   0,    1,    2, ...,    5,    6,    7],
       [   0,    1,    2, ...,    5,    6,    7],
       [ 591, 1829,    0, ...,    3,    4,    5]], dtype=int32)

And everything else (nojit relu, jit norelu, nojit norelu) is more random like so:

  Array([[ 985, 1273,  579, ...,  386, 1756,  390],
       [1213,  402, 1652, ...,   59, 1665,  597],
       [1275, 1131, 1073, ..., 1615, 1034, 2030],
       ...,
       [ 570, 1460,  101, ...,    0,    1,    2],
       [1546, 1519, 1748, ..., 1644,  902, 1482],
       [1059,  430, 1778, ...,  732, 1034,  344]], dtype=int32)

Printing the intermediate values for leaky_relu, there are a lot of zeros in the results. Anyways, looking more closely, it might actually be a bug in how XLA does top-k. I'll look into that.

@muchanem
Copy link
Author

Thanks! Another thing I don't understand is why major-minor ordering helps regardless of batch size? Even if the batch dimension is larger than the latent dimension, your layout fix works--but isn't the batch dimension in that case major and latent minor? Is it just that the vmap means that the batch dimension will always be treated as the minor dimension since the ops are vectorized across it? (these runs were before I was adding in the aux output)

@muchanem
Copy link
Author

muchanem commented Sep 4, 2024

Was making modifications to code today and noticed that adding the line codes += 0 re-introduces the buggy top-k optimization

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants