forked from instadeepai/Mava
-
Notifications
You must be signed in to change notification settings - Fork 0
/
specs.py
78 lines (68 loc) · 2.95 KB
/
specs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Objects which specify the input/output spaces of an environment from the perspective
of each agent in a multi-agent environment.
This module exposes the same spec classes as `dm_env` as well as providing an
additional `EnvironmentSpec` class which collects all of the specs for a given
environment. An `EnvironmentSpec` instance can be created directly or by using
the `make_environment_spec` helper given a `dm_env.Environment` instance.
"""
from typing import Dict, List
import dm_env
from acme.specs import EnvironmentSpec
# TODO Why use this class to define specs, when you can just update
# the specs on the wrappers themselves
class MAEnvironmentSpec:
def __init__(
self,
environment: dm_env.Environment,
specs: Dict[str, EnvironmentSpec] = None,
extra_specs: Dict = None,
):
if not specs:
specs = self._make_ma_environment_spec(environment)
else:
self.extra_specs = extra_specs
self._keys = list(sorted(specs.keys()))
self._specs = {key: specs[key] for key in self._keys}
def _make_ma_environment_spec(
self, environment: dm_env.Environment
) -> Dict[str, EnvironmentSpec]:
"""Returns an `EnvironmentSpec` describing values used by
an environment for each agent."""
specs = {}
observation_specs = environment.observation_spec()
action_specs = environment.action_spec()
reward_specs = environment.reward_spec()
discount_specs = environment.discount_spec()
self.extra_specs = environment.extra_spec()
for agent in environment.possible_agents:
specs[agent] = EnvironmentSpec(
observations=observation_specs[agent],
actions=action_specs[agent],
rewards=reward_specs[agent],
discounts=discount_specs[agent],
)
return specs
def get_extra_specs(self) -> Dict[str, EnvironmentSpec]:
return self.extra_specs # type: ignore
def get_agent_specs(self) -> Dict[str, EnvironmentSpec]:
return self._specs
def get_agent_type_specs(self) -> Dict[str, EnvironmentSpec]:
specs = {}
agent_types = list({agent.split("_")[0] for agent in self._keys})
for agent_type in agent_types:
specs[agent_type] = self._specs[f"{agent_type}_0"]
return specs
def get_agent_ids(self) -> List[str]:
return self._keys
def get_agent_types(self) -> List[str]:
return list({agent.split("_")[0] for agent in self._keys})
def get_agents_by_type(self) -> Dict[str, List[str]]:
agents_by_type: Dict[str, List[str]] = {}
agents_ids = self.get_agent_ids()
agent_types = self.get_agent_types()
for agent_type in agent_types:
agents_by_type[agent_type] = []
for agent in agents_ids:
if agent_type in agent:
agents_by_type[agent_type].append(agent)
return agents_by_type