Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

segfault when using ensure_compile_time_eval #18831

Open
benjaminvatterj opened this issue Dec 5, 2023 · 1 comment
Open

segfault when using ensure_compile_time_eval #18831

benjaminvatterj opened this issue Dec 5, 2023 · 1 comment
Labels
bug Something isn't working

Comments

@benjaminvatterj
Copy link

Description

Hi! I have a model that requires precomputing a large matrix that is a model constant. To avoid having to compute it at every function call, I thought I'd make it evaluate at compile time. However, this results in a segmentation fault. The same does not happen if I do not ask for compile-time evaluation. I'm not sure if this is a bug or some deeper issue I don't fully understand. This is related to the following question I posed on the Q&A #18830.

Here's a minimal (a bit absurdly minimal admittedly) reproducable example

import jax.numpy as jnp
import jax

model_data = {
    'num_segments': 11544,
    'segment_id': jnp.repeat(jnp.arange(11544), 30),
    'val_id': jnp.tile(jnp.arange(10), 3 * 11544),
}


@jax.jit
def segfault_version(x):
    with jax.ensure_compile_time_eval():
        segment_id = model_data['segment_id']
        val_id = model_data['val_id']
        num_segments = model_data['num_segments']
        agg_mat = (jnp.arange(num_segments).reshape(-1, 1) == segment_id.reshape(1, -1))
    x = x[val_id]
    return agg_mat @ x

x = jnp.ones(5)
print(segfault_version(x))

the output is simply a segmentation fault. I'm on a Mac Studio M2 (no jax metal, because too many things are broken)

What jax/jaxlib version are you using?

0.4.20 0.4.20

Which accelerator(s) are you using?

CPU

Additional system info?

1.26.2 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:25:29) [Clang 14.0.6 ] uname_result(system='Darwin', node='Bvatter-Studio-23', release='23.1.0', version='Darwin Kernel Version 23.1.0: Mon Oct 9 21:28:45 PDT 2023; root:xnu-10002.41.9~6/RELEASE_ARM64_T6020', machine='arm64')

NVIDIA GPU info

No response

@benjaminvatterj benjaminvatterj added the bug Something isn't working label Dec 5, 2023
@benjaminvatterj
Copy link
Author

I'll add the observation that I get the same segmentation fault if I try to compute the agg_mat outside of the function and reference it within it.

import jax.numpy as jnp
import jax

model_data = {
    'num_segments': 11544,
    'segment_id': jnp.repeat(jnp.arange(11544), 30),
    'val_id': jnp.tile(jnp.arange(10), 3 * 11544),
}

segment_id = model_data['segment_id']
val_id = model_data['val_id']
num_segments = model_data['num_segments']
agg_mat = (jnp.arange(num_segments).reshape(-1, 1) == segment_id.reshape(1, -1))
model_data['agg_mat'] = agg_mat

@jax.jit
def slow_version2(x):
    x = x[val_id]
    return agg_mat @ x

This also results in a segmentation fault. Is there a better way of passing large constants? The key issue is that its matrix, so passing it as a static argument to jit also doesn't seem to work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant