Skip to content

Commit

Permalink
chore: fix matplotlib and jax typing issues (#216)
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a authored Jan 10, 2024
1 parent a695150 commit f6c9ef3
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions jumanji/specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,13 @@ def test_shape_element_type_error(self) -> None:

def test_dtype_type_error(self) -> None:
with pytest.raises(TypeError):
specs.Array((1, 2, 3), "32")
specs.Array((1, 2, 3), "32") # type: ignore

def test_scalar_shape(self) -> None:
specs.Array((), jnp.int32)

def test_string_dtype_error(self) -> None:
specs.Array((1, 2, 3), "int32")
specs.Array((1, 2, 3), "int32") # type: ignore

def test_dtype(self) -> None:
specs.Array((1, 2, 3), int)
Expand Down
2 changes: 1 addition & 1 deletion jumanji/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,4 @@ def get_valid_dtype(dtype: Union[jnp.dtype, type]) -> jnp.dtype:
Returns:
dtype converted to the correct type precision.
"""
return jnp.empty((), dtype).dtype
return jnp.empty((), dtype).dtype # type: ignore
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ chex>=0.1.3
dm-env>=1.5
gym>=0.22.0
jax>=0.2.26
matplotlib>=3.3.4
matplotlib~=3.7.4
numpy>=1.19.5
Pillow>=9.0.0
typing-extensions>=4.0.0

0 comments on commit f6c9ef3

Please sign in to comment.