Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/Multiple trainers for MA-DDPG #253

Merged
merged 164 commits into from
Oct 25, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
164 commits
Select commit Hold shift + click to select a range
f543f5f
feature: Save work on seperate variable source.
DriesSmit Jun 15, 2021
5333e74
fix: A general inter-node variable communicator module is now impleme…
DriesSmit Jun 16, 2021
ab10d87
fix: Cleanup variable_utils and some other files.
DriesSmit Jun 16, 2021
a84946e
Merge remote-tracking branch 'origin' into feature/mava-scaling
DriesSmit Jun 16, 2021
329f95f
Merge remote-tracking branch 'origin/develop' into feature/mava-scaling
DriesSmit Jun 18, 2021
5d93108
Add scaled mad4pg example.
DriesSmit Jun 18, 2021
58926f5
Merge remote-tracking branch 'origin/develop' into feature/mava-scaling
DriesSmit Jun 24, 2021
0a1119b
feature: Save latest code.
DriesSmit Jun 25, 2021
e259dd1
feature: Save latest code.
DriesSmit Jun 25, 2021
e9d6a55
fix: A lot of bugfixes.
DriesSmit Jun 25, 2021
9f2f2cd
fix: A lot of bugfixes.
DriesSmit Jun 25, 2021
6648245
fix: Last save for today.
DriesSmit Jun 25, 2021
6cda7fc
fix: Fix code so that other algorithms can still run in their normal …
DriesSmit Jun 25, 2021
bc2a013
Fix some bugs in the debugging 3 trainer example.
DriesSmit Jun 28, 2021
acbfd4e
fix: Fix some more bugs.
DriesSmit Jun 28, 2021
9aa7834
Fix makefile.
DriesSmit Jun 28, 2021
a678bf3
fix: First attempt running.
DriesSmit Jun 28, 2021
e5037f9
fix: Fix bug where trainers did not update variable source.
DriesSmit Jun 29, 2021
ad76e16
Merge remote-tracking branch 'origin/develop' into feature/mava-scaling
DriesSmit Jun 29, 2021
ad3a184
fix: Small fix.
DriesSmit Jun 29, 2021
8c159b9
fix: Resolve merge conflict.
DriesSmit Jun 30, 2021
fb29faa
fix: Fix environment_loop for when no variable_client is presented.
DriesSmit Jun 30, 2021
0607dca
fix: Fix error where no count variable crashes the environment_loop.
DriesSmit Jun 30, 2021
efec9d4
fix: Small fix.
DriesSmit Jun 30, 2021
ce939e7
fix: Small fix.
DriesSmit Jun 30, 2021
2bf2e46
Fix: Small fix.
DriesSmit Jun 30, 2021
28cb0e9
Merge remote-tracking branch 'origin/develop' into feature/mava-scaling
DriesSmit Jun 30, 2021
8b24be2
Merge branch 'bugfix/logging-running-stats' into feature/mava-scaling
DriesSmit Jun 30, 2021
e9fc43e
Merge remote-tracking branch 'origin/develop' into feature/mava-scaling
DriesSmit Jun 30, 2021
7a44219
fix: Fix environment loop.
DriesSmit Jun 30, 2021
2be9993
fix: Remove comment.
DriesSmit Jun 30, 2021
e235a66
fix: Small fixes.
DriesSmit Jun 30, 2021
d318c8c
fix: Small fixes.
DriesSmit Jun 30, 2021
6832841
Start with network_key conversion
DriesSmit Jul 1, 2021
6d25b66
fix: Update scaled example networks.
DriesSmit Jul 1, 2021
bc87cdc
fix: Fix scaled maddpg system.
DriesSmit Jul 2, 2021
1a06216
Merge network-keys branch.
DriesSmit Jul 2, 2021
0263b54
fix: Net key bugfixes in scaled maddpg system.
DriesSmit Jul 2, 2021
c7f8d72
Start with setting up population based training example.
DriesSmit Jul 2, 2021
37fa724
Merge feature/network-keys.
DriesSmit Jul 2, 2021
d1e8d89
Merge feature/mava-scaling.
DriesSmit Jul 2, 2021
a49e84e
Merge branch feature/network-keys.
DriesSmit Jul 3, 2021
8b468e0
Merge branch 'feature/mava-scaling' into feature/population-based-tra…
DriesSmit Jul 3, 2021
07ead24
Small changes.
DriesSmit Jul 5, 2021
f09e2c1
Save changes.
DriesSmit Jul 5, 2021
bd4aafc
Merge branch 'feature/network-keys' into feature/population-based-tra…
DriesSmit Jul 5, 2021
9dcbb8f
feature: Update decentralised architecture to work for both methods o…
DriesSmit Jul 6, 2021
3d1c2dd
block: PBT is blocked for the moment as the reverb version must first…
DriesSmit Jul 7, 2021
0424702
Merge remote-tracking branch 'origin/develop' into feature/population…
DriesSmit Jul 7, 2021
9791529
Merge remote-tracking branch 'origin/develop' into feature/mava-scaling
DriesSmit Jul 7, 2021
ab1586f
Merge branch 'feature/network-keys' into feature/mava-scaling
DriesSmit Jul 7, 2021
fc63367
fix: Delete shared weights for network_factory in examples.
DriesSmit Jul 7, 2021
ee4f97d
Merge remote-tracking branch 'origin/feature/network-keys' into featu…
DriesSmit Jul 7, 2021
8c04e74
feature: Add better function descriptions.
DriesSmit Jul 7, 2021
7e82914
Merge feature/mava-scaling.
DriesSmit Jul 7, 2021
0ad3cc1
Merge acme adder update branch.
DriesSmit Jul 12, 2021
8ac2ab2
Merge remote-tracking branch 'origin/develop' into feature/mava-scaling
DriesSmit Jul 12, 2021
2ee0bf2
Merge feature/network-keys.
DriesSmit Jul 12, 2021
19aaa98
feature: Delete original maddpg and rename maddpg_scaled to maddpg. f…
DriesSmit Jul 13, 2021
c222bf9
fix: Add deleted README back again.
DriesSmit Jul 13, 2021
c12cead
Merge branch 'feature/network-keys' into feature/mava-scaling
DriesSmit Jul 13, 2021
bf12cdf
fix: Small fix to MAD4PG executor.
DriesSmit Jul 13, 2021
dbedde8
Merge feature/mava-scaling.
DriesSmit Jul 13, 2021
ac5dff3
Merge remote-tracking branch 'origin/feature/upgrade-acme-version' in…
DriesSmit Jul 13, 2021
6706068
fix: General updates to adders for PBT.
DriesSmit Jul 13, 2021
6ecae40
fix: Fix trainer for PBT.
DriesSmit Jul 14, 2021
62b1a37
Merge branch 'develop' into feature/mava-scaling
DriesSmit Jul 14, 2021
e656ef0
fix: Correct table names in builder and system.
DriesSmit Jul 14, 2021
1f8c422
fix: A few small bugfixes. The PBT general frameworks seems to be wor…
DriesSmit Jul 15, 2021
e0ae92a
fix: Small fixes to mad4pg.
DriesSmit Jul 16, 2021
26a284e
Save code
DriesSmit Jul 19, 2021
b7ab9c6
Save code
DriesSmit Jul 19, 2021
cb7c6b1
Save work
DriesSmit Jul 19, 2021
633774c
Merge adders branch.
DriesSmit Jul 19, 2021
5d363b2
Save changes
DriesSmit Jul 20, 2021
96f2bb1
Merge uprade-acme-version.
DriesSmit Jul 20, 2021
25ca993
Save changes.
DriesSmit Jul 20, 2021
401dfbc
fix: Small fixes to PBT setup.
DriesSmit Jul 20, 2021
3ff3ffb
fix: Small fixes to PBT setup.
DriesSmit Jul 20, 2021
bb91b79
Merge branch 'develop' into feature/mava-scaling
DriesSmit Jul 20, 2021
8b55ac9
fix: Implement custom sort algorithm.
DriesSmit Jul 21, 2021
b6e6495
Merge branch 'develop' into feature/mava-scaling
DriesSmit Jul 21, 2021
c51d19b
Start work on scaled recurrent setup.
DriesSmit Jul 22, 2021
bfa15bf
Merge develop.
DriesSmit Jul 22, 2021
1a258f0
Merge upgrade-acme branch.
DriesSmit Jul 22, 2021
5cd4257
Merge branch 'feature/population-based-training' into feature/mava-sc…
DriesSmit Jul 22, 2021
8dbf5a4
Merge upgrade-acme branch.
DriesSmit Jul 22, 2021
d54ee4c
Add old maddpg system for debugging purposes.
DriesSmit Jul 22, 2021
00e35c4
fix: Fix recurrent maddpg/mad4pg.
DriesSmit Jul 22, 2021
542ed65
fix: Fix run_maddpg.py.
DriesSmit Jul 22, 2021
56eeb95
Merge remote-tracking branch 'origin/develop' into feature/mava-scaling
DriesSmit Jul 22, 2021
1da03d7
fix: SMAC imports.
DriesSmit Jul 22, 2021
fd4a255
fix: Fix mypy issues.
DriesSmit Jul 22, 2021
19b1903
Fix MA-DDPG executor to work for non indentity observation networks.
DriesSmit Jul 23, 2021
7b31b55
Reduce variable_server checkpoint waiting time.
DriesSmit Jul 24, 2021
16b8c7b
Fix .keys() and .values() error.
DriesSmit Jul 24, 2021
2bfedce
fix: Small fixes.
DriesSmit Jul 26, 2021
4cb7d99
update: Change supersuit import so that it does not break Mava if the…
DriesSmit Jul 26, 2021
0562341
update: Update sequential and transition adder with better machanics.…
DriesSmit Jul 27, 2021
b86c8a8
fix: Fix adders to be backwards compatible.
DriesSmit Jul 27, 2021
b3331d5
fix: Small fixes.
DriesSmit Jul 27, 2021
c12f85d
fix: Small fixes.
DriesSmit Jul 27, 2021
0084f75
fix: Small fix.
DriesSmit Jul 27, 2021
d6edc46
Small fixes.
DriesSmit Jul 27, 2021
4e8981e
Merge develop.
DriesSmit Jul 28, 2021
df23cb1
fix: is Dict bugs.
DriesSmit Jul 28, 2021
bfe6946
Merge develop.
DriesSmit Aug 3, 2021
167ad44
Address PR comments.
DriesSmit Aug 3, 2021
1a86fd0
Address PR comments.
DriesSmit Aug 3, 2021
b849293
Merge branch 'develop' into feature/mava-scaling
DriesSmit Aug 4, 2021
182f1e9
Merge branch 'develop' into feature/mava-scaling
DriesSmit Aug 6, 2021
c98e460
Adders PR comments.
DriesSmit Aug 6, 2021
12da764
Update mava scalling example.
DriesSmit Aug 6, 2021
3e6a006
Merge branch 'develop' into feature/mava-scaling
DriesSmit Aug 9, 2021
1b70136
Fix slow training times.
DriesSmit Aug 12, 2021
dc9348d
Add some more async methods in the trainer.
DriesSmit Aug 12, 2021
00361fa
Add tf.function to wrapper steps.
DriesSmit Aug 12, 2021
34b0925
fix: Fix the evaluator and executor loggers.
DriesSmit Aug 18, 2021
55e4f2b
fix: Small changes.
DriesSmit Aug 18, 2021
2341d10
Merge develop.
DriesSmit Aug 19, 2021
f0ef242
Add init.
DriesSmit Aug 21, 2021
9842000
Merge Develop. Update MAD4PG vmin and vmax parameters.
DriesSmit Aug 24, 2021
5244bde
Small fixes.
DriesSmit Aug 24, 2021
acb127c
Small fixes.
DriesSmit Aug 24, 2021
7ce2141
Fix tests.
DriesSmit Aug 24, 2021
f7c0bcd
Fix tests.
DriesSmit Aug 24, 2021
16903f8
Fix tests.
DriesSmit Aug 24, 2021
8bf30f3
Merge remote-tracking branch 'origin/develop' into feature/mava-scaling
DriesSmit Aug 26, 2021
8a5681a
Update the base.py adder file with some comments.
DriesSmit Aug 30, 2021
fa59d5b
Update multiple trainer example.
DriesSmit Aug 30, 2021
852da27
Merge branch 'develop' into feature/mava-scaling
DriesSmit Aug 30, 2021
3cc5831
Merge develop.
DriesSmit Aug 31, 2021
a74ab9f
Added enums for the trainer setup to make it easier to setup multiple…
DriesSmit Sep 9, 2021
a6b8e89
Merge remote-tracking branch 'origin/develop' into feature/mava-scaling
DriesSmit Sep 9, 2021
f4fb448
Convert network_sampling_setup enums.
DriesSmit Sep 9, 2021
70717da
Made a lot of corrections to the docstrings.
DriesSmit Sep 9, 2021
7a94465
Minor corrections to docstrings
arnupretorius Sep 9, 2021
4d455b1
Address PR comments.
DriesSmit Sep 10, 2021
6bf7623
Small fix.
DriesSmit Sep 10, 2021
e478f27
fix: Address PR comment.
DriesSmit Sep 10, 2021
4a160c1
Small fix to the checkpointer.
DriesSmit Sep 10, 2021
163a1d7
Small fix to trainer_networks setup.
DriesSmit Sep 10, 2021
4663856
Address PR comments.
DriesSmit Sep 11, 2021
266db23
Address PR comments.
DriesSmit Sep 14, 2021
3bb8a3c
Add termination_conditions.
DriesSmit Sep 14, 2021
41429a4
Change write flush back.
DriesSmit Sep 15, 2021
ed271be
Address PR comments.
DriesSmit Sep 15, 2021
2a0d7c1
Update trainer naming to match dev branch for one trainer.
DriesSmit Sep 16, 2021
91bc297
Quick fix.
DriesSmit Sep 16, 2021
7b0db2a
Small updates to make the trainers and executors faster.
DriesSmit Sep 17, 2021
e478da0
Small fixes.
DriesSmit Sep 17, 2021
df9ea64
Small fix to the get and set async functions.
DriesSmit Sep 18, 2021
bd97bcc
Add a buffer to the async_add function if it is called to often.
DriesSmit Sep 18, 2021
0bdb98b
Fix mypy issue?
DriesSmit Sep 21, 2021
d1a51df
Small changes.
DriesSmit Sep 30, 2021
2a82347
Fix error where get variables was not getting pulled in the trainer.
DriesSmit Oct 1, 2021
b4f664f
Remove sigma.
DriesSmit Oct 1, 2021
62cd1fa
Add state based critic with one one action.
DriesSmit Oct 4, 2021
71e4117
Small fixes.
DriesSmit Oct 5, 2021
19a867e
Fix problem with tf.function and num_step not incrementing.
DriesSmit Oct 7, 2021
d3a272f
Address PR comments.
DriesSmit Oct 7, 2021
e39daca
Small fix.
DriesSmit Oct 7, 2021
fb97273
Add comments on multiple trainer being experimental.
DriesSmit Oct 25, 2021
be2a4df
Merge develop.
DriesSmit Oct 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: A general inter-node variable communicator module is now impleme…
…nt and working.
  • Loading branch information
