Skip to content

Commit

Permalink
feat: Added state info and stats to pz wrapper.
Browse files Browse the repository at this point in the history
  • Loading branch information
KaleabTessera committed Nov 15, 2021
1 parent 496e2e4 commit 5669793
Showing 1 changed file with 70 additions and 23 deletions.
93 changes: 70 additions & 23 deletions mava/wrappers/pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,10 @@ def reset(self) -> dm_env.TimeStep:
for agent in self.possible_agents
}

if type(observe) == tuple:
observe, env_extras = observe
if self._return_state_info and type(observe) == tuple:
observe, state = observe
else:
env_extras = {}
state = None

observations = self._convert_observations(
observe, {agent: False for agent in self.possible_agents}
Expand All @@ -419,16 +419,16 @@ def reset(self) -> dm_env.TimeStep:
}

# If we want state information and it has not been provided as part of
# env env_extras.
if (
self._return_state_info
and hasattr(self._environment, "get_state")
and "s_t" not in env_extras
):
state = self._environment.get_state()
env_extras["s_t"] = state

return parameterized_restart(rewards, self._discounts, observations), env_extras
# the env reset - e.g. smac.
if not state:
state = self.get_state()

if state is not None:
return parameterized_restart(rewards, self._discounts, observations), {
"s_t": state
}
else:
return parameterized_restart(rewards, self._discounts, observations)

def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep:
"""Steps in env.
Expand All @@ -450,29 +450,57 @@ def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep:
rewards = self._convert_reward(rewards)
observations = self._convert_observations(observations, dones)

if self._return_state_info and hasattr(self._environment, "get_state"):
state = self._environment.get_state()
else:
state = None
state = self.get_state()

if self.env_done():
self._step_type = dm_env.StepType.LAST
self._reset_next_step = True
# Terminal discount should be 0.0 as per dm_env
discount = {
agent: convert_np_type(self.discount_spec()[agent].dtype, 0.0)
for agent in self.possible_agents
}
else:
self._step_type = dm_env.StepType.MID
discount = self._discounts

timestep = dm_env.TimeStep(
observation=observations,
reward=rewards,
discount=self._discounts,
discount=discount,
step_type=self._step_type,
)

if state:
if state is not None:
return timestep, {"s_t": state}
else:
return timestep

def extra_spec(self) -> Dict[str, specs.BoundedArray]:
"""Function returns extra spec (format) of the env.
Returns:
Dict[str, specs.BoundedArray]: extra spec.
"""
if self._return_state_info and hasattr(
self._environment.unwrapped.env, "get_state"
):
minimum = list(self.observation_spec().values())[ # type:ignore
0
].observation._minimum[0]
maximum = list(self.observation_spec().values())[ # type:ignore
0
].observation._maximum[
0
] # type:ignore
state = self._environment.unwrapped.env.get_state()
return {
"s_t": specs.BoundedArray(
state.shape, np.float32, minimum=minimum, maximum=maximum
)
}
else:
return {}

def env_done(self) -> bool:
"""Check if env is done.
Expand Down Expand Up @@ -586,13 +614,32 @@ def discount_spec(self) -> Dict[str, specs.BoundedArray]:
)
return discount_specs

def extra_spec(self) -> Dict[str, specs.BoundedArray]:
"""Extra data spec.
def get_state(self) -> Optional[Dict]:
"""Retrieve state from environment.
Returns:
Dict[str, specs.BoundedArray]: spec for extra data.
environment state.
"""
return {}
if self._return_state_info and hasattr(
self._environment.unwrapped.env, "get_state"
):
state = self._environment.unwrapped.env.get_state()
else:
state = None
return state

def get_stats(self) -> Optional[Dict]:
"""Return extra stats to be logged.
Returns:
extra stats to be logged.
"""
if hasattr(self._environment, "get_stats"):
return self._environment.get_stats()
elif hasattr(self._environment.unwrapped.env, "get_stats"):
return self._environment.unwrapped.env.get_stats()
else:
return None

@property
def agents(self) -> List:
Expand Down

0 comments on commit 5669793

Please sign in to comment.