-
Notifications
You must be signed in to change notification settings - Fork 92
/
types.py
166 lines (128 loc) · 5.66 KB
/
types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import Any, Callable, Dict, Generic, Optional, Protocol, Tuple, TypeVar, Union
import chex
import jumanji.specs as specs
from flax.core.frozen_dict import FrozenDict
from jumanji.types import TimeStep
from tensorflow_probability.substrates.jax.distributions import Distribution
from typing_extensions import NamedTuple, TypeAlias
Action: TypeAlias = chex.Array
Value: TypeAlias = chex.Array
Done: TypeAlias = chex.Array
HiddenState: TypeAlias = chex.Array
# Can't know the exact type of State.
State: TypeAlias = Any
Metrics: TypeAlias = Dict[str, chex.Array]
class MarlEnv(Protocol):
"""The API used by mava for environments.
A mava environment simply uses the Jumanji env API with a few added attributes.
For examples of how to add custom environments to Mava see `mava/wrappers/jumanji.py`.
Jumanji API docs: https://instadeepai.github.io/jumanji/#basic-usage
"""
num_agents: int
time_limit: int
action_dim: int
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
"""Resets the environment to an initial state.
Args:
key: random key used to reset the environment.
Returns:
state: State object corresponding to the new state of the environment,
timestep: TimeStep object corresponding the first timestep returned by the environment,
"""
...
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]:
"""Run one timestep of the environment's dynamics.
Args:
state: State object containing the dynamics of the environment.
action: Array containing the action to take.
Returns:
state: State object corresponding to the next state of the environment,
timestep: TimeStep object corresponding the timestep returned by the environment,
"""
...
@cached_property
def observation_spec(self) -> specs.Spec:
"""Returns the observation spec.
Returns:
observation_spec: a NestedSpec tree of spec.
"""
...
@cached_property
def action_spec(self) -> specs.Spec:
"""Returns the action spec.
Returns:
action_spec: a NestedSpec tree of spec.
"""
...
@cached_property
def reward_spec(self) -> specs.Array:
"""Describes the reward returned by the environment. By default, this is assumed to be a
single float.
Returns:
reward_spec: a `specs.Array` spec.
"""
...
@cached_property
def discount_spec(self) -> specs.BoundedArray:
"""Describes the discount returned by the environment. By default, this is assumed to be a
single float between 0 and 1.
Returns:
discount_spec: a `specs.BoundedArray` spec.
"""
...
@property
def unwrapped(self) -> Any:
"""Retuns: the innermost environment (without any wrappers applied)."""
...
class Observation(NamedTuple):
"""The observation that the agent sees.
agents_view: the agent's view of the environment.
action_mask: boolean array specifying, for each agent, which action is legal.
step_count: the number of steps elapsed since the beginning of the episode.
"""
agents_view: chex.Array # (num_agents, num_obs_features)
action_mask: chex.Array # (num_agents, num_actions)
step_count: Optional[chex.Array] = None # (num_agents, )
class ObservationGlobalState(NamedTuple):
"""The observation seen by agents in centralised systems.
Extends `Observation` by adding a `global_state` attribute for centralised training.
global_state: The global state of the environment, often a concatenation of agents' views.
"""
agents_view: chex.Array # (num_agents, num_obs_features)
action_mask: chex.Array # (num_agents, num_actions)
global_state: chex.Array # (num_agents, num_agents * num_obs_features)
step_count: Optional[chex.Array] = None # (num_agents, )
RNNObservation: TypeAlias = Tuple[Observation, Done]
RNNGlobalObservation: TypeAlias = Tuple[ObservationGlobalState, Done]
MavaObservation: TypeAlias = Union[Observation, ObservationGlobalState]
# `MavaState` is the main type passed around in our systems. It is often used as a scan carry.
# Types like: `LearnerState` (mava/systems/<system_name>/types.py) are `MavaState`s.
MavaState = TypeVar("MavaState")
MavaTransition = TypeVar("MavaTransition")
class ExperimentOutput(NamedTuple, Generic[MavaState]):
"""Experiment output."""
learner_state: MavaState
episode_metrics: Metrics
train_metrics: Metrics
LearnerFn = Callable[[MavaState], ExperimentOutput[MavaState]]
SebulbaLearnerFn = Callable[[MavaState, MavaTransition], Tuple[MavaState, Metrics]]
ActorApply = Callable[[FrozenDict, Observation], Distribution]
CriticApply = Callable[[FrozenDict, Observation], Value]
RecActorApply = Callable[
[FrozenDict, HiddenState, RNNObservation], Tuple[HiddenState, Distribution]
]
RecCriticApply = Callable[[FrozenDict, HiddenState, RNNObservation], Tuple[HiddenState, Value]]