absl_py==1.3.0 chex==0.1.5 dm_haiku==0.0.9 jax==0.3.25 jaxlib==0.3.25 matplotlib==3.6.2 ml_collections==0.1.1 numpy==1.22.4 optax==0.1.4 rlax==0.1.4 scipy==1.9.3