Skip to content

Commit

Permalink
Fix error where get variables was not getting pulled in the trainer.
Browse files Browse the repository at this point in the history
  • Loading branch information
DriesSmit committed Oct 1, 2021
1 parent d1a51df commit 2a82347
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 34 deletions.
3 changes: 0 additions & 3 deletions mava/systems/tf/maddpg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,8 +533,6 @@ def make_trainer(
get_keys.append(f"{net_key}_{net_type_key}")

variables = self.create_counter_variables(variables)
num_steps = variables["trainer_steps"]

count_names = [
"trainer_steps",
"trainer_walltime",
Expand Down Expand Up @@ -583,7 +581,6 @@ def make_trainer(
"variable_client": variable_client,
"dataset": dataset,
"counts": counts,
"num_steps": num_steps,
"logger": logger,
}
if connection_spec:
Expand Down
29 changes: 4 additions & 25 deletions mava/systems/tf/maddpg/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __init__(
target_observation_networks: Dict[str, snt.Module],
variable_client: VariableClient,
counts: Dict[str, Any],
num_steps: tf.Variable,
agent_net_keys: Dict[str, str],
max_gradient_norm: float = None,
logger: loggers.Logger = None,
Expand Down Expand Up @@ -98,8 +97,6 @@ def __init__(
network.
variable_client: The client used to manage the variables.
counts: step counter object.
num_steps: Use to track the number of steps before the target networks
are updated.
agent_net_keys: specifies what network each agent uses.
max_gradient_norm: maximum allowed norm for gradients
before clipping is applied.
Expand Down Expand Up @@ -142,7 +139,7 @@ def __init__(
self._max_gradient_norm = tf.convert_to_tensor(1e10)

# Necessary to track when to update target networks.
self._num_steps = num_steps
self._num_steps = 0
self._target_averaging = target_averaging
self._target_update_period = target_update_period
self._target_update_rate = target_update_rate
Expand Down Expand Up @@ -220,7 +217,7 @@ def _update_target_networks(self) -> None:
if tf.math.mod(self._num_steps, self._target_update_period) == 0:
for src, dest in zip(online_variables, target_variables):
dest.assign(src)
self._num_steps.assign_add(1)
self._num_steps += 1

def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray]]:
"""Depreciated method."""
Expand Down Expand Up @@ -534,7 +531,6 @@ def __init__(
target_observation_networks: Dict[str, snt.Module],
variable_client: VariableClient,
counts: Dict[str, Any],
num_steps: int,
agent_net_keys: Dict[str, str],
max_gradient_norm: float = None,
logger: loggers.Logger = None,
Expand All @@ -561,7 +557,6 @@ def __init__(
logger=logger,
variable_client=variable_client,
counts=counts,
num_steps=num_steps,
)


Expand All @@ -587,7 +582,6 @@ def __init__(
target_observation_networks: Dict[str, snt.Module],
variable_client: VariableClient,
counts: Dict[str, Any],
num_steps: int,
agent_net_keys: Dict[str, str],
max_gradient_norm: float = None,
logger: loggers.Logger = None,
Expand All @@ -614,7 +608,6 @@ def __init__(
logger=logger,
variable_client=variable_client,
counts=counts,
num_steps=num_steps,
)

def _get_critic_feed(
Expand Down Expand Up @@ -680,7 +673,6 @@ def __init__(
target_observation_networks: Dict[str, snt.Module],
variable_client: VariableClient,
counts: Dict[str, Any],
num_steps: int,
agent_net_keys: Dict[str, str],
max_gradient_norm: float = None,
logger: loggers.Logger = None,
Expand All @@ -707,7 +699,6 @@ def __init__(
logger=logger,
variable_client=variable_client,
counts=counts,
num_steps=num_steps,
)
self._connection_spec = connection_spec

Expand Down Expand Up @@ -790,7 +781,6 @@ def __init__(
target_observation_networks: Dict[str, snt.Module],
variable_client: VariableClient,
counts: Dict[str, Any],
num_steps: int,
agent_net_keys: Dict[str, str],
max_gradient_norm: float = None,
logger: loggers.Logger = None,
Expand All @@ -817,7 +807,6 @@ def __init__(
logger=logger,
variable_client=variable_client,
counts=counts,
num_steps=num_steps,
)

def _get_critic_feed(
Expand Down Expand Up @@ -885,7 +874,6 @@ def __init__(
target_observation_networks: Dict[str, snt.Module],
variable_client: VariableClient,
counts: Dict[str, Any],
num_steps: tf.Variable,
agent_net_keys: Dict[str, str],
max_gradient_norm: float = None,
logger: loggers.Logger = None,
Expand Down Expand Up @@ -918,8 +906,6 @@ def __init__(
network.
variable_client: The client used to manage the variables.
counts: step counter object.
num_steps: Use to track the number of steps before the target networks
are updated.
agent_net_keys: specifies what network each agent uses.
max_gradient_norm: maximum allowed norm for gradients
before clipping is applied.
Expand Down Expand Up @@ -963,7 +949,7 @@ def __init__(
self._max_gradient_norm = tf.convert_to_tensor(1e10)

# Necessary to track when to update target networks.
self._num_steps = num_steps
self._num_steps = 0
self._target_averaging = target_averaging
self._target_update_period = target_update_period
self._target_update_rate = target_update_rate
Expand Down Expand Up @@ -1042,7 +1028,7 @@ def _update_target_networks(self) -> None:
if tf.math.mod(self._num_steps, self._target_update_period) == 0:
for src, dest in zip(online_variables, target_variables):
dest.assign(src)
self._num_steps.assign_add(1)
self._num_steps += 1

def _transform_observations(
self, observations: Dict[str, np.ndarray]
Expand Down Expand Up @@ -1177,7 +1163,6 @@ def _step(
Returns:
losses
"""

# Update the target networks
self._update_target_networks()

Expand Down Expand Up @@ -1454,7 +1439,6 @@ def __init__(
target_observation_networks: Dict[str, snt.Module],
variable_client: VariableClient,
counts: Dict[str, Any],
num_steps: tf.Variable,
agent_net_keys: Dict[str, str],
max_gradient_norm: float = None,
logger: loggers.Logger = None,
Expand Down Expand Up @@ -1482,7 +1466,6 @@ def __init__(
logger=logger,
variable_client=variable_client,
counts=counts,
num_steps=num_steps,
bootstrap_n=bootstrap_n,
)

Expand Down Expand Up @@ -1512,7 +1495,6 @@ def __init__(
target_observation_networks: Dict[str, snt.Module],
variable_client: VariableClient,
counts: Dict[str, Any],
num_steps: tf.Variable,
agent_net_keys: Dict[str, str],
max_gradient_norm: float = None,
logger: loggers.Logger = None,
Expand Down Expand Up @@ -1540,7 +1522,6 @@ def __init__(
logger=logger,
variable_client=variable_client,
counts=counts,
num_steps=num_steps,
bootstrap_n=bootstrap_n,
)

Expand Down Expand Up @@ -1610,7 +1591,6 @@ def __init__(
target_observation_networks: Dict[str, snt.Module],
variable_client: VariableClient,
counts: Dict[str, Any],
num_steps: tf.Variable,
agent_net_keys: Dict[str, str],
max_gradient_norm: float = None,
logger: loggers.Logger = None,
Expand Down Expand Up @@ -1638,7 +1618,6 @@ def __init__(
logger=logger,
variable_client=variable_client,
counts=counts,
num_steps=num_steps,
bootstrap_n=bootstrap_n,
)

Expand Down
3 changes: 1 addition & 2 deletions mava/systems/tf/variable_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _adjust_and_request(self) -> None:
self._set_keys,
tf2_utils.to_numpy({key: self._variables[key] for key in self._set_keys}),
)
self._client.get_variables(self._get_keys)
self._copy(self._client.get_variables(self._get_keys))

def get_async(self) -> None:
"""Asynchronously updates the get variables with the latest copy from source."""
Expand Down Expand Up @@ -129,7 +129,6 @@ def set_and_get_async(self) -> None:
# Track the number of calls (we only update periodically).
if self._set_get_call_counter < self._update_period:
self._set_get_call_counter += 1

period_reached: bool = self._set_get_call_counter >= self._update_period

if period_reached and self._set_get_future is None: # type: ignore
Expand Down
6 changes: 2 additions & 4 deletions mava/wrappers/system_trainer_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def __init__(
def step(self) -> None:
# Run the learning step.
fetches = self._step()

if self._require_loggers:
self._create_loggers(list(fetches.keys()))
self._require_loggers = False
Expand All @@ -174,13 +173,12 @@ def step(self) -> None:
self._timestamp: float = timestamp

# Update our counts and record it.
print("self._counts: ", self._counts)
exit()
self._variable_client.add_async(
["trainer_steps", "trainer_walltime"],
{"trainer_steps": 1, "trainer_walltime": elapsed_time},
)
# Update the variable source and the trainer

# Set and get the latest variables
self._variable_client.set_and_get_async()

fetches.update(self._counts)
Expand Down

0 comments on commit 2a82347

Please sign in to comment.