-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Insights: google/jax
September 6, 2024 – September 13, 2024
Overview
Could not load contribution data
Please try again later
2 Releases published by 1 person
-
jaxlib-v0.4.32 Jaxlib release v0.4.32
published
Sep 11, 2024 -
jax-v0.4.32 JAX release v0.4.32
published
Sep 11, 2024
89 Pull requests merged by 20 people
-
DOC: Improved documentation for jax.numpy.fromstring
#23553 merged
Sep 13, 2024 -
Fix a small typo for the condition of scipy.entr.
#23596 merged
Sep 13, 2024 -
[export] Improve the forward compatibility documentation
#23618 merged
Sep 13, 2024 -
Pallas Mosaic GPU now supports scratch buffers in SMEM
#23592 merged
Sep 13, 2024 -
ParsedPartitionSpec needs to check that it is the proper instance type
#23612 merged
Sep 13, 2024 -
Remove the device assignment check in _resolve_in_shardings since that's historical and not needed anymore
#23609 merged
Sep 13, 2024 -
Reverts 255c30303d32e7473262b2e35348175c87e4348f
#23610 merged
Sep 13, 2024 -
shmap in_spec None shouldn't require hashability
#23605 merged
Sep 13, 2024 -
[Mosaic TPU] Propagate the memory space change for memref bitcast and reshape.
#23562 merged
Sep 13, 2024 -
[Mosaic TPU] Better error message when shape of memref bitcast is invalid.
#23561 merged
Sep 12, 2024 -
Fix a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs.
#23581 merged
Sep 12, 2024 -
Fixed a static type error in Mosaic GPU lowering
#23589 merged
Sep 12, 2024 -
DOC: Improve docs for
jax.numpy
:diff
andediff1d
#23593 merged
Sep 12, 2024 -
Redisable one more complex arctan test.
#23603 merged
Sep 12, 2024 -
Redisable complex arctan tests on older jaxlibs.
#23602 merged
Sep 12, 2024 -
Lower the minimum jaxlib version to 0.4.31, since 0.4.32 was yanked from pypi.
#23601 merged
Sep 12, 2024 -
DOC: Improve doc for
jnp.resize
#23586 merged
Sep 12, 2024 -
Delete jax.xla_computation since it's been 3 months since it was deprecated.
#23584 merged
Sep 12, 2024 -
Enable more layout tests on GPUs
#23591 merged
Sep 12, 2024 -
Refactor gpusolver kernel definitions into separate build target.
#23509 merged
Sep 12, 2024 -
Update complex arctan and arctanh accuracy tests
#23587 merged
Sep 12, 2024 -
[Pallas] Fix array indexing error when dimension size is not a multiple of stride
#23534 merged
Sep 12, 2024 -
DOC: Improve docs for
jax.numpy
arithmetic comparison operations.#23572 merged
Sep 11, 2024 -
Flip jax_pmap_no_rank_reduction by default to True.
#23322 merged
Sep 11, 2024 -
Enable the enhanced TPU launch barrier on all TPU generations.
#23411 merged
Sep 11, 2024 -
Disable complex arctan test in mnegj.real part of plane.
#23580 merged
Sep 11, 2024 -
Update version numbers after v0.4.32 release.
#23578 merged
Sep 11, 2024 -
Add link to XLA documentation for building JAX with CUDA from sources.
#23577 merged
Sep 11, 2024 -
Remove custom partitioning pointer from pre-compiled HLO
#22702 merged
Sep 11, 2024 -
Fix typo in
jax.typing
module doc#23566 merged
Sep 11, 2024 -
[Pallas] Fix TPU large array indexing tests.
#23524 merged
Sep 11, 2024 -
Improve docs for jnp.stack, jnp.concat, & related functions
#23546 merged
Sep 11, 2024 -
trying on another readme header
#23545 merged
Sep 11, 2024 -
Make sure to call the superclass' __init__() on a newly created instance in PositionalSharding._remake().
#23570 merged
Sep 11, 2024 -
Relax test tolerance in pinv test to fix a CI failure on Windows CPU.
#23567 merged
Sep 11, 2024 -
[NVIDIA] Relax the requirement for providing both
query_seq_lengths
andkey_value_seq_lengths
#23415 merged
Sep 11, 2024 -
Internal change
#23571 merged
Sep 11, 2024 -
Update users of jax.tree.map() to be more careful about how they handle Nones.
#23558 merged
Sep 11, 2024 -
[Pallas] Implement tiled and swizzled Memref loads for Mosaic GPU via "GPUBlockSpec"
#23547 merged
Sep 11, 2024 -
[Pallas] Disable win32 gpu_ops_test.
#23552 merged
Sep 10, 2024 -
Update the docs for conv_general_dilated to clarify 'W' 'H'.
#23549 merged
Sep 10, 2024 -
Reverts 5e4250e64bb415be94ddc8a80dba6083a6a4123a
#23554 merged
Sep 10, 2024 -
Removed a sneaky comma in Pallas Mosaic GPU lowering
#23551 merged
Sep 10, 2024 -
Prepare for jax 0.4.32 release
#23550 merged
Sep 10, 2024 -
Clean up JAX backend for all backends to avoid dangling PyClient references.
#23206 merged
Sep 10, 2024 -
[Pallas] Simplify lowering and fix the test for
lax.erf_inv_p
#23531 merged
Sep 10, 2024 -
DOC: Improved documentation for jax.numpy.kron and jax.numpy.outer
#23443 merged
Sep 10, 2024 -
Add the FFI functions and tutorial to the changelog.
#23478 merged
Sep 10, 2024 -
DOC: Improve docs for
jax.numpy
:float_power
andnextafter
#23536 merged
Sep 10, 2024 -
DOC: Improve docs for
jnp.trim_zeros
#23540 merged
Sep 10, 2024 -
Estimate the amount of required scratch SMEM automatically in Pallas Mosaic GPU lowering
#23538 merged
Sep 10, 2024 -
Disable two lax_scipy_test testcases that fail on TPU v6e.
#23541 merged
Sep 10, 2024 -
Make JAX extract the mesh from an
AUTO
in/out sharding.#23515 merged
Sep 10, 2024 -
Improve the docstring for
jax.Array.copy_to_host_async
.#23529 merged
Sep 9, 2024 -
Create optimal order for v5e:8 devices which is [0, 1, 2, 3, 7, 6, 5, 4]
#23528 merged
Sep 9, 2024 -
Change nightly install commands to include all packages.
#23526 merged
Sep 9, 2024 -
Disable a shard_map test case that fails on TPU v5e.
#23525 merged
Sep 9, 2024 -
[Pallas] Fully skip GPU attention tests on win32.
#23518 merged
Sep 9, 2024 -
doc: update examples of deprecation timelines
#23424 merged
Sep 9, 2024 -
array API: use latest array-api-tests commit hash
#23514 merged
Sep 9, 2024 -
Add array_api intersphinx & document jnp.permute_dims
#23512 merged
Sep 9, 2024 -
More explicit docstring on the limitations of
spsolve
#23511 merged
Sep 9, 2024 -
Fixed validation in
jax.debug.format
#23502 merged
Sep 9, 2024 -
Update TPU test configuration tags.
#23513 merged
Sep 9, 2024 -
[Pallas] Update Pallas docs with new figures and TPUCompilerParams
#23484 merged
Sep 9, 2024 -
Improve docs for jnp.split & related APIs
#23459 merged
Sep 9, 2024 -
DOC: Better docs for
jax.numpy
:minimum
andmaximum
#23469 merged
Sep 9, 2024 -
Disable cudnn_fusion_test from CI.
#23500 merged
Sep 9, 2024 -
[pallas] Fix test failures on Windows.
#23497 merged
Sep 9, 2024 -
Relax some test tolerances in for_loop_test.py.
#23499 merged
Sep 8, 2024 -
Skip symmetric product test on older jaxlibs.
#23498 merged
Sep 8, 2024 -
Fix triton capi_objects target to depend on MLIR CAPIIRObjects bazel
#23462 merged
Sep 8, 2024 -
Adds failing test for https://github.com/google/jax/issues/23476.
#23494 merged
Sep 8, 2024 -
tweak readme title to be more about what jax can do for you, dear user
#23490 merged
Sep 7, 2024 -
Removed dead code from
add_jaxvals
#23481 merged
Sep 7, 2024 -
Fix tolerances for failing linalg tests.
#23483 merged
Sep 6, 2024 -
Update FFI target name for
syrk
operation to be consistent with other kernels.#23482 merged
Sep 6, 2024 -
Make pltpu key derivation more robust.
#23241 merged
Sep 6, 2024 -
Port GPU kernel for symmetric eigendecomposition to GPU.
#23190 merged
Sep 6, 2024 -
Implement symmetric_product() to produce a symmetric matrix:
C = alpha * X @ X.T + beta * C
#23062 merged
Sep 6, 2024 -
[host_callback] Fix type promotion error
#23477 merged
Sep 6, 2024 -
[NFC] Remove unused argument, fix help string.
#23369 merged
Sep 6, 2024 -
[shape_poly] Improve handling of equality shape constraints
#23470 merged
Sep 6, 2024 -
Added a bit more error checking to Pallas Mosaic GPU pipelining logic
#23472 merged
Sep 6, 2024 -
Fix expression parentheses in shape polymorphism docs for division of symbolic dimensions.
#23454 merged
Sep 6, 2024
32 Pull requests opened by 11 people
-
Removed `named_shape=` from `core.ShapedArray`
#23473 opened
Sep 6, 2024 -
Support JAX shard map lowering with Shardy enabled.
#23480 opened
Sep 6, 2024 -
Add a small number of utils to make MLIR easier and nicer to use.
#23485 opened
Sep 6, 2024 -
Determine symmetric linear op fro CG from abstract output dtype
#23486 opened
Sep 6, 2024 -
Fix Pyright issue with the type of NotMapped
#23492 opened
Sep 7, 2024 -
Fix and reenable cudnn_fusion_test.
#23505 opened
Sep 9, 2024 -
Pallas Mosaic GPU no longer pipelines stores for outputs which are invariant wrt sequential dimensions
#23510 opened
Sep 9, 2024 -
Do a single mbarrier.arrive.expect_tx per fetch in Pallas Mosaic GPU
#23516 opened
Sep 9, 2024 -
`NotMapped` is now a proper class instead of an alias to `types.NoneType`
#23527 opened
Sep 9, 2024 -
[pallas:mosaic_gpu] Fragmented array debug printing.
#23539 opened
Sep 10, 2024 -
DOC: Improved documentation for numpy.frombuffer
#23556 opened
Sep 10, 2024 -
Update integer overflow check to avoid numpy 2.0 warning
#23557 opened
Sep 10, 2024 -
Deprecate passing `NdArrays with ndim != 1` and `non-arraylike` inputs to `jnp.trim_zeros`
#23563 opened
Sep 11, 2024 -
Ported a few changes to FragmentArray by cperivol@
#23568 opened
Sep 11, 2024 -
Added a new `allow_approx` flag to Mosaic GPU params in Pallas
#23569 opened
Sep 11, 2024 -
Add support for setting a dot product "algorithm" for lax.dot_general.
#23574 opened
Sep 11, 2024 -
Remove code that existed to support jaxlib < 0.4.32.
#23582 opened
Sep 11, 2024 -
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
#23583 opened
Sep 11, 2024 -
Add float8_e4m3 type support
#23585 opened
Sep 12, 2024 -
Added a new primitive from copying GMEM->SMEM in Pallas Mosaic GPU kernels
#23595 opened
Sep 12, 2024 -
Disable flaky tsan tests temporarily.
#23598 opened
Sep 12, 2024 -
Allow Jax libraries to register atexit handlers before or after `clear_backends()`.
#23604 opened
Sep 12, 2024 -
Fix some layout test failures on gpu backend
#23606 opened
Sep 12, 2024 -
[ROCm] Fix invalid repo url for EL path
#23607 opened
Sep 12, 2024 -
Internal breakage.
#23608 opened
Sep 13, 2024 -
DOC: Update gpu_performance_tips.md
#23615 opened
Sep 13, 2024 -
[Take 2] Generalize global jit cpp cache keys so we can add more keys than the current donate_argnums.
#23617 opened
Sep 13, 2024 -
[callbacks] Relax the usage of io_callback in AD
#23619 opened
Sep 13, 2024 -
[io_callback] Adds test for io_callback being used inside custom partitioning
#23620 opened
Sep 13, 2024 -
Updated :gpu_test configuration
#23621 opened
Sep 13, 2024 -
[pallas] Clean up forward-compatibility conditionals in Pallas lowering
#23622 opened
Sep 13, 2024
13 Issues closed by 10 people
-
`host_callback.call` fails on multi-gpu machine
#5577 closed
Sep 13, 2024 -
[Pallas] Array indexing error when dimension size is not a multiple of stride
#23532 closed
Sep 12, 2024 -
Always breaks compilation
#23503 closed
Sep 11, 2024 -
⚠️ Nightly upstream-dev CI failed ⚠️
#23496 closed
Sep 11, 2024 -
Capture of small arrays causes JIT execution 10x longer
#23479 closed
Sep 10, 2024 -
Allow compilation cache to be saved from process indices that are not process index 0
#18819 closed
Sep 10, 2024 -
[Pallas GPU] The behavior of `jnp.sign(jnp.nan)` in Pallas GPU does not match that of JAX
#23504 closed
Sep 9, 2024 -
jax.debug.print() fails when passing value formatting (which docs say is supported)
#23475 closed
Sep 9, 2024 -
Force a physical axis layout rearrangement
#23471 closed
Sep 9, 2024 -
debug.pring doesn't work
#23501 closed
Sep 9, 2024 -
Jax numpy reduceat function throws error
#23493 closed
Sep 8, 2024 -
Pallas flash attention fails on GPU
#23495 closed
Sep 8, 2024 -
Error exporting np.where when unrelated dimension variable is 0
#23456 closed
Sep 6, 2024
20 Issues opened by 18 people
-
Orthogonal Initializer raises gpusolverDnCreate(&handle) failed: cuSolver internal error
#23616 opened
Sep 13, 2024 -
io_callback does not work with custom_vjp
#23614 opened
Sep 13, 2024 -
make jax.debug.print work with non-jax types
#23600 opened
Sep 12, 2024 -
Make jax.distributed timeouts configurable via jax.config
#23599 opened
Sep 12, 2024 -
Allow tuple inputs to `scatter_dimension` in `jax.lax.psum_scatter`
#23594 opened
Sep 12, 2024 -
Wrong results on CPU since 0.4.32
#23590 opened
Sep 12, 2024 -
`closure_convert` doesn't work on VJPs
#23588 opened
Sep 12, 2024 -
Sparse matrix multiplication and output nse values
#23576 opened
Sep 11, 2024 -
Error compile: --bazel_options=--repo_env=LOCAL_CUDA_PATH="${CUDA_HOME}" issues with clang & gcc
#23575 opened
Sep 11, 2024 -
Error compile: make_batch_pointers.cu.cc failed: undeclared inclusion(s)
#23573 opened
Sep 11, 2024 -
[ROCM] Multi-device reduction causes segfault
#23565 opened
Sep 11, 2024 -
Unable to build jaxlib with debug symbols for GPU
#23564 opened
Sep 11, 2024 -
`fp32` or even `fp64` support for cudnn flashattention
#23560 opened
Sep 11, 2024 -
jax.debug.breakpoint gives UnexpectedTracerError when used with jax.lax.cond
#23555 opened
Sep 10, 2024 -
Unused vmap GPU memory allocation causes RESOURCE_EXHAUSTED for versions >0.4.14
#23548 opened
Sep 10, 2024 -
All-gather performed in fp32 instead of bf16 despite explicit datatype conversion
#23543 opened
Sep 10, 2024 -
Project Plan Github Like Pytorch for features and next releases?
#23537 opened
Sep 10, 2024 -
add support for QR decomposition gradient calculation for wide matrices (rows < columns)
#23533 opened
Sep 10, 2024 -
[ROCM] x64 mode crashes with "redzone_checker with block dimensions: 1024x1x1: hipError_t"
#23506 opened
Sep 9, 2024 -
vmap(SPMD axis)/shmap/(vmap with capture) pattern breaks batching
#23476 opened
Sep 6, 2024
32 Unresolved conversations
Sometimes conversations happen on old items that aren’t yet closed. Here is a list of all the Issues and Pull Requests with unresolved conversations.
-
[NVIDIA] Add the sliding window to `jax.nn.dot_product_attention`
#23247 commented on
Sep 13, 2024 • 13 new comments -
Adding `JAX_LOGGING_LEVEL` configuration option
#23468 commented on
Sep 13, 2024 • 9 new comments -
A few more pyright fixes
#23453 commented on
Sep 6, 2024 • 2 new comments -
Lower tan to StableHLO instead of CHLO.
#23261 commented on
Sep 13, 2024 • 2 new comments -
Create JAX Advanced Tutorials (section, move/create docs)
#23165 commented on
Sep 11, 2024 • 2 new comments -
Revert to nvtx_kern_sum to get correct profiling info and adding post processing logic
#23467 commented on
Sep 6, 2024 • 0 new comments -
CHLO -> StableHLO : use TanOp and StablehloCreateCompatibilityExpanderPass
#23408 commented on
Sep 12, 2024 • 0 new comments -
[Pallas] Fix integer array indexing
#23406 commented on
Sep 12, 2024 • 0 new comments -
Activate Hessenberg Decomposition to XLA's FFI
#23333 commented on
Sep 6, 2024 • 0 new comments -
Stackless [WIP]
#23299 commented on
Sep 12, 2024 • 0 new comments -
[WIP] Add a jaxlib_wheel rule to jax.bzl.
#23276 commented on
Sep 9, 2024 • 0 new comments -
Add a memory saving index rewrite step to vmap with ragged inputs over pallas_call.
#23252 commented on
Sep 13, 2024 • 0 new comments -
[JAX] add support for gather/scatter batching dims following the new attributes in stablehlo.
#23074 commented on
Sep 12, 2024 • 0 new comments -
When caching is enabled, also enable XLA caching features as well
#22899 commented on
Sep 11, 2024 • 0 new comments -
[Pallas TPU] Add lowering for `lax.erf_inv` 64 bit
#22310 commented on
Sep 11, 2024 • 0 new comments -
Port Tridiagonal Reduction to XLA's FFI
#21614 commented on
Sep 6, 2024 • 0 new comments -
Support Hessian of gamma-distributed samples
#21432 commented on
Sep 9, 2024 • 0 new comments -
Support (nonsymmetric) np.linalg.eig on GPU
#1259 commented on
Sep 13, 2024 • 0 new comments -
jax.jit(function).lower(x, y).cost_analysis() gives NONE while running with GPU as the device
#22713 commented on
Sep 12, 2024 • 0 new comments -
Error Building Jaxlib v0.4.30 on Jetson Orin
#22155 commented on
Sep 12, 2024 • 0 new comments -
``scipy.linalg.solve_banded`` and ``scipy.linalg.ldl``
#12779 commented on
Sep 11, 2024 • 0 new comments -
Rotation.concatenate does not work for two single rotations
#23202 commented on
Sep 11, 2024 • 0 new comments -
`jnp.arange` does not permit dynamic shaped arrays
#23423 commented on
Sep 10, 2024 • 0 new comments -
Deprecate jax.experimental.host_callback in favor of JAX external callbacks
#20385 commented on
Sep 10, 2024 • 0 new comments -
jax.numpy.digitize doesn't work with shape polymorphism
#22489 commented on
Sep 9, 2024 • 0 new comments -
AMD+Jax Docker container incompatible with Singularity
#19138 commented on
Sep 9, 2024 • 0 new comments -
Gradient of SVD with degenerate singular values becomes NaN
#2311 commented on
Sep 9, 2024 • 0 new comments -
_isolve tacitly checks for symmetric linear operator wrongly
#23403 commented on
Sep 8, 2024 • 0 new comments -
Functionality to chunk `vmap`.
#11319 commented on
Sep 8, 2024 • 0 new comments -
vmap of cond's predicate results in select, leading to unexpected compute/memory use
#8409 commented on
Sep 8, 2024 • 0 new comments -
jax.lax.cond() signficantly slower when called inside jax.lax.scan()
#5986 commented on
Sep 8, 2024 • 0 new comments -
`jax.nn.dot_product_attention` does not respect `key_value_seq_lengths`
#23349 commented on
Sep 6, 2024 • 0 new comments