Skip to content

Commit

Permalink
feat: Set executor update period to 0, when using eval intervals.
Browse files Browse the repository at this point in the history
  • Loading branch information
KaleabTessera committed Nov 29, 2021
1 parent ef42512 commit 643b9c6
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 13 deletions.
4 changes: 4 additions & 0 deletions mava/environment_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,10 @@ def should_run_loop(eval_condtion: Tuple) -> bool:
) >= 1.0
if should_run_loop:
self._last_evaluator_run_t = int(count)
print(
"Running eval loop at executor step: "
+ f"{self._last_evaluator_run_t}"
)
return should_run_loop

episode_count, step_count = 0, 0
Expand Down
9 changes: 7 additions & 2 deletions mava/systems/tf/dial/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def make_executor(
"""

agent_net_keys = self._config.agent_net_keys
evaluator_interval = self._config.evaluator_interval if evaluator else None

variable_client = None
if variable_source:
Expand All @@ -308,7 +309,11 @@ def make_executor(
variable_client = variable_utils.VariableClient(
client=variable_source,
variables={"q_network": variables},
update_period=self._config.executor_variable_update_period,
# If we are using evaluator_intervals,
# we should always get the latest variables.
update_period=0
if evaluator_interval
else self._config.executor_variable_update_period,
)

# Make sure not to use a random policy after checkpoint restoration by
Expand All @@ -329,7 +334,7 @@ def make_executor(
communication_module=communication_module,
evaluator=evaluator,
fingerprint=fingerprint,
interval=self._config.evaluator_interval if evaluator else None,
interval=evaluator_interval,
)

def make_trainer(
Expand Down
9 changes: 7 additions & 2 deletions mava/systems/tf/maddpg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,14 +462,19 @@ def make_executor(
get_keys.extend(count_names)
counts = {name: variables[name] for name in count_names}

evaluator_interval = self._config.evaluator_interval if evaluator else None
variable_client = None
if variable_source:
# Get new policy variables
variable_client = variable_utils.VariableClient(
client=variable_source,
variables=variables,
get_keys=get_keys,
update_period=self._config.executor_variable_update_period,
# If we are using evaluator_intervals,
# we should always get the latest variables.
update_period=0
if evaluator_interval
else self._config.executor_variable_update_period,
)

# Make sure not to use a random policy after checkpoint restoration by
Expand All @@ -487,7 +492,7 @@ def make_executor(
variable_client=variable_client,
adder=adder,
evaluator=evaluator,
interval=self._config.evaluator_interval if evaluator else None,
interval=evaluator_interval,
)

def make_trainer(
Expand Down
10 changes: 7 additions & 3 deletions mava/systems/tf/madqn/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def make_executor(
"""

agent_net_keys = self._config.agent_net_keys

evaluator_interval = self._config.evaluator_interval if evaluator else None
variable_client = None
if variable_source:
# Create policy variables
Expand All @@ -319,7 +319,11 @@ def make_executor(
variable_client = variable_utils.VariableClient(
client=variable_source,
variables={"q_network": variables},
update_period=self._config.executor_variable_update_period,
# If we are using evaluator_intervals,
# we should always get the latest variables.
update_period=0
if evaluator_interval
else self._config.executor_variable_update_period,
)

# Make sure not to use a random policy after checkpoint restoration by
Expand All @@ -340,7 +344,7 @@ def make_executor(
communication_module=communication_module,
evaluator=evaluator,
fingerprint=fingerprint,
interval=self._config.evaluator_interval if evaluator else None,
interval=evaluator_interval,
)

def make_trainer(
Expand Down
9 changes: 7 additions & 2 deletions mava/systems/tf/mappo/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def make_executor(
"""

variable_client = None
evaluator_interval = self._config.evaluator_interval if evaluator else None
if variable_source:
# Create policy variables.
variables = {
Expand All @@ -243,7 +244,11 @@ def make_executor(
variable_client = variable_utils.VariableClient(
client=variable_source,
variables={"policy": variables},
update_period=self._config.executor_variable_update_period,
# If we are using evaluator_intervals,
# we should always get the latest variables.
update_period=0
if evaluator_interval
else self._config.executor_variable_update_period,
)

# Make sure not to use a random policy after checkpoint restoration by
Expand All @@ -257,7 +262,7 @@ def make_executor(
variable_client=variable_client,
adder=adder,
evaluator=evaluator,
interval=self._config.evaluator_interval if evaluator else None,
interval=evaluator_interval,
)

def make_trainer(
Expand Down
4 changes: 2 additions & 2 deletions mava/systems/tf/variable_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from acme.tf import utils as tf2_utils

from mava.systems.tf import savers as tf2_savers
from mava.utils.training_utils import check_count_condition
from mava.utils.training_utils import check_count_condition, non_blocking_sleep


class VariableSource:
Expand Down Expand Up @@ -135,7 +135,7 @@ def run(self) -> None:
# Checkpoints every 5 minutes
while True:
# Wait 10 seconds before checking again
time.sleep(10)
non_blocking_sleep(10)

# Add 1 extra second just to make sure that the checkpointer
# is ready to save.
Expand Down
5 changes: 3 additions & 2 deletions mava/utils/lp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

import functools
import inspect
import time
from typing import Any, Callable, Dict, List, Optional

import launchpad as lp
from absl import flags, logging
from acme.utils import counting
from launchpad.nodes.python.local_multi_processing import PythonProcess

from mava.utils.training_utils import non_blocking_sleep

FLAGS = flags.FLAGS


Expand Down Expand Up @@ -119,4 +120,4 @@ def run(self) -> None:
lp.stop()

# Don't spam the counter.
time.sleep(10.0)
non_blocking_sleep(10)
13 changes: 13 additions & 0 deletions mava/utils/training_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
import os
import time
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import sonnet as snt
import tensorflow as tf
import trfl


def non_blocking_sleep(time_in_seconds: int) -> None:
"""Function to sleep for time_in_seconds, without hanging lp program.
Args:
time_in_seconds : number of seconds to sleep for.
"""
for _ in range(time_in_seconds):
# Do not sleep for a long period of time to avoid LaunchPad program
# termination hangs (time.sleep is not interruptible).
time.sleep(1)


def check_count_condition(condition: Optional[dict]) -> Tuple:
"""Checks if condition is valid. These conditions are used for termination
or to run evaluators in intervals.
Expand Down

0 comments on commit 643b9c6

Please sign in to comment.