Skip to content

Commit

Permalink
typing: Fixed mypy issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
KaleabTessera committed Nov 22, 2021
1 parent 3df706d commit 89ace75
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions mava/wrappers/flatland.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


import types as tp
import typing
from functools import partial
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union

Expand Down Expand Up @@ -145,7 +146,7 @@ def possible_agents(self) -> List[str]:
"""Return list of all possible agents."""
return self._possible_agents

def render(self, mode: str = "human") -> np.array:
def render(self, mode: str = "human") -> np.ndarray:
"""Renders the environment."""
if mode == "human":
show = True
Expand All @@ -170,9 +171,7 @@ def reset(self) -> dm_env.TimeStep:

self._reset_next_step = False
self._agents = self.possible_agents[:]
self._discounts = {
agent: np.dtype("float32").type(1.0) for agent in self.agents
}

observe, info = self._environment.reset()
observations = self._create_observations(
observe, info, self._environment.dones
Expand Down Expand Up @@ -240,13 +239,13 @@ def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep:
# of legal actions must be converted to a legal actions mask.
def _convert_observations(
self,
observes: Dict[str, Tuple[np.array, np.ndarray]],
observes: Dict[str, Tuple[np.ndarray, np.ndarray]],
dones: Dict[str, bool],
) -> Observation:
return convert_dm_compatible_observations(
observes,
observes, # type: ignore
dones,
self.observation_spec(),
self.observation_spaces, # type:ignore
self.env_done(),
self.possible_agents,
)
Expand All @@ -255,17 +254,17 @@ def _convert_observations(
# to be a tuple of the observation from the env and the agent info
def _collate_obs_and_info(
self, observes: Dict[int, np.ndarray], info: Dict[str, Dict[int, Any]]
) -> Dict[str, Tuple[np.array, np.ndarray]]:
observations: Dict[str, Tuple[np.array, np.ndarray]] = {}
) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
observations: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
observes = self.preprocessor(observes)
for agent, obs in observes.items():
agent_id = get_agent_id(agent)
agent_info = np.array(
[info[k][agent] for k in sort_str_num(info.keys())],
dtype=np.float32,
)
obs = (obs, agent_info) if self._include_agent_info else obs
observations[agent_id] = obs
obs = (obs, agent_info) if self._include_agent_info else obs # type: ignore # noqa: E501
observations[agent_id] = obs # type: ignore

return observations

Expand Down Expand Up @@ -324,7 +323,7 @@ def _pre_step(self) -> None:
def observation_spec(self) -> Dict[str, OLT]:
"""Return observation spec."""
observation_specs = {}
for agent in self.possible_agents:
for agent in self.agents:
observation_specs[agent] = OLT(
observation=tuple(
(
Expand Down Expand Up @@ -475,6 +474,7 @@ def min_gt(seq: Sequence, val: Any) -> Any:
idx -= 1
return min

@typing.no_type_check
def norm_obs_clip(
obs: np.ndarray,
clip_min: int = -1,
Expand Down Expand Up @@ -529,6 +529,7 @@ def _split_node_into_feature_groups(

return data, distance, agent_data

@typing.no_type_check
def _split_subtree_into_feature_groups(
node: Node, current_tree_depth: int, max_tree_depth: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -569,9 +570,9 @@ def split_tree_into_feature_groups(
sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(
tree.childs[direction], 1, max_tree_depth
)
data = np.concatenate((data, sub_data))
distance = np.concatenate((distance, sub_distance))
agent_data = np.concatenate((agent_data, sub_agent_data))
data = np.concatenate((data, sub_data)) # type: ignore
distance = np.concatenate((distance, sub_distance)) # type: ignore
agent_data = np.concatenate((agent_data, sub_agent_data)) # type: ignore

return data, distance, agent_data

Expand All @@ -592,7 +593,9 @@ def normalize_observation(
distance = norm_obs_clip(distance, normalize_to_range=True)
agent_data = np.clip(agent_data, -1, 1)
normalized_obs = np.array(
np.concatenate((np.concatenate((data, distance)), agent_data)),
np.concatenate(
(np.concatenate((data, distance)), agent_data)
), # type:ignore
dtype=np.float32,
)
return normalized_obs

0 comments on commit 89ace75

Please sign in to comment.