Skip to content

Commit

Permalink
fix: smaclite win rate tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
Louay-Ben-nessir committed Nov 12, 2024
1 parent 0c4e83b commit 245aecc
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 12 deletions.
2 changes: 1 addition & 1 deletion mava/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]:

timesteps = jax.tree.map(lambda *x: np.stack(x), *timesteps)

metrics = timesteps.extras
metrics = timesteps.extras["episode_metrics"]
if config.env.log_win_rate:
metrics["won_episode"] = timesteps.extras["won_episode"]

Expand Down
2 changes: 1 addition & 1 deletion mava/systems/ppo/sebulba/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def act_fn(
timestep.reward,
log_prob,
obs_tpu,
timestep.extras,
timestep.extras["episode_metrics"],
)
)

Expand Down
2 changes: 1 addition & 1 deletion mava/utils/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
_gym_registry = {
"RobotWarehouse": UoeWrapper,
"LevelBasedForaging": UoeWrapper,
"SMAC": SmacWrapper,
"SMACLite": SmacWrapper,
}


Expand Down
26 changes: 17 additions & 9 deletions mava/wrappers/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def reset(

agents_view, info = self._env.reset()

info = {"action_mask": self.get_action_mask(info)}
info["action_mask"] = self.get_action_mask(info)
if self.add_global_state:
info["global_obs"] = self.get_global_obs(agents_view)

Expand All @@ -119,7 +119,7 @@ def reset(
def step(self, actions: List) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]:
agents_view, reward, terminated, truncated, info = self._env.step(actions)

info = {"action_mask": self.get_action_mask(info)}
info["action_mask"] = self.get_action_mask(info)
if self.add_global_state:
info["global_obs"] = self.get_global_obs(agents_view)

Expand All @@ -143,6 +143,13 @@ def get_global_obs(self, obs: NDArray) -> NDArray:
class SmacWrapper(UoeWrapper):
"""A wrapper that converts actions step to integers."""

def reset(
self, seed: Optional[int] = None, options: Optional[dict] = None
) -> Tuple[NDArray, Dict]:
agents_view, info = super().reset()
info["won_episode"] = info["battle_won"]
return agents_view, info

def step(self, actions: List) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]:
# Convert actions to integers before passing them to the environment
actions = [int(action) for action in actions]
Expand Down Expand Up @@ -181,9 +188,6 @@ def reset(
"is_terminal_step": False,
}

if "won_episode" in info:
metrics["won_episode"] = info["won_episode"]

info["metrics"] = metrics

return agents_view, info
Expand All @@ -199,8 +203,6 @@ def step(self, actions: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray, Di
"episode_length": self.running_count_episode_length,
"is_terminal_step": np.logical_or(terminated, truncated).all().item(),
}
if "won_episode" in info:
metrics["won_episode"] = info["won_episode"]

info["metrics"] = metrics

Expand Down Expand Up @@ -294,7 +296,12 @@ def _create_timestep(
) -> TimeStep:
observation = self._format_observation(obs, info)
# Filter out the masks and auxiliary data
extras = {key: value for key, value in info["metrics"].items() if key[0] != "_"}
extras = {}
extras["episode_metrics"] = {
key: value for key, value in info["metrics"].items() if key[0] != "_"
}
if "won_episode" in info:
extras["won_episode"] = info["won_episode"]

return TimeStep(
step_type=step_type, # type: ignore
Expand Down Expand Up @@ -346,7 +353,8 @@ def async_multiagent_worker( # CCR001
info,
) = env.step(data)
if np.logical_or(terminated, truncated).all():
observation, _ = env.reset()
observation, new_info = env.reset()
info["action_mask"] = new_info["action_mask"]

if shared_memory:
write_to_shared_memory(observation_space, index, observation, shared_memory)
Expand Down

0 comments on commit 245aecc

Please sign in to comment.