Skip to content

Commit

Permalink
merge: Merge remote-tracking branch 'origin/feature/eval-intervals' i…
Browse files Browse the repository at this point in the history
…nto feature/smac-env-upgrades
  • Loading branch information
KaleabTessera committed Nov 19, 2021
2 parents 98fae98 + 52cd450 commit 98f7e99
Show file tree
Hide file tree
Showing 52 changed files with 494 additions and 144 deletions.
2 changes: 2 additions & 0 deletions Dockerfile.tf
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ ENV TF_CPP_MIN_LOG_LEVEL=3
COPY . /home/app/mava
RUN python -m pip uninstall -y enum34
RUN python -m pip install --upgrade pip
# For box2d
RUN apt-get install swig -y
## Install core dependencies.
RUN python -m pip install -e .[tf,reverb,launchpad]
## Optional install for screen recording.
Expand Down
2 changes: 1 addition & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ We include a number of systems running on continuous control tasks.
a MADDPG system running on the continuous action space simple_spread MPE environment.
- *Feedforward*:
- [decentralised](debugging/simple_spread/feedforward/decentralised/run_maddpg.py), [decentralised record agents](debugging/simple_spread/feedforward/decentralised/run_maddpg_record.py) (***Example recording agents acting in the environment***), [decentralised scaling](debugging/simple_spread/feedforward/decentralised/run_maddpg_scaling.py) (***Example scaling to 4 executors***), [decentralised custom loggers](debugging/simple_spread/feedforward/decentralised/run_maddpg_custom_logging.py) (***Example using custom logging***), [decentralised lr scheduling](debugging/simple_spread/feedforward/decentralised/run_maddpg_lr_schedule.py) (***Example using lr schedule***),
[centralised](debugging/simple_spread/feedforward/centralised/run_maddpg.py), [networked](debugging/simple_spread/feedforward/networked/run_maddpg.py) (***Example using a fully-connected, networked architecture***), [networked with custom architecture](debugging/simple_spread/feedforward/networked/run_maddpg_custom_network.py) (***Example using a custom, sparse, networked architecture***) and [state_based](debugging/simple_spread/feedforward/state_based/run_maddpg.py) .
[centralised](debugging/simple_spread/feedforward/centralised/run_maddpg.py), [networked](debugging/simple_spread/feedforward/networked/run_maddpg.py) (***Example using a fully-connected, networked architecture***), [networked with custom architecture](debugging/simple_spread/feedforward/networked/run_maddpg_custom_network.py) (***Example using a custom, sparse, networked architecture***) and [state_based](debugging/simple_spread/feedforward/state_based/run_maddpg.py) and [decentralised evaluator intervals](debugging/simple_spread/feedforward/decentralised/run_mad4pg_evaluator_interval.py)(***Example running the evaluation loop at intervals***)
- *Recurrent*
- [decentralised](debugging/simple_spread/recurrent/decentralised/run_maddpg.py) and [state_based](debugging/simple_spread/recurrent/state_based/run_maddpg.py).

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# python3
# Copyright 2021 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example running MAD4PG on debug MPE environments, using an evaluation schedule."""
import functools
from datetime import datetime
from typing import Any

import launchpad as lp
import sonnet as snt
from absl import app, flags

from mava.systems.tf import mad4pg
from mava.utils import lp_utils
from mava.utils.environments import debugging_utils
from mava.utils.loggers import logger_utils

FLAGS = flags.FLAGS
flags.DEFINE_string(
"env_name",
"simple_spread",
"Debugging environment name (str).",
)
flags.DEFINE_string(
"action_space",
"continuous",
"Environment action space type (str).",
)
flags.DEFINE_string(
"mava_id",
str(datetime.now()),
"Experiment identifier that can be used to continue experiments.",
)
flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.")


def main(_: Any) -> None:

# Environment.
environment_factory = functools.partial(
debugging_utils.make_environment,
env_name=FLAGS.env_name,
action_space=FLAGS.action_space,
)

# Networks.
network_factory = lp_utils.partial_kwargs(
mad4pg.make_default_networks,
vmin=-10,
vmax=50,
)

# Checkpointer appends "Checkpoints" to checkpoint_dir.
checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}"

# Log every [log_every] seconds.
log_every = 10
logger_factory = functools.partial(
logger_utils.make_logger,
directory=FLAGS.base_dir,
to_terminal=True,
to_tensorboard=True,
time_stamp=FLAGS.mava_id,
time_delta=log_every,
)

