Skip to content

Commit

Permalink
fix: add unwrapped method to gigastep and jaxmarl wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a committed Oct 25, 2024
1 parent 3043a9d commit 41467f8
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 6 deletions.
4 changes: 3 additions & 1 deletion mava/networks/torsos.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def setup(self) -> None:
def __call__(self, observation: chex.Array) -> chex.Array:
"""Forward pass."""
x = observation
for channel, kernel, stride in zip(self.channel_sizes, self.kernel_sizes, self.strides):
for channel, kernel, stride in zip(
self.channel_sizes, self.kernel_sizes, self.strides, strict=False
):
x = nn.Conv(channel, (kernel, kernel), (stride, stride))(x)
if self.use_layer_norm:
x = nn.LayerNorm(use_scale=False)(x)
Expand Down
3 changes: 1 addition & 2 deletions mava/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import chex
import jumanji.specs as specs
from flax.core.frozen_dict import FrozenDict
from jumanji import Environment
from jumanji.types import TimeStep
from tensorflow_probability.substrates.jax.distributions import Distribution
from typing_extensions import NamedTuple, TypeAlias
Expand Down Expand Up @@ -103,7 +102,7 @@ def discount_spec(self) -> specs.BoundedArray:
...

@property
def unwrapped(self) -> Environment:
def unwrapped(self) -> Any:
"""Retuns: the innermost environment (without any wrappers applied)."""
...

Expand Down
2 changes: 1 addition & 1 deletion mava/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def log_dict(self, data: Dict, step: int, eval_step: int, event: LogEvent) -> No
for value in data.values():
value = value.item() if isinstance(value, jax.Array) else value
values.append(f"{value:.3f}" if isinstance(value, float) else str(value))
log_str = " | ".join([f"{k}: {v}" for k, v in zip(keys, values)])
log_str = " | ".join([f"{k}: {v}" for k, v in zip(keys, values, strict=False)])

self.logger.info(
f"{colour}{Style.BRIGHT}{event.value.upper()} - {log_str}{Style.RESET_ALL}"
Expand Down
4 changes: 4 additions & 0 deletions mava/wrappers/gigastep.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,7 @@ def adversary_policy(self, obs: Array, state: Tuple[Dict, Dict], key: PRNGKey) -
"""
return jax.random.randint(key, (obs.shape[0],), 0, self.action_dim)

@property
def unwrapped(self) -> GigastepEnv:
return self._env
4 changes: 4 additions & 0 deletions mava/wrappers/jaxmarl.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ def discount_spec(self) -> specs.BoundedArray:
name="discount",
)

@property
def unwrapped(self) -> MultiAgentEnv:
return self._env

@abstractmethod
def action_mask(self, wrapped_env_state: Any) -> Array:
"""Get action mask for each agent."""
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dynamic=["version", "dependencies", "optional-dependencies"]
license={file="LICENSE"}
description="Distributed Multi-Agent Reinforcement Learning in JAX."
readme ="README.md"
requires-python=">=3.9" # would be nice to upgrade this and jumanji at some point
requires-python=">=3.11"
keywords=["multi-agent", "reinforcement learning", "python", "jax", "anakin", "sebulba"]
classifiers=[
"Environment :: Console",
Expand All @@ -21,6 +21,7 @@ classifiers=[
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
Expand All @@ -36,7 +37,7 @@ optional-dependencies={dev={file=["requirements/requirements-dev.txt"]}}
"Bug Tracker"="https://github.com/instadeep/Mava/issues"

[tool.mypy]
python_version="3.10"
python_version="3.11"
warn_redundant_casts=true
disallow_untyped_defs=true
strict_equality=true
Expand Down

0 comments on commit 41467f8

Please sign in to comment.