diff --git a/setup.py b/setup.py
index 7f57cd99a1..790eee53df 100644
--- a/setup.py
+++ b/setup.py
@@ -105,7 +105,7 @@ def has_ext_modules(self):
],
keywords='tensorflow probability statistics bayesian machine learning',
extras_require={ # e.g. `pip install tfp-nightly[jax]`
- 'jax': ['jax', 'jaxlib'],
+ 'jax': ['jax<=0.2.11', 'jaxlib<=0.1.64'],
'tfds': [TFDS_PACKAGE],
}
)