# Distributed program.
program = mad4pg.MAD4PG(
environment_factory=environment_factory,
network_factory=network_factory,
logger_factory=logger_factory,
num_executors=1,
policy_optimizer=snt.optimizers.Adam(learning_rate=1e-4),
critic_optimizer=snt.optimizers.Adam(learning_rate=1e-4),
checkpoint_subpath=checkpoint_dir,
max_gradient_norm=40.0,
# Run evaluation loop every 100 executor_steps.
evaluator_interval={"executor_steps": 100},
).build()

# Ensure only trainer runs on gpu, while other processes run on cpu.
local_resources = lp_utils.to_device(
program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"]
)

# Launch.
lp.launch(
program,
lp.LaunchType.LOCAL_MULTI_PROCESSING,
terminal="current_terminal",
local_resources=local_resources,
)


if __name__ == "__main__":
app.run(main)
6 changes: 3 additions & 3 deletions mava/adders/reverb/transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,15 @@ def __init__(

def _write(self) -> None:
# Convenient getters for use in tree operations.
def get_first(x: np.array) -> np.array:
def get_first(x: np.ndarray) -> np.ndarray:
return x[self._first_idx]

def get_last(x: np.array) -> np.array:
def get_last(x: np.ndarray) -> np.ndarray:
return x[self._last_idx]

# Note: this getter is meant to be used on a TrajectoryWriter.history to
# obtain its numpy values.
def get_all_np(x: np.array) -> np.array:
def get_all_np(x: np.ndarray) -> np.ndarray:
return x[self._first_idx : self._last_idx].numpy()

# Get the state, action, next_state, as well as possibly extras for the
Expand Down
118 changes: 87 additions & 31 deletions mava/environment_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import mava
from mava.types import Action
from mava.utils.training_utils import check_count_condition
from mava.utils.wrapper_utils import (
SeqTimestepDict,
convert_seq_timestep_and_actions_to_parallel,
Expand Down Expand Up @@ -67,10 +68,10 @@ def __init__(
self.num_agents = self._environment.num_agents

# keeps track of previous actions and timesteps
self._prev_action: Dict[str, Action] = {
self._prev_action: Dict[str, Optional[Action]] = {
a: None for a in self._environment.possible_agents
}
self._prev_timestep: Dict[str, dm_env.TimeStep] = {
self._prev_timestep: Dict[str, Optional[dm_env.TimeStep]] = {
a: None for a in self._environment.possible_agents
}
self._agent_action_timestep: Dict[str, Tuple[Action, dm_env.TimeStep]] = {}
Expand Down Expand Up @@ -145,7 +146,10 @@ def _perform_turn(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:

# save action, timestep pairs for current agent
timestep = self._set_step_type(timestep, self._step_type[agent])
self._agent_action_timestep[agent] = (self._prev_action[agent], timestep)
self._agent_action_timestep[agent] = (
self._prev_action[agent],
timestep,
) # type: ignore

self._prev_timestep[agent] = timestep

Expand Down Expand Up @@ -176,7 +180,10 @@ def _collect_last_timesteps(self, timestep: dm_env.TimeStep) -> None:
for _ in range(self.num_agents):
agent = self._environment.current_agent
timestep = self._set_step_type(timestep, dm_env.StepType.LAST)
self._agent_action_timestep[agent] = (self._prev_action[agent], timestep)
self._agent_action_timestep[agent] = (
self._prev_action[agent],
timestep,
) # type: ignore

timestep = self._environment.step(
generate_zeros_from_spec(self._environment.action_spec()[agent])
Expand Down Expand Up @@ -321,6 +328,9 @@ def __init__(
self._should_update = should_update
self._running_statistics: Dict[str, float] = {}

# We need this to schedule evaluation/test runs
self._last_evaluator_run_t = -1

def _get_actions(self, timestep: dm_env.TimeStep) -> Any:
return self._executor.select_actions(timestep.observation)

Expand All @@ -338,6 +348,36 @@ def _compute_episode_statistics(
) -> None:
pass

def get_counts(self) -> Any:
if hasattr(self._executor, "_counts"):
counts = self._executor._counts
else:
counts = self._counter.get_counts()
return counts

def record_counts(self, episode_steps: int) -> Any:
# Record counts.
if hasattr(self._executor, "_counts"):
loop_type = "evaluator" if self._executor._evaluator else "executor"

if hasattr(self._executor, "_variable_client"):
self._executor._variable_client.add_async(
[f"{loop_type}_episodes", f"{loop_type}_steps"],
{
f"{loop_type}_episodes": 1,
f"{loop_type}_steps": episode_steps,
},
)
else:
self._executor._counts[f"{loop_type}_episodes"] += 1
self._executor._counts[f"{loop_type}_steps"] += episode_steps

counts = self._executor._counts
else:
counts = self._counter.increment(episodes=1, steps=episode_steps)

return counts

def run_episode(self) -> loggers.LoggingData:
"""Run one episode.
Each episode is a loop which interacts first with the environment to get a
Expand Down Expand Up @@ -429,27 +469,8 @@ def run_episode(self) -> loggers.LoggingData:
if self._get_running_stats():
return self._get_running_stats()
else:
# Record counts.
if hasattr(self._executor, "_counts"):
loop_type = "executor"
if "_" not in self._loop_label:
loop_type = "evaluator"

if hasattr(self._executor, "_variable_client"):
self._executor._variable_client.add_async(
[f"{loop_type}_episodes", f"{loop_type}_steps"],
{
f"{loop_type}_episodes": 1,
f"{loop_type}_steps": episode_steps,
},
)
else:
self._executor._counts[f"{loop_type}_episodes"] += 1
self._executor._counts[f"{loop_type}_steps"] += episode_steps

counts = self._executor._counts
else:
counts = self._counter.increment(episodes=1, steps=episode_steps)

counts = self.record_counts(episode_steps)

# Collect the results and combine with counts.
steps_per_second = episode_steps / (time.time() - start_time)
Expand All @@ -459,7 +480,6 @@ def run_episode(self) -> loggers.LoggingData:
"steps_per_second": steps_per_second,
}
result.update(counts)

return result

def run(
Expand Down Expand Up @@ -488,10 +508,46 @@ def should_terminate(episode_count: int, step_count: int) -> bool:
num_steps is not None and step_count >= num_steps
)

def should_run_loop(eval_condtion: Tuple) -> bool:
"""Check if the eval loop should run in current step.
Args:
eval_condtion : tuple containing interval key and count.
Returns:
a bool indicatings if eval should run.
"""
should_run_loop = False
eval_interval_key, eval_interval_count = eval_condtion
counts = self.get_counts()
if counts:
count = counts.get(eval_interval_key)
# We run eval loops around every eval_interval_count (not exactly every
# eval_interval_count due to latency in getting updated counts).
should_run_loop = (
(count - self._last_evaluator_run_t) / eval_interval_count
) >= 1.0
if should_run_loop:
self._last_evaluator_run_t = int(count)
return should_run_loop

episode_count, step_count = 0, 0

# Currently, we only use intervals for eval loops.
environment_loop_schedule = (
self._executor._evaluator and self._executor._interval
)
if environment_loop_schedule:
eval_condtion = check_count_condition(self._executor._interval)

while not should_terminate(episode_count, step_count):
result = self.run_episode()
episode_count += 1
step_count += result["episode_length"]
# Log the given results.
self._logger.write(result)
if (not environment_loop_schedule) or should_run_loop(eval_condtion):
result = self.run_episode()
episode_count += 1
step_count += result["episode_length"]
# Log the given results.
self._logger.write(result)

# We need to get the latest counts if we are using eval intervals.
if environment_loop_schedule:
self._executor.update()
4 changes: 2 additions & 2 deletions mava/environment_loops/debugging_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@


def get_good_simple_spread_action(
agent_id: int, obs: np.array, environment: DebuggingEnvWrapper
) -> Union[int, np.array]:
agent_id: int, obs: np.ndarray, environment: DebuggingEnvWrapper
) -> Union[int, np.ndarray]:
import gym

diff = np.array(obs[5:7])
Expand Down
1 change: 1 addition & 0 deletions mava/environment_loops/open_spiel_environment_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

class OpenSpielSequentialEnvironmentLoop(SequentialEnvironmentLoop):
"""A Sequential MARL environment loop.
This takes `Environment` and `Executor` instances and coordinates their
interaction. Executors are updated if `should_update=True`. This can be used as:
loop = EnvironmentLoop(environment, executor)
Expand Down
5 changes: 5 additions & 0 deletions mava/systems/tf/dial/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
trainer: MADQNTrainer = None,
fingerprint: bool = False,
evaluator: bool = False,
interval: Optional[dict] = None,
):
"""Initialise the system executor
Expand All @@ -69,6 +70,7 @@ def __init__(
stabilise experience replay. Defaults to False.
evaluator (bool, optional): whether the executor will be used for
evaluation. Defaults to False.
interval: interval that evaluations are run at.
"""

# Store these for later use.
Expand All @@ -85,6 +87,9 @@ def __init__(
self._states: Dict[str, Any] = {}
self._messages: Dict[str, Any] = {}

self._evaluator = evaluator
self._interval = interval

@tf.function
def _policy(
self,
Expand Down
Loading

0 comments on commit 98f7e99

Please sign in to comment.