Skip to content

Commit

Permalink
chore: bunch of minor changes and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Louay-Ben-nessir committed Nov 5, 2024
1 parent 659a837 commit 7deb75b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 17 deletions.
4 changes: 2 additions & 2 deletions mava/configs/env/smac_gym.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
defaults:
- _self_

env_name: Starcraft # Used for logging purposes.
env_name: SMAC # Used for logging purposes.
scenario:
name: smaclite
task_name: smaclite/2s3z-v0 # smaclite/ + ['10m_vs_11m-v0', '27m_vs_30m-v0', '3s5z_vs_3s6z-v0', '2s3z-v0', '3s5z-v0', 'MMM-v0', 'MMM2-v0', '2c_vs_64zg-v0', 'bane_vs_bane-v0', 'corridor-v0', '2s_vs_1sc-v0', '3s_vs_5z-v0']
Expand All @@ -16,7 +16,7 @@ eval_metric: episode_return
implicit_agent_id: False
# Whether or not to log the winrate of this environment. This should not be changed as not all
# environments have a winrate metric.
log_win_rate: False
log_win_rate: True

# Weather or not to sum the returned rewards over all of the agents.
use_shared_rewards: True
Expand Down
14 changes: 8 additions & 6 deletions mava/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,12 @@ def get_sebulba_eval_fn(
Args:
----
env: an environment that conforms to the mava environment spec.
act_fn: a function that takes in params, timestep, key and optionally a state
env_maker: A function to create the environment instances.
act_fn: A function that takes in params, timestep, key and optionally a state
and returns actions and optionally a state (see `EvalActFn`).
config: the system config.
absolute_metric: whether or not this evaluator calculates the absolute_metric.
config: The system config.
np_rng: Random number generator for seeding environment.
absolute_metric: Whether or not this evaluator calculates the absolute_metric.
This determines how many evaluation episodes it does.
"""
n_devices = jax.device_count()
Expand All @@ -240,8 +241,8 @@ def get_sebulba_eval_fn(
env = env_maker(config, n_parallel_envs)

act_fn = jax.jit(
act_fn, device=jax.devices("cpu")[0]
) # cpu so that we don't block actors/learners
act_fn, device=jax.local_devices()[config.arch.actor_device_ids[0]]
) # Evaluate using the first actor device

# Warnings if num eval episodes is not divisible by num parallel envs.
if eval_episodes % n_parallel_envs != 0:
Expand All @@ -264,6 +265,7 @@ def eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) -> Met
def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]:
"""Simulates `num_envs` episodes."""

# Generate a list of random seeds within the 32-bit integer range, using a seeded RNG.
seeds = np_rng.integers(np.iinfo(np.int32).max, size=n_parallel_envs).tolist()
ts = env.reset(seed=seeds)

Expand Down
2 changes: 1 addition & 1 deletion mava/utils/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
_gym_registry = {
"RobotWarehouse": GymWrapper,
"LevelBasedForaging": GymWrapper,
"Starcraft": SmacWrapper,
"SMAC": SmacWrapper,
}


Expand Down
19 changes: 11 additions & 8 deletions mava/wrappers/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

# needed to avoid host -> device transfers when calling TimeStep.last()
class StepType(IntEnum):
"""Coppy of Jumanji's step type but with numpy arrays"""
"""Copy of Jumanji's step type but with numpy arrays"""

FIRST = 0
MID = 1
Expand All @@ -69,7 +69,7 @@ def last(self) -> bool:

class GymWrapper(gymnasium.Wrapper):
"""Base wrapper for multi-agent gym environments.
This wrapper works out of the box for RobotWarehouse and level based foraging.
This wrapper works out of the box for RobotWarehouse and level-based foraging.
"""

def __init__(
Expand Down Expand Up @@ -100,7 +100,7 @@ def reset(

agents_view, info = self._env.reset()

info = {"actions_mask": self.get_actions_mask(info)}
info = {"action_mask": self.get_action_mask(info)}
if self.add_global_state:
info["global_obs"] = self.get_global_obs(agents_view)

Expand All @@ -109,7 +109,7 @@ def reset(
def step(self, actions: List) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]:
agents_view, reward, terminated, truncated, info = self._env.step(actions)

info = {"actions_mask": self.get_actions_mask(info)}
info = {"action_mask": self.get_action_mask(info)}
if self.add_global_state:
info["global_obs"] = self.get_global_obs(agents_view)

Expand All @@ -120,7 +120,7 @@ def step(self, actions: List) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]

return agents_view, reward, terminated, truncated, info

def get_actions_mask(self, info: Dict) -> NDArray:
def get_action_mask(self, info: Dict) -> NDArray:
if "action_mask" in info:
return np.array(info["action_mask"])
return np.ones((self.num_agents, self.num_actions), dtype=np.float32)
Expand All @@ -138,10 +138,11 @@ def step(self, actions: List) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]
actions = [int(action) for action in actions]

agents_view, reward, terminated, truncated, info = super().step(actions)
info["won_episode"] = info["battle_won"]

return agents_view, reward, terminated, truncated, info

def get_actions_mask(self, info: Dict) -> NDArray:
def get_action_mask(self, info: Dict) -> NDArray:
return np.array(self._env.unwrapped.get_avail_actions())


Expand Down Expand Up @@ -232,7 +233,7 @@ def modify_space(self, space: spaces.Space) -> spaces.Space:


class GymToJumanji:
"""Converts from the Gym API to the dm_env API."""
"""Converts from the Gym API to the Jumanji API."""

def __init__(self, env: gymnasium.vector.VectorEnv):
self.env = env
Expand Down Expand Up @@ -269,7 +270,7 @@ def _format_observation(

# (N, B, O) -> (B, N, O)
obs = np.array(obs).swapaxes(0, 1)
action_mask = np.stack(info["actions_mask"])
action_mask = np.stack(info["action_mask"])
obs_data = {"agents_view": obs, "action_mask": action_mask}

if "global_obs" in info:
Expand Down Expand Up @@ -301,6 +302,8 @@ def close(self) -> None:

# Copied form Gymnasium/blob/main/gymnasium/vector/async_vector_env.py
# Modified to work with multiple agents
# Note: The worker handles auto-resetting the environments.
# Each environment resets when all of its agents have either terminated or been truncated.
def async_multiagent_worker( # CCR001
index: int,
env_fn: Callable,
Expand Down

0 comments on commit 7deb75b

Please sign in to comment.