DriesSmit committed Jun 16, 2021
commit 5333e746678c3a72261dea07b9124303c02fd2ba
37 changes: 26 additions & 11 deletions mava/systems/tf/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

tfd = tfp.distributions

import time

# Delete me
import numpy as np


class FeedForwardExecutor(core.Executor):
"""A feed-forward executor.
Expand Down Expand Up @@ -115,18 +120,28 @@ def select_actions(
return actions

def update(self, wait: bool = False) -> None:
# TODO (dries): Add counter here or in vairalbe_utils to pull every n times
if self._variable_client:
import numpy as np
#print("self._variable_client._variables", self._variable_client._variables["policies"]["agent"][1])
var_sum = np.sum(self._variable_client._variables["policies"]["agent"][1])
#print("Executor client: ", var_sum, ". local: ", "None")
self._variable_client.get_async(wait)

# for agent in self._policy_networks.keys():
# self._policy_networks[agent].variables = self._variable_client._variables["policies"][agent]



# if self._variable_client._call_counter==0:

self._variable_client.get_async()
for agent in self._policy_networks.keys():
for i in range(len(self._policy_networks[agent].variables)):
self._policy_networks[agent].variables[i].assign(
self._variable_client._variables["policies"][agent][i]
)
# if self._variable_client._future is not None:

# # print("Actually getting :) ")
# while not self._variable_client._future.done():
# # print("Getting..")
# time.sleep(0.001)
# self._variable_client.get_async()
# # print("Done.")

# var_sum = np.sum(self._variable_client._variables["policies"]["agent"][1])
# print("Executor client: ", var_sum, ". local: ", "None")
# self._variable_client.get_async(wait)


class RecurrentExecutor(core.Executor):
Expand Down
34 changes: 19 additions & 15 deletions mava/systems/tf/maddpg_scaled/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,26 @@
import copy
import dataclasses
from typing import Any, Dict, Iterator, List, Optional, Type, Union
from acme.core import VariableSource
from numpy.core.fromnumeric import var

import reverb
import sonnet as snt
from acme import datasets
from acme.core import VariableSource
from acme.specs import EnvironmentSpec
from acme.utils import counting, loggers
from acme.tf import utils as tf2_utils
from acme.utils import counting, loggers
from dm_env import specs as dm_specs
from mava.systems.tf import variable_utils
from mava.systems.tf import savers as tf2_savers
from mava.systems.tf.variable_sources import VariableSource as MavaVariableSource
from numpy.core.fromnumeric import var

from mava import adders, core, specs, types
from mava.adders import reverb as reverb_adders
from mava.systems.builders import SystemBuilder
from mava.systems.tf import executors
from mava.systems.tf import savers as tf2_savers
from mava.systems.tf import variable_utils
from mava.systems.tf.maddpg_scaled import training
from mava.systems.tf.maddpg_scaled.execution import MADDPGFeedForwardExecutor
from mava.systems.tf.variable_sources import VariableSource as MavaVariableSource
from mava.wrappers import DetailedTrainerStatistics, NetworkStatisticsActorCritic

BoundedArray = dm_specs.BoundedArray
Expand Down Expand Up @@ -273,13 +274,15 @@ def make_variable_server(
# Set all the network variables inside the variable source
networks_vars = {}
for net_key in networks.keys():
networks_vars[net_key] = {agent: networks[net_key][agent].variables for agent in self._agent_types}


variable_source.set_variables(networks.keys(), tf2_utils.to_numpy(networks_vars))
networks_vars[net_key] = {
agent: networks[net_key][agent].variables for agent in self._agent_types
}

variable_source.set_variables(
networks.keys(), tf2_utils.to_numpy(networks_vars)
)

return variable_source


def make_executor(
self,
Expand Down Expand Up @@ -311,7 +314,7 @@ def make_executor(
variable_client = variable_utils.VariableClient(
client=variable_source,
variables={"policies": variables},
update_period=self._config.executor_variable_update_period,
get_period=self._config.executor_variable_update_period,
)

# Make sure not to use a random policy after checkpoint restoration by
Expand Down Expand Up @@ -366,9 +369,10 @@ def make_trainer(

variables = {}
for net_key in networks.keys():
variables[net_key] = {agent: networks[net_key][agent].variables for agent in self._agent_types}
variables = tf2_utils.to_numpy(variables)

variables[net_key] = {
agent: networks[net_key][agent].variables for agent in self._agent_types
}

variable_client = variable_utils.VariableClient(
client=variable_source,
variables=variables,
Expand Down
21 changes: 11 additions & 10 deletions mava/systems/tf/maddpg_scaled/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import copy
import os
import time
from typing import Any, Dict, List, Optional, Tuple, Union, Sequence
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import reverb
Expand All @@ -34,7 +34,8 @@
import mava
from mava import core
from mava.components.tf.losses.sequence import recurrent_n_step_critic_loss
from mava.systems.tf import savers as tf2_savers, variable_sources
from mava.systems.tf import savers as tf2_savers
from mava.systems.tf import variable_sources
from mava.utils import training_utils as train_utils

train_utils.set_growing_gpu_memory()
Expand Down Expand Up @@ -164,7 +165,7 @@ def __init__(
# Expose the variables.
policy_networks_to_expose = {}
self._system_network_variables: Dict[str, Dict[str, snt.Module]] = {
"critic": {},
"critics": {},
"policies": {},
}
for agent_key in self.unique_net_keys:
Expand All @@ -175,7 +176,7 @@ def __init__(
]
)
policy_networks_to_expose[agent_key] = policy_network_to_expose
self._system_network_variables["critic"][
self._system_network_variables["critics"][
agent_key
] = target_critic_networks[agent_key].variables
self._system_network_variables["policies"][
Expand Down Expand Up @@ -237,7 +238,7 @@ def _update_target_networks(self) -> None:
for src, dest in zip(online_variables, target_variables):
dest.assign(src)
self._num_steps.assign_add(1)

def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray]]:
variables: Dict[str, Dict[str, np.ndarray]] = {}
for network_type in names:
Expand Down Expand Up @@ -458,7 +459,10 @@ def step(self) -> None:

# Update variable source variables
# TODO (dries): Do this call asynchronous
self._variable_source.set_variables(self._system_network_variables.keys(), tf2_utils.to_numpy(self._system_network_variables))
self._variable_source.set_variables(
self._system_network_variables.keys(),
tf2_utils.to_numpy(self._system_network_variables),
)

# Checkpoint and attempt to write the logs.
if self._checkpoint:
Expand Down Expand Up @@ -498,7 +502,6 @@ def __init__(
logger: loggers.Logger = None,
checkpoint: bool = True,
checkpoint_subpath: str = "~/mava/",

):
"""Initializes the learner.
Args:
Expand Down Expand Up @@ -1046,7 +1049,7 @@ def __init__(
]
)
policy_networks_to_expose[agent_key] = policy_network_to_expose
self._system_network_variables["critic"][
self._system_network_variables["critics"][
agent_key
] = target_critic_networks[agent_key].variables
self._system_network_variables["policies"][
Expand Down Expand Up @@ -1412,8 +1415,6 @@ def step(self) -> None:
if self._logger:
self._logger.write(fetches)




class MADDPGDecentralisedRecurrentTrainer(MADDPGBaseRecurrentTrainer):
"""MADDPG trainer.
Expand Down
22 changes: 15 additions & 7 deletions mava/systems/tf/variable_sources.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
from typing import Dict, Sequence
from acme.tf import utils as tf2_utils
from typing import Dict, Sequence

import numpy as np
import sonnet as snt
from acme.tf import utils as tf2_utils

class VariableSource():

class VariableSource:
arnupretorius marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self) -> None:
# Init the variable dictionary
self.variables = {}
self.variables = {}

def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray]]:
# import tensorflow as tf
# tf.print("Getting variable inside source.")
if type(names) == str:
return self.variables[names]
else:
variables: Dict[str, Dict[str, np.ndarray]] = {}
for var_key in names:
variables[var_key] = self.variables[var_key]
return variables

def set_variables(self, names: Sequence[str], vars: Dict[str, np.ndarray]) -> Dict[str, Dict[str, np.ndarray]]:

def set_variables(
self, names: Sequence[str], vars: Dict[str, np.ndarray]
) -> Dict[str, Dict[str, np.ndarray]]:
# import tensorflow as tf
# tf.print("Setting variable inside source.")
if type(names) == str:
vars = {names: vars}
names = [names]

for var_key in names:
self.variables[var_key] = vars[var_key]
self.variables[var_key] = vars[var_key]
Loading