From 56c34feef509bb165bc11fe9197664800ca6a470 Mon Sep 17 00:00:00 2001
From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
Date: Sat, 25 Nov 2023 16:18:57 +0000
Subject: [PATCH] [Docs] Docs website (#34)
* docs
* docs
* docs banner
* change
* change
* api
* amend
* amend
* amend
* amend
* empty
* amend
* amend
* amend
* amend
* amend
* amend
* amend
* amend
* amend
* amend
* amend
* amend
* Amend
* Amend
* Amend
* Amend
* Amend
* Amend
* Amend
* Amend
* Amend
* Amend
* Amend
* Amend
* Amend
* Revert "Amend"
This reverts commit c88d311c8033741748c3ce265d984db368321aec.
* Amend
* Amend
* Amend
* Amend
* Amend
* Amend
* Amend
* Amend
* Amend
* Make utils private
* Amend
* Amend
* Docs
* Docs
* Docs
* Algorithm docs
* Algorithm docs
* Model docs
* Links
---
.gitignore | 5 +-
.readthedocs.yaml | 31 ++++
README.md | 23 ++-
benchmarl/__init__.py | 13 +-
benchmarl/algorithms/__init__.py | 22 +++
benchmarl/algorithms/common.py | 28 ++--
benchmarl/algorithms/iddpg.py | 12 ++
benchmarl/algorithms/ippo.py | 17 ++
benchmarl/algorithms/iql.py | 11 ++
benchmarl/algorithms/isac.py | 26 +++
benchmarl/algorithms/maddpg.py | 11 ++
benchmarl/algorithms/mappo.py | 17 ++
benchmarl/algorithms/masac.py | 25 +++
benchmarl/algorithms/qmix.py | 12 ++
benchmarl/algorithms/vdn.py | 11 ++
benchmarl/benchmark/__init__.py | 7 +
benchmarl/{ => benchmark}/benchmark.py | 17 ++
benchmarl/environments/common.py | 6 +-
benchmarl/environments/pettingzoo/common.py | 2 +
benchmarl/environments/smacv2/common.py | 2 +
benchmarl/environments/vmas/common.py | 2 +
benchmarl/eval_results.py | 70 +++++++-
benchmarl/experiment/__init__.py | 1 +
benchmarl/experiment/callback.py | 8 +-
benchmarl/experiment/experiment.py | 43 ++---
benchmarl/hydra_config.py | 55 ++++++-
benchmarl/models/__init__.py | 8 +-
benchmarl/models/common.py | 32 ++--
benchmarl/models/mlp.py | 16 ++
benchmarl/run.py | 13 ++
benchmarl/utils.py | 4 +-
docs/Makefile | 20 +++
docs/make.bat | 35 ++++
docs/requirements.txt | 6 +
docs/source/_templates/autosummary/class.rst | 9 ++
.../autosummary/class_no_inherit.rst | 8 +
.../_templates/autosummary/class_private.rst | 9 ++
.../autosummary/class_private_no_undoc.rst | 8 +
docs/source/_templates/breadcrumbs.html | 4 +
docs/source/concepts/benchmarks.rst | 25 +++
docs/source/concepts/components.rst | 111 +++++++++++++
docs/source/concepts/configuring.rst | 152 ++++++++++++++++++
docs/source/concepts/extending.rst | 27 ++++
docs/source/concepts/features.rst | 63 ++++++++
docs/source/concepts/reporting.rst | 43 +++++
docs/source/conf.py | 66 ++++++++
docs/source/index.rst | 76 +++++++++
docs/source/modules/algorithms.rst | 33 ++++
docs/source/modules/benchmark.rst | 17 ++
docs/source/modules/environments.rst | 54 +++++++
docs/source/modules/experiment.rst | 29 ++++
docs/source/modules/models.rst | 35 ++++
docs/source/modules/root.rst | 23 +++
docs/source/usage/installation.rst | 64 ++++++++
docs/source/usage/notebooks.rst | 6 +
docs/source/usage/running.rst | 91 +++++++++++
examples/plotting/README.md | 6 +-
57 files changed, 1495 insertions(+), 75 deletions(-)
create mode 100644 .readthedocs.yaml
create mode 100644 benchmarl/benchmark/__init__.py
rename benchmarl/{ => benchmark}/benchmark.py (75%)
create mode 100644 docs/Makefile
create mode 100644 docs/make.bat
create mode 100644 docs/requirements.txt
create mode 100644 docs/source/_templates/autosummary/class.rst
create mode 100644 docs/source/_templates/autosummary/class_no_inherit.rst
create mode 100644 docs/source/_templates/autosummary/class_private.rst
create mode 100644 docs/source/_templates/autosummary/class_private_no_undoc.rst
create mode 100644 docs/source/_templates/breadcrumbs.html
create mode 100644 docs/source/concepts/benchmarks.rst
create mode 100644 docs/source/concepts/components.rst
create mode 100644 docs/source/concepts/configuring.rst
create mode 100644 docs/source/concepts/extending.rst
create mode 100644 docs/source/concepts/features.rst
create mode 100644 docs/source/concepts/reporting.rst
create mode 100644 docs/source/conf.py
create mode 100644 docs/source/index.rst
create mode 100644 docs/source/modules/algorithms.rst
create mode 100644 docs/source/modules/benchmark.rst
create mode 100644 docs/source/modules/environments.rst
create mode 100644 docs/source/modules/experiment.rst
create mode 100644 docs/source/modules/models.rst
create mode 100644 docs/source/modules/root.rst
create mode 100644 docs/source/usage/installation.rst
create mode 100644 docs/source/usage/notebooks.rst
create mode 100644 docs/source/usage/running.rst
diff --git a/.gitignore b/.gitignore
index 475fb7ef..be2e5283 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,7 +4,10 @@
**/outputs/
**/multirun/
-
+# Docs
+docs/output/
+docs/source/generated/
+docs/build/
# Byte-compiled / optimized / DLL files
__pycache__/
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
new file mode 100644
index 00000000..58e9fa45
--- /dev/null
+++ b/.readthedocs.yaml
@@ -0,0 +1,31 @@
+# .readthedocs.yaml
+# Read the Docs configuration file
+# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
+
+# Required
+version: 2
+
+# Set the OS, Python version and other tools you might need
+build:
+ os: ubuntu-22.04
+ tools:
+ python: "3.10"
+
+# Build documentation in the "docs/" directory with Sphinx
+sphinx:
+ fail_on_warning: true
+ configuration: docs/source/conf.py
+
+# Optionally build your docs in additional formats such as PDF and ePub
+formats:
+ - epub
+
+# Optional but recommended, declare the Python requirements required
+# to build your documentation
+# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
+python:
+ install:
+ - requirements: docs/requirements.txt
+ # Install our python package before building the docs
+ - method: pip
+ path: .
diff --git a/README.md b/README.md
index 910ed810..0a49f9e4 100644
--- a/README.md
+++ b/README.md
@@ -1,19 +1,20 @@
-![BenchMARL](https://github.com/matteobettini/vmas-media/blob/main/media/benchmarl.png?raw=true)
+![BenchMARL](https://raw.githubusercontent.com/matteobettini/benchmarl_sphinx_theme/master/benchmarl_sphinx_theme/static/img/benchmarl.png?raw=true)
# BenchMARL
[![tests](https://github.com/facebookresearch/BenchMARL/actions/workflows/unit_tests.yml/badge.svg)](test)
[![codecov](https://codecov.io/github/facebookresearch/BenchMARL/coverage.svg?branch=main)](https://codecov.io/gh/facebookresearch/BenchMARL)
+[![Documentation Status](https://readthedocs.org/projects/benchmarl/badge/?version=latest)](https://benchmarl.readthedocs.io/en/latest/?badge=latest)
[![Python](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10-blue.svg)](https://www.python.org/downloads/)
[![Downloads](https://static.pepy.tech/personalized-badge/benchmarl?period=total&units=international_system&left_color=grey&right_color=blue&left_text=Downloads)](https://pepy.tech/project/benchmarl)
+[![Discord Shield](https://dcbadge.vercel.app/api/server/jEEWCn6T3p?style=flat)](https://discord.gg/jEEWCn6T3p)
```bash
python benchmarl/run.py algorithm=mappo task=vmas/balance
```
-
[![Examples](https://img.shields.io/badge/Examples-blue.svg)](examples) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/facebookresearch/BenchMARL/blob/main/notebooks/run.ipynb)
[![Static Badge](https://img.shields.io/badge/Benchmarks-Wandb-yellow)](https://wandb.ai/matteobettini/benchmarl-public/reportlist)
@@ -58,6 +59,7 @@ the domain and want to easily take a picture of the landscape.
* [Reporting and plotting](#reporting-and-plotting)
* [Extending](#extending)
* [Configuring](#configuring)
+ + [Experiment](#experiment)
+ [Algorithm](#algorithm)
+ [Task](#task)
+ [Model](#model)
@@ -280,10 +282,9 @@ Currently available ones are:
In the following, we report a table of the results:
-| **
Environment
** | **Sample efficiency curves (all tasks)
** | **Performance profile
** | **Aggregate scores
** |
-|---------------------------------------|-------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------|
-| VMAS | | | |
-
+| **Environment
** | **Sample efficiency curves (all tasks)
** | **Performance profile
** | **Aggregate scores
** |
+|---------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| VMAS | | | |
## Reporting and plotting
@@ -295,9 +296,9 @@ your benchmarks. No more struggling with matplotlib and latex!
[![Example](https://img.shields.io/badge/Example-blue.svg)](examples/plotting)
-![aggregate_scores](https://drive.google.com/uc?export=view&id=1q2So9V6sL8NHMtj6vL-S3KyzZi11Vfia)
-![sample_efficiancy](https://drive.google.com/uc?export=view&id=1fzfFn0q54gsALRAwmqD1hRTqQIadGPoE)
-![performace_profile](https://drive.google.com/uc?export=view&id=151pSR2sBluSpWiYxtq3jNX0tfE0vgAuR)
+![aggregate_scores](https://raw.githubusercontent.com/matteobettini/benchmarl_sphinx_theme/master/benchmarl_sphinx_theme/static/img/benchmarks/vmas/aggregate_scores.png)
+![sample_efficiancy](https://raw.githubusercontent.com/matteobettini/benchmarl_sphinx_theme/master/benchmarl_sphinx_theme/static/img/benchmarks/vmas/environemnt_sample_efficiency_curves.png)
+![performace_profile](https://raw.githubusercontent.com/matteobettini/benchmarl_sphinx_theme/master/benchmarl_sphinx_theme/static/img/benchmarks/vmas/performance_profile_figure.png)
## Extending
@@ -322,7 +323,6 @@ in the script itself or via [hydra](https://hydra.cc/docs/intro/).
We suggest to read the hydra documentation
to get familiar with all its functionalities.
-The project can be configured either the script itself or via hydra.
Each component in the project has a corresponding yaml configuration in the BenchMARL
[conf tree](benchmarl/conf).
Components' configurations are loaded from these files into python dataclasses that act
@@ -333,8 +333,7 @@ You can also directly load and validate configuration yaml files without using h
### Experiment
-Experiment configurations are in [`benchmarl/conf/config.yaml`](benchmarl/conf/config.yaml),
-with the experiment hyperparameters in [`benchmarl/conf/experiment`](benchmarl/conf/experiment).
+Experiment configurations are in [`benchmarl/conf/config.yaml`](benchmarl/conf/config.yaml).
Running custom experiments is extremely simplified by the [Hydra](https://hydra.cc/) configurations.
The default configuration for the library is contained in the [`benchmarl/conf`](benchmarl/conf) folder.
diff --git a/benchmarl/__init__.py b/benchmarl/__init__.py
index f953c5ca..96fd64ed 100644
--- a/benchmarl/__init__.py
+++ b/benchmarl/__init__.py
@@ -4,13 +4,22 @@
# LICENSE file in the root directory of this source tree.
#
+
+__version__ = "0.0.4"
+
import importlib
+import benchmarl.algorithms
+import benchmarl.benchmark
+import benchmarl.environments
+import benchmarl.experiment
+import benchmarl.models
+
_has_hydra = importlib.util.find_spec("hydra") is not None
if _has_hydra:
- def load_hydra_schemas():
+ def _load_hydra_schemas():
from hydra.core.config_store import ConfigStore
from benchmarl.algorithms import algorithm_config_registry
@@ -28,4 +37,4 @@ def load_hydra_schemas():
for task_schema_name, task_schema in _task_class_registry.items():
cs.store(name=task_schema_name, group="task", node=task_schema)
- load_hydra_schemas()
+ _load_hydra_schemas()
diff --git a/benchmarl/algorithms/__init__.py b/benchmarl/algorithms/__init__.py
index b9e18647..f0e2d20a 100644
--- a/benchmarl/algorithms/__init__.py
+++ b/benchmarl/algorithms/__init__.py
@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
#
+from .common import Algorithm, AlgorithmConfig
from .iddpg import Iddpg, IddpgConfig
from .ippo import Ippo, IppoConfig
from .iql import Iql, IqlConfig
@@ -14,6 +15,27 @@
from .qmix import Qmix, QmixConfig
from .vdn import Vdn, VdnConfig
+classes = [
+ "Iddpg",
+ "IddpgConfig",
+ "Ippo",
+ "IppoConfig",
+ "Iql",
+ "IqlConfig",
+ "Isac",
+ "IsacConfig",
+ "Maddpg",
+ "MaddpgConfig",
+ "Mappo",
+ "MappoConfig",
+ "Masac",
+ "MasacConfig",
+ "Qmix",
+ "QmixConfig",
+ "Vdn",
+ "VdnConfig",
+]
+
# A registry mapping "algoname" to its config dataclass
# This is used to aid loading of algorithms from yaml
algorithm_config_registry = {
diff --git a/benchmarl/algorithms/common.py b/benchmarl/algorithms/common.py
index 5e0b86f2..3702b75d 100644
--- a/benchmarl/algorithms/common.py
+++ b/benchmarl/algorithms/common.py
@@ -23,7 +23,7 @@
from torchrl.objectives.utils import HardUpdate, SoftUpdate, TargetNetUpdater
from benchmarl.models.common import ModelConfig
-from benchmarl.utils import DEVICE_TYPING, read_yaml_config
+from benchmarl.utils import _read_yaml_config, DEVICE_TYPING
class Algorithm(ABC):
@@ -32,7 +32,7 @@ class Algorithm(ABC):
This should be overridden by implemented algorithms
and all abstract methods should be implemented.
- Args:
+ Args:
experiment (Experiment): the experiment class
"""
@@ -104,14 +104,13 @@ def _check_specs(self):
def get_loss_and_updater(self, group: str) -> Tuple[LossModule, TargetNetUpdater]:
"""
Get the LossModule and TargetNetUpdater for a specific group.
- This function calls the abstract self._get_loss() which needs to be implemented.
+ This function calls the abstract :class:`~benchmarl.algorithms.Algorithm._get_loss()` which needs to be implemented.
The function will cache the output at the first call and return the cached values in future calls.
Args:
group (str): agent group of the loss and updater
Returns: LossModule and TargetNetUpdater for the group
-
"""
if group not in self._losses_and_updaters.keys():
action_space = self.action_spec[group, "action"]
@@ -144,7 +143,7 @@ def get_replay_buffer(
) -> ReplayBuffer:
"""
Get the ReplayBuffer for a specific group.
- This function will check self.on_policy and create the buffer accordingly
+ This function will check ``self.on_policy`` and create the buffer accordingly
Args:
group (str): agent group of the loss and updater
@@ -165,7 +164,7 @@ def get_replay_buffer(
def get_policy_for_loss(self, group: str) -> TensorDictModule:
"""
Get the non-explorative policy for a specific group loss.
- This function calls the abstract self._get_policy_for_loss() which needs to be implemented.
+ This function calls the abstract :class:`~benchmarl.algorithms.Algorithm._get_policy_for_loss()` which needs to be implemented.
The function will cache the output at the first call and return the cached values in future calls.
Args:
@@ -192,7 +191,7 @@ def get_policy_for_loss(self, group: str) -> TensorDictModule:
def get_policy_for_collection(self) -> TensorDictSequential:
"""
Get the explorative policy for all groups together.
- This function calls the abstract self._get_policy_for_collection() which needs to be implemented.
+ This function calls the abstract :class:`~benchmarl.algorithms.Algorithm._get_policy_for_collection()` which needs to be implemented.
The function will cache the output at the first call and return the cached values in future calls.
Returns: TensorDictSequential representing all explorative policies
@@ -217,7 +216,7 @@ def get_policy_for_collection(self) -> TensorDictSequential:
def get_parameters(self, group: str) -> Dict[str, Iterable]:
"""
Get the dictionary mapping loss names to the relative parameters to optimize for a given group.
- This function calls the abstract self._get_parameters() which needs to be implemented.
+ This function calls the abstract :class:`~benchmarl.algorithms.Algorithm._get_parameters()` which needs to be implemented.
Returns: a dictionary mapping loss names to a parameters' list
"""
@@ -323,13 +322,16 @@ class AlgorithmConfig:
Dataclass representing an algorithm configuration.
This should be overridden by implemented algorithms.
Implementors should:
- 1. add configuration parameters for their algorithm
- 2. implement all abstract methods
+
+ 1. add configuration parameters for their algorithm
+ 2. implement all abstract methods
+
"""
def get_algorithm(self, experiment) -> Algorithm:
"""
Main function to turn the config into the associated algorithm
+
Args:
experiment (Experiment): the experiment class
@@ -349,7 +351,7 @@ def _load_from_yaml(name: str) -> Dict[str, Any]:
/ "algorithm"
/ f"{name.lower()}.yaml"
)
- return read_yaml_config(str(yaml_path.resolve()))
+ return _read_yaml_config(str(yaml_path.resolve()))
@classmethod
def get_from_yaml(cls, path: Optional[str] = None):
@@ -359,7 +361,7 @@ def get_from_yaml(cls, path: Optional[str] = None):
Args:
path (str, optional): The full path of the yaml file to load from.
If None, it will default to
- benchmarl/conf/algorithm/self.associated_class().__name__
+ ``benchmarl/conf/algorithm/self.associated_class().__name__``
Returns: the loaded AlgorithmConfig
"""
@@ -370,7 +372,7 @@ def get_from_yaml(cls, path: Optional[str] = None):
)
)
else:
- return cls(**read_yaml_config(path))
+ return cls(**_read_yaml_config(path))
@staticmethod
@abstractmethod
diff --git a/benchmarl/algorithms/iddpg.py b/benchmarl/algorithms/iddpg.py
index 2bf1657c..742a49ef 100644
--- a/benchmarl/algorithms/iddpg.py
+++ b/benchmarl/algorithms/iddpg.py
@@ -19,6 +19,16 @@
class Iddpg(Algorithm):
+ """Same as :class:`~benchmarkl.algorithms.Maddpg` (from `https://arxiv.org/abs/1706.02275 `__) but with decentralized critics.
+
+ Args:
+ share_param_critic (bool): Whether to share the parameters of the critics withing agent groups
+ loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
+ delay_value (bool): whether to separate the target value networks from the value networks used for
+ data collection.
+
+ """
+
def __init__(
self, share_param_critic: bool, loss_function: str, delay_value: bool, **kwargs
):
@@ -227,6 +237,8 @@ def get_value_module(self, group: str) -> TensorDictModule:
@dataclass
class IddpgConfig(AlgorithmConfig):
+ """Configuration dataclass for :class:`~benchmarl.algorithms.Iddpg`."""
+
share_param_critic: bool = MISSING
loss_function: str = MISSING
delay_value: bool = MISSING
diff --git a/benchmarl/algorithms/ippo.py b/benchmarl/algorithms/ippo.py
index f7190630..0dd9bfa2 100644
--- a/benchmarl/algorithms/ippo.py
+++ b/benchmarl/algorithms/ippo.py
@@ -22,6 +22,21 @@
class Ippo(Algorithm):
+ """Independent PPO (from `https://arxiv.org/abs/2011.09533 `__).
+
+ Args:
+ share_param_critic (bool): Whether to share the parameters of the critics withing agent groups
+ clip_epsilon (scalar): weight clipping threshold in the clipped PPO loss equation.
+ entropy_coef (scalar): entropy multiplier when computing the total loss.
+ critic_coef (scalar): critic loss multiplier when computing the total
+ loss_critic_type (str): loss function for the value discrepancy.
+ Can be one of "l1", "l2" or "smooth_l1".
+ lmbda (float): The GAE lambda
+ scale_mapping (str): positive mapping function to be used with the std.
+ choices: "softplus", "exp", "relu", "biased_softplus_1";
+
+ """
+
def __init__(
self,
share_param_critic: bool,
@@ -270,6 +285,8 @@ def get_critic(self, group: str) -> TensorDictModule:
@dataclass
class IppoConfig(AlgorithmConfig):
+ """Configuration dataclass for :class:`~benchmarl.algorithms.Ippo`."""
+
share_param_critic: bool = MISSING
clip_epsilon: float = MISSING
entropy_coef: float = MISSING
diff --git a/benchmarl/algorithms/iql.py b/benchmarl/algorithms/iql.py
index 8838c8fa..526c3d79 100644
--- a/benchmarl/algorithms/iql.py
+++ b/benchmarl/algorithms/iql.py
@@ -18,6 +18,15 @@
class Iql(Algorithm):
+ """Independent Q Learning (from `https://www.semanticscholar.org/paper/Multi-Agent-Reinforcement-Learning%3A-Independent-Tan/59de874c1e547399b695337bcff23070664fa66e `__).
+
+ Args:
+ loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
+ delay_value (bool): whether to separate the target value networks from the value networks used for
+ data collection.
+
+ """
+
def __init__(self, delay_value: bool, loss_function: str, **kwargs):
super().__init__(**kwargs)
@@ -175,6 +184,8 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
@dataclass
class IqlConfig(AlgorithmConfig):
+ """Configuration dataclass for :class:`~benchmarl.algorithms.Iql`."""
+
delay_value: bool = MISSING
loss_function: str = MISSING
diff --git a/benchmarl/algorithms/isac.py b/benchmarl/algorithms/isac.py
index 20df1ac1..972b762f 100644
--- a/benchmarl/algorithms/isac.py
+++ b/benchmarl/algorithms/isac.py
@@ -26,6 +26,30 @@
class Isac(Algorithm):
+ """Independent Soft Actor Critic.
+
+ Args:
+ share_param_critic (bool): Whether to share the parameters of the critics withing agent groups
+ num_qvalue_nets (integer): number of Q-Value networks used.
+ loss_function (str): loss function to be used with
+ the value function loss.
+ delay_qvalue (bool): Whether to separate the target Q value
+ networks from the Q value networks used for data collection.
+ target_entropy (float or str, optional): Target entropy for the
+ stochastic policy. Default is "auto", where target entropy is
+ computed as :obj:`-prod(n_actions)`.
+ discrete_target_entropy_weight (float): weight for the target entropy term when actions are discrete
+ alpha_init (float): initial entropy multiplier.
+ min_alpha (float): min value of alpha.
+ max_alpha (float): max value of alpha.
+ fixed_alpha (bool): if ``True``, alpha will be fixed to its
+ initial value. Otherwise, alpha will be optimized to
+ match the 'target_entropy' value.
+ scale_mapping (str): positive mapping function to be used with the std.
+ choices: "softplus", "exp", "relu", "biased_softplus_1";
+
+ """
+
def __init__(
self,
share_param_critic: bool,
@@ -358,6 +382,8 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:
@dataclass
class IsacConfig(AlgorithmConfig):
+ """Configuration dataclass for :class:`~benchmarl.algorithms.Isac`."""
+
share_param_critic: bool = MISSING
num_qvalue_nets: int = MISSING
diff --git a/benchmarl/algorithms/maddpg.py b/benchmarl/algorithms/maddpg.py
index df79de41..60501708 100644
--- a/benchmarl/algorithms/maddpg.py
+++ b/benchmarl/algorithms/maddpg.py
@@ -19,6 +19,15 @@
class Maddpg(Algorithm):
+ """Multi Agent DDPG (from `https://arxiv.org/abs/1706.02275 `__).
+
+ Args:
+ share_param_critic (bool): Whether to share the parameters of the critics withing agent groups
+ loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
+ delay_value (bool): whether to separate the target value networks from the value networks used for
+ data collection.
+ """
+
def __init__(
self, share_param_critic: bool, loss_function: str, delay_value: bool, **kwargs
):
@@ -283,6 +292,8 @@ def get_value_module(self, group: str) -> TensorDictModule:
@dataclass
class MaddpgConfig(AlgorithmConfig):
+ """Configuration dataclass for :class:`~benchmarl.algorithms.Maddpg`."""
+
share_param_critic: bool = MISSING
loss_function: str = MISSING
diff --git a/benchmarl/algorithms/mappo.py b/benchmarl/algorithms/mappo.py
index cae735e9..c856642f 100644
--- a/benchmarl/algorithms/mappo.py
+++ b/benchmarl/algorithms/mappo.py
@@ -21,6 +21,21 @@
class Mappo(Algorithm):
+ """Multi Agent PPO (from `https://arxiv.org/abs/2103.01955 `__).
+
+ Args:
+ share_param_critic (bool): Whether to share the parameters of the critics withing agent groups
+ clip_epsilon (scalar): weight clipping threshold in the clipped PPO loss equation.
+ entropy_coef (scalar): entropy multiplier when computing the total loss.
+ critic_coef (scalar): critic loss multiplier when computing the total
+ loss_critic_type (str): loss function for the value discrepancy.
+ Can be one of "l1", "l2" or "smooth_l1".
+ lmbda (float): The GAE lambda
+ scale_mapping (str): positive mapping function to be used with the std.
+ choices: "softplus", "exp", "relu", "biased_softplus_1";
+
+ """
+
def __init__(
self,
share_param_critic: bool,
@@ -301,6 +316,8 @@ def get_critic(self, group: str) -> TensorDictModule:
@dataclass
class MappoConfig(AlgorithmConfig):
+ """Configuration dataclass for :class:`~benchmarl.algorithms.Mappo`."""
+
share_param_critic: bool = MISSING
clip_epsilon: float = MISSING
entropy_coef: float = MISSING
diff --git a/benchmarl/algorithms/masac.py b/benchmarl/algorithms/masac.py
index 95c67ae7..291a6588 100644
--- a/benchmarl/algorithms/masac.py
+++ b/benchmarl/algorithms/masac.py
@@ -20,6 +20,29 @@
class Masac(Algorithm):
+ """Multi Agent Soft Actor Critic.
+
+ Args:
+ share_param_critic (bool): Whether to share the parameters of the critics withing agent groups
+ num_qvalue_nets (integer): number of Q-Value networks used.
+ loss_function (str): loss function to be used with
+ the value function loss.
+ delay_qvalue (bool): Whether to separate the target Q value
+ networks from the Q value networks used for data collection.
+ target_entropy (float or str, optional): Target entropy for the
+ stochastic policy. Default is "auto", where target entropy is
+ computed as :obj:`-prod(n_actions)`.
+ discrete_target_entropy_weight (float): weight for the target entropy term when actions are discrete
+ alpha_init (float): initial entropy multiplier.
+ min_alpha (float): min value of alpha.
+ max_alpha (float): max value of alpha.
+ fixed_alpha (bool): if ``True``, alpha will be fixed to its
+ initial value. Otherwise, alpha will be optimized to
+ match the 'target_entropy' value.
+ scale_mapping (str): positive mapping function to be used with the std.
+ choices: "softplus", "exp", "relu", "biased_softplus_1";
+ """
+
def __init__(
self,
share_param_critic: bool,
@@ -434,6 +457,8 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:
@dataclass
class MasacConfig(AlgorithmConfig):
+ """Configuration dataclass for :class:`~benchmarl.algorithms.Masac`."""
+
share_param_critic: bool = MISSING
num_qvalue_nets: int = MISSING
diff --git a/benchmarl/algorithms/qmix.py b/benchmarl/algorithms/qmix.py
index fd3bd7cd..f4edc1fd 100644
--- a/benchmarl/algorithms/qmix.py
+++ b/benchmarl/algorithms/qmix.py
@@ -18,6 +18,16 @@
class Qmix(Algorithm):
+ """QMIX (from `https://arxiv.org/abs/1803.11485 `__).
+
+ Args:
+ mixing_embed_dim (int): hidden dimension of the mixing network
+ loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
+ delay_value (bool): whether to separate the target value networks from the value networks used for
+ data collection.
+
+ """
+
def __init__(
self, mixing_embed_dim: int, delay_value: bool, loss_function: str, **kwargs
):
@@ -200,6 +210,8 @@ def get_mixer(self, group: str) -> TensorDictModule:
@dataclass
class QmixConfig(AlgorithmConfig):
+ """Configuration dataclass for :class:`~benchmarl.algorithms.Qmix`."""
+
mixing_embed_dim: int = MISSING
delay_value: bool = MISSING
loss_function: str = MISSING
diff --git a/benchmarl/algorithms/vdn.py b/benchmarl/algorithms/vdn.py
index 8e250f83..4ac77ab8 100644
--- a/benchmarl/algorithms/vdn.py
+++ b/benchmarl/algorithms/vdn.py
@@ -18,6 +18,15 @@
class Vdn(Algorithm):
+ """Vdn (from `https://arxiv.org/abs/1706.05296 `__).
+
+ Args:
+ loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
+ delay_value (bool): whether to separate the target value networks from the value networks used for
+ data collection.
+
+ """
+
def __init__(self, delay_value: bool, loss_function: str, **kwargs):
super().__init__(**kwargs)
@@ -189,6 +198,8 @@ def get_mixer(self, group: str) -> TensorDictModule:
@dataclass
class VdnConfig(AlgorithmConfig):
+ """Configuration dataclass for :class:`~benchmarl.algorithms.Vdn`."""
+
delay_value: bool = MISSING
loss_function: str = MISSING
diff --git a/benchmarl/benchmark/__init__.py b/benchmarl/benchmark/__init__.py
new file mode 100644
index 00000000..0fca8b53
--- /dev/null
+++ b/benchmarl/benchmark/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+#
+
+from .benchmark import Benchmark
diff --git a/benchmarl/benchmark.py b/benchmarl/benchmark/benchmark.py
similarity index 75%
rename from benchmarl/benchmark.py
rename to benchmarl/benchmark/benchmark.py
index a2a296a0..6d59eaff 100644
--- a/benchmarl/benchmark.py
+++ b/benchmarl/benchmark/benchmark.py
@@ -13,6 +13,20 @@
class Benchmark:
+ """A benchmark.
+
+ Benchmarks are collections of experiments to compare.
+
+ Args:
+ algorithm_configs (list of AlgorithmConfig): the algorithms to benchmark
+ model_config (ModelConfig): the config of the policy model
+ tasks (list of Task): the tasks to benchmark
+ seeds (set of int): the seeds for the benchmark
+ experiment_config (ExperimentConfig): the experiment config
+ critic_model_config (ModelConfig, optional): the config of the critic model. Defaults to model_config
+
+ """
+
def __init__(
self,
algorithm_configs: Sequence[AlgorithmConfig],
@@ -36,9 +50,11 @@ def __init__(
@property
def n_experiments(self):
+ """The number of experiments in the benchmark."""
return len(self.algorithm_configs) * len(self.tasks) * len(self.seeds)
def get_experiments(self) -> Iterator[Experiment]:
+ """Yields one experiment at a time"""
for algorithm_config in self.algorithm_configs:
for task in self.tasks:
for seed in self.seeds:
@@ -52,6 +68,7 @@ def get_experiments(self) -> Iterator[Experiment]:
)
def run_sequential(self):
+ """Run all the experiments in the benchmark in a sequence."""
for i, experiment in enumerate(self.get_experiments()):
print(f"\nRunning experiment {i+1}/{self.n_experiments}.\n")
try:
diff --git a/benchmarl/environments/common.py b/benchmarl/environments/common.py
index c7f49605..fd4fc2a6 100644
--- a/benchmarl/environments/common.py
+++ b/benchmarl/environments/common.py
@@ -17,7 +17,7 @@
from torchrl.data import CompositeSpec
from torchrl.envs import EnvBase, RewardSum, Transform
-from benchmarl.utils import DEVICE_TYPING, read_yaml_config
+from benchmarl.utils import _read_yaml_config, DEVICE_TYPING
def _load_config(name: str, config: Dict[str, Any]):
@@ -255,7 +255,7 @@ def __str__(self):
@staticmethod
def _load_from_yaml(name: str) -> Dict[str, Any]:
yaml_path = Path(__file__).parent.parent / "conf" / "task" / f"{name}.yaml"
- return read_yaml_config(str(yaml_path.resolve()))
+ return _read_yaml_config(str(yaml_path.resolve()))
def get_from_yaml(self, path: Optional[str] = None) -> Task:
"""
@@ -273,4 +273,4 @@ def get_from_yaml(self, path: Optional[str] = None) -> Task:
Task._load_from_yaml(str(Path(self.env_name()) / Path(task_name)))
)
else:
- return self.update_config(**read_yaml_config(path))
+ return self.update_config(**_read_yaml_config(path))
diff --git a/benchmarl/environments/pettingzoo/common.py b/benchmarl/environments/pettingzoo/common.py
index 372616b8..fdc4eb61 100644
--- a/benchmarl/environments/pettingzoo/common.py
+++ b/benchmarl/environments/pettingzoo/common.py
@@ -15,6 +15,8 @@
class PettingZooTask(Task):
+ """Enum for PettingZoo tasks."""
+
MULTIWALKER = None
WATERWORLD = None
SIMPLE_ADVERSARY = None
diff --git a/benchmarl/environments/smacv2/common.py b/benchmarl/environments/smacv2/common.py
index f9dcc691..47043396 100644
--- a/benchmarl/environments/smacv2/common.py
+++ b/benchmarl/environments/smacv2/common.py
@@ -17,6 +17,8 @@
class Smacv2Task(Task):
+ """Enum for SMACv2 tasks."""
+
PROTOSS_5_VS_5 = None
PROTOSS_10_VS_10 = None
PROTOSS_10_VS_11 = None
diff --git a/benchmarl/environments/vmas/common.py b/benchmarl/environments/vmas/common.py
index ba00c94d..0ee22ac8 100644
--- a/benchmarl/environments/vmas/common.py
+++ b/benchmarl/environments/vmas/common.py
@@ -15,6 +15,8 @@
class VmasTask(Task):
+ """Enum for VMAS tasks."""
+
BALANCE = None
SAMPLING = None
NAVIGATION = None
diff --git a/benchmarl/eval_results.py b/benchmarl/eval_results.py
index 83f7a572..94f2174d 100644
--- a/benchmarl/eval_results.py
+++ b/benchmarl/eval_results.py
@@ -28,6 +28,24 @@
def get_raw_dict_from_multirun_folder(multirun_folder: str) -> Dict:
+ """Get the ``marl-eval`` input dictionary from the folder of a hydra multirun.
+
+ Examples:
+ .. code-block:: python
+
+ from benchmarl.eval_results import get_raw_dict_from_multirun_folder, Plotting
+ raw_dict = get_raw_dict_from_multirun_folder(
+ multirun_folder="some_prefix/multirun/2023-09-22/17-21-34"
+ )
+ processed_data = Plotting.process_data(raw_dict)
+
+ Args:
+ multirun_folder (str): the absolute path to the multirun folder
+
+ Returns:
+ the dict obtained by merging all the json files in the multirun
+
+ """
return load_and_merge_json_dicts(_get_json_files_from_multirun(multirun_folder))
@@ -43,6 +61,18 @@ def _get_json_files_from_multirun(multirun_folder: str) -> List[str]:
def load_and_merge_json_dicts(
json_input_files: List[str], json_output_file: Optional[str] = None
) -> Dict:
+ """Loads and merges json dictionaries to form the ``marl-eval`` input dictionary .
+
+ Args:
+ json_input_files (list of str): a list containing the absolute paths to the json files
+ json_output_file (str, optional): if specified, the merged dictionary will be also written
+ to the file in this absolute path
+
+ Returns:
+ the dict obtained by merging all the json files
+
+ """
+
def update(d, u):
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
@@ -67,13 +97,49 @@ def update(d, u):
class Plotting:
+ """Class containing static utilities for plotting in ``marl-eval``.
+
+ Examples:
+ >>> from benchmarl.eval_results import get_raw_dict_from_multirun_folder, Plotting
+ >>> raw_dict = get_raw_dict_from_multirun_folder(
+ ... multirun_folder="some_prefix/multirun/2023-09-22/17-21-34"
+ ... )
+ >>> processed_data = Plotting.process_data(raw_dict)
+ ... (
+ ... environment_comparison_matrix,
+ ... sample_efficiency_matrix,
+ ... ) = Plotting.create_matrices(processed_data, env_name="vmas")
+ >>> Plotting.performance_profile_figure(
+ ... environment_comparison_matrix=environment_comparison_matrix
+ ... )
+ >>> Plotting.aggregate_scores(
+ ... environment_comparison_matrix=environment_comparison_matrix
+ ... )
+ >>> Plotting.environemnt_sample_efficiency_curves(
+ ... sample_effeciency_matrix=sample_efficiency_matrix
+ ... )
+ >>> Plotting.task_sample_efficiency_curves(
+ ... processed_data=processed_data, env="vmas", task="navigation"
+ ... )
+ >>> plt.show()
+
+ """
METRICS_TO_NORMALIZE = ["return"]
METRIC_TO_PLOT = "return"
@staticmethod
- def process_data(raw_data: Dict):
- # Call data_process_pipeline to normalize the choosen metrics and to clean the data
+ def process_data(raw_data: Dict) -> Dict:
+ """Call ``data_process_pipeline`` to normalize the chosen metrics and to clean the data
+
+ Args:
+ raw_data (dict): the input data
+
+ Returns:
+ the processed dict
+
+ """
+
return data_process_pipeline(
raw_data=raw_data, metrics_to_normalize=Plotting.METRICS_TO_NORMALIZE
)
diff --git a/benchmarl/experiment/__init__.py b/benchmarl/experiment/__init__.py
index 0e954fb2..775bb0da 100644
--- a/benchmarl/experiment/__init__.py
+++ b/benchmarl/experiment/__init__.py
@@ -4,4 +4,5 @@
# LICENSE file in the root directory of this source tree.
#
+from .callback import Callback
from .experiment import Experiment, ExperimentConfig
diff --git a/benchmarl/experiment/callback.py b/benchmarl/experiment/callback.py
index c93b85d8..d5c32f13 100644
--- a/benchmarl/experiment/callback.py
+++ b/benchmarl/experiment/callback.py
@@ -76,11 +76,11 @@ def __init__(self, experiment, callbacks: List[Callback]):
for callback in self.callbacks:
callback.experiment = experiment
- def on_batch_collected(self, batch: TensorDictBase):
+ def _on_batch_collected(self, batch: TensorDictBase):
for callback in self.callbacks:
callback.on_batch_collected(batch)
- def on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase:
+ def _on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase:
train_td = None
for callback in self.callbacks:
td = callback.on_train_step(batch, group)
@@ -91,10 +91,10 @@ def on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase:
train_td.update(td)
return train_td
- def on_train_end(self, training_td: TensorDictBase, group: str):
+ def _on_train_end(self, training_td: TensorDictBase, group: str):
for callback in self.callbacks:
callback.on_train_end(training_td, group)
- def on_evaluation_end(self, rollouts: List[TensorDictBase]):
+ def _on_evaluation_end(self, rollouts: List[TensorDictBase]):
for callback in self.callbacks:
callback.on_evaluation_end(rollouts)
diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py
index ba43a3df..ba550fcd 100644
--- a/benchmarl/experiment/experiment.py
+++ b/benchmarl/experiment/experiment.py
@@ -30,7 +30,7 @@
from benchmarl.experiment.callback import Callback, CallbackNotifier
from benchmarl.experiment.logger import Logger
from benchmarl.models.common import ModelConfig
-from benchmarl.utils import read_yaml_config
+from benchmarl.utils import _read_yaml_config
_has_hydra = importlib.util.find_spec("hydra") is not None
if _has_hydra:
@@ -44,7 +44,7 @@ class ExperimentConfig:
This class acts as a schema for loading and validating yaml configurations.
Parameters in this class aim to be agnostic of the algorithm, task or model used.
- To know their meaning, please check out the descriptions in benchmarl/conf/experiment/base_experiment.yaml
+ To know their meaning, please check out the descriptions in ``benchmarl/conf/experiment/base_experiment.yaml``
"""
sampling_device: str = MISSING
@@ -111,7 +111,7 @@ def train_minibatch_size(self, on_policy: bool) -> int:
"""
The minibatch size of tensors used for training.
On-policy algorithms are trained by splitting the train_batch_size (equal to the collected frames) into minibatches.
- Off-policy algorithms do not go through this process and thus have the train_minibatch_size==train_batch_size
+ Off-policy algorithms do not go through this process and thus have the ``train_minibatch_size==train_batch_size``
Args:
on_policy (bool): is the algorithms on_policy
@@ -168,8 +168,8 @@ def n_envs_per_worker(self, on_policy: bool) -> int:
"""
Number of environments used for collection
- In vectorized environments, this will be the vectorized batch_size.
- In other environments, this will be emulated by running them sequentially.
+ - In vectorized environments, this will be the vectorized batch_size.
+ - In other environments, this will be emulated by running them sequentially.
Args:
on_policy (bool): is the algorithms on_policy
@@ -233,9 +233,10 @@ def get_from_yaml(path: Optional[str] = None):
Args:
path (str, optional): The full path of the yaml file to load from.
If None, it will default to
- benchmarl/conf/experiment/base_experiment.yaml
+ ``benchmarl/conf/experiment/base_experiment.yaml``
- Returns: the loaded ExperimentConfig
+ Returns:
+ the loaded :class:`~benchmarl.experiment.ExperimentConfig`
"""
if path is None:
yaml_path = (
@@ -244,9 +245,9 @@ def get_from_yaml(path: Optional[str] = None):
/ "experiment"
/ "base_experiment.yaml"
)
- return ExperimentConfig(**read_yaml_config(str(yaml_path.resolve())))
+ return ExperimentConfig(**_read_yaml_config(str(yaml_path.resolve())))
else:
- return ExperimentConfig(**read_yaml_config(path))
+ return ExperimentConfig(**_read_yaml_config(path))
def validate(self, on_policy: bool):
"""
@@ -282,16 +283,15 @@ class Experiment(CallbackNotifier):
"""
Main experiment class in BenchMARL.
-
Args:
task (Task): the task configuration
algorithm_config (AlgorithmConfig): the algorithm configuration
model_config (ModelConfig): the policy model configuration
seed (int): the seed for the experiment
- config (ExperimentConfig):
+ config (ExperimentConfig): the experiment config
critic_model_config (ModelConfig, optional): the policy model configuration.
If None, it defaults to model_config
- callbacks (list of Callback, optional): list of benchmarl.experiment.callbacks.Callback for this experiment
+ callbacks (list of Callback, optional): callbacks for this experiment
"""
def __init__(
@@ -330,7 +330,7 @@ def __init__(
@property
def on_policy(self) -> bool:
- """Weather the algorithm has to be run on policy"""
+ """Whether the algorithm has to be run on policy."""
return self.algorithm_config.on_policy()
def _setup(self):
@@ -538,7 +538,7 @@ def _collection_loop(self):
pbar.set_description(f"mean return = {self.mean_return}", refresh=False)
# Callback
- self.on_batch_collected(batch)
+ self._on_batch_collected(batch)
# Loop over groups
training_start = time.time()
@@ -561,7 +561,7 @@ def _collection_loop(self):
)
# Callback
- self.on_train_end(training_td, group)
+ self._on_train_end(training_td, group)
# Exploration update
if isinstance(self.group_policies[group], TensorDictSequential):
@@ -651,7 +651,7 @@ def _optimizer_loop(self, group: str) -> TensorDictBase:
if self.target_updaters[group] is not None:
self.target_updaters[group].step()
- callback_loss = self.on_train_step(subdata, group)
+ callback_loss = self._on_train_step(subdata, group)
if callback_loss is not None:
training_td.update(callback_loss)
@@ -720,11 +720,11 @@ def callback(env, td):
total_frames=self.total_frames,
)
# Callback
- self.on_evaluation_end(rollouts)
+ self._on_evaluation_end(rollouts)
# Saving experiment state
def state_dict(self) -> OrderedDict:
- """Get the state_dict for the experiment"""
+ """Get the state_dict for the experiment."""
state = OrderedDict(
total_time=self.total_time,
total_frames=self.total_frames,
@@ -743,7 +743,12 @@ def state_dict(self) -> OrderedDict:
return state_dict
def load_state_dict(self, state_dict: Dict) -> None:
- """Load the state_dict for the experiment"""
+ """Load the state_dict for the experiment.
+
+ Args:
+ state_dict (dict): the state dict
+
+ """
for group in self.group_map.keys():
self.losses[group].load_state_dict(state_dict[f"loss_{group}"])
self.replay_buffers[group].load_state_dict(state_dict[f"buffer_{group}"])
diff --git a/benchmarl/hydra_config.py b/benchmarl/hydra_config.py
index e1465e2a..3f9cb100 100644
--- a/benchmarl/hydra_config.py
+++ b/benchmarl/hydra_config.py
@@ -3,8 +3,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
-
-from omegaconf import DictConfig, OmegaConf
+import importlib
from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task, task_config_registry
@@ -12,8 +11,23 @@
from benchmarl.models import model_config_registry
from benchmarl.models.common import ModelConfig, parse_model_config, SequenceModelConfig
+_has_hydra = importlib.util.find_spec("hydra") is not None
+
+if _has_hydra:
+ from omegaconf import DictConfig, OmegaConf
+
def load_experiment_from_hydra(cfg: DictConfig, task_name: str) -> Experiment:
+ """Creates an :class:`~benchmarl.experiment.Experiment` from hydra config.
+
+ Args:
+ cfg (DictConfig): the config dictionary from hydra main
+ task_name (str): the name of the task to load
+
+ Returns:
+ :class:`~benchmarl.experiment.Experiment`
+
+ """
algorithm_config = load_algorithm_config_from_hydra(cfg.algorithm)
experiment_config = load_experiment_config_from_hydra(cfg.experiment)
task_config = load_task_config_from_hydra(cfg.task, task_name)
@@ -31,20 +45,57 @@ def load_experiment_from_hydra(cfg: DictConfig, task_name: str) -> Experiment:
def load_task_config_from_hydra(cfg: DictConfig, task_name: str) -> Task:
+ """Returns a :class:`~benchmarl.environments.Task` from hydra config.
+
+ Args:
+ cfg (DictConfig): the task config dictionary from hydra
+ task_name (str): the name of the task to load
+
+ Returns:
+ :class:`~benchmarl.environments.Task`
+
+ """
return task_config_registry[task_name].update_config(
OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
)
def load_experiment_config_from_hydra(cfg: DictConfig) -> ExperimentConfig:
+ """Returns a :class:`~benchmarl.experiment.ExperimentConfig` from hydra config.
+
+ Args:
+ cfg (DictConfig): the experiment config dictionary from hydra
+
+ Returns:
+ :class:`~benchmarl.experiment.ExperimentConfig`
+
+ """
return OmegaConf.to_object(cfg)
def load_algorithm_config_from_hydra(cfg: DictConfig) -> AlgorithmConfig:
+ """Returns a :class:`~benchmarl.algorithms.AlgorithmConfig` from hydra config.
+
+ Args:
+ cfg (DictConfig): the algorithm config dictionary from hydra
+
+ Returns:
+ :class:`~benchmarl.algorithms.AlgorithmConfig`
+
+ """
return OmegaConf.to_object(cfg)
def load_model_config_from_hydra(cfg: DictConfig) -> ModelConfig:
+ """Returns a :class:`~benchmarl.models.ModelConfig` from hydra config.
+
+ Args:
+ cfg (DictConfig): the model config dictionary from hydra
+
+ Returns:
+ :class:`~benchmarl.models.ModelConfig`
+
+ """
if "layers" in cfg.keys():
model_configs = [
load_model_config_from_hydra(cfg.layers[f"l{i}"])
diff --git a/benchmarl/models/__init__.py b/benchmarl/models/__init__.py
index b437626a..fa71cc1b 100644
--- a/benchmarl/models/__init__.py
+++ b/benchmarl/models/__init__.py
@@ -4,6 +4,12 @@
# LICENSE file in the root directory of this source tree.
#
-from .mlp import MlpConfig
+from .common import Model, ModelConfig, SequenceModel, SequenceModelConfig
+from .mlp import Mlp, MlpConfig
+
+classes = [
+ "Mlp",
+ "MlpConfig",
+]
model_config_registry = {"mlp": MlpConfig}
diff --git a/benchmarl/models/common.py b/benchmarl/models/common.py
index 7110005d..a0ec5b43 100644
--- a/benchmarl/models/common.py
+++ b/benchmarl/models/common.py
@@ -15,7 +15,7 @@
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs import EnvBase
-from benchmarl.utils import class_from_name, DEVICE_TYPING, read_yaml_config
+from benchmarl.utils import _class_from_name, _read_yaml_config, DEVICE_TYPING
def _check_spec(tensordict, spec):
@@ -28,7 +28,7 @@ def parse_model_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
kwargs = {}
for key, value in cfg.items():
if key.endswith("class") and value is not None:
- value = class_from_name(cfg[key])
+ value = _class_from_name(cfg[key])
kwargs.update({key: value})
return kwargs
@@ -60,12 +60,12 @@ class Model(TensorDictModuleBase, ABC):
output_spec (CompositeSpec): the output spec of the model
agent_group (str): the name of the agent group the model is for
n_agents (int): the number of agents this module is for
- device (str): the mdoel's device
+ device (str): the model's device
input_has_agent_dim (bool): This tells the model if the input will have a multi-agent dimension or not.
For example, the input of policies will always have this set to true,
but critics that use a global state have this set to false as the state is shared by all agents
centralised (bool): This tells the model if it has full observability.
- This will always be true when self.input_has_agent_dim==False,
+ This will always be true when ``self.input_has_agent_dim==False``,
but in cases where the input has the agent dimension, this parameter is
used to distinguish between a decentralised model (where each agent's data
is processed separately) and a centralized model, where the model pools all data together
@@ -114,8 +114,8 @@ def __init__(
def output_has_agent_dim(self) -> bool:
"""
This is a dynamically computed attribute that indicates if the output will have the agent dimension.
- This will be false when share_params==True and centralised==True, and true in all other cases.
- When output_has_agent_dim is true, your model's output should contain the multiagent dimension,
+ This will be false when ``share_params==True and centralised==True``, and true in all other cases.
+ When output_has_agent_dim is true, your model's output should contain the multi-agent dimension,
and the dimension should be absent otherwise
"""
return output_has_agent_dim(self.share_params, self.centralised)
@@ -170,6 +170,12 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
class SequenceModel(Model):
+ """A sequence of :class:`~benchmarl.models.Model`
+
+ Args:
+ models (list of Model): the models in the sequence
+ """
+
def __init__(
self,
models: List[Model],
@@ -194,11 +200,13 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
@dataclass
class ModelConfig(ABC):
"""
- Dataclass representing an model configuration.
+ Dataclass representing a :class:`~benchmarl.models.Model` configuration.
This should be overridden by implemented models.
Implementors should:
- 1. add configuration parameters for their algorithm
- 2. implement all abstract methods
+
+ 1. add configuration parameters for their algorithm
+ 2. implement all abstract methods
+
"""
def get_model(
@@ -280,7 +288,7 @@ def _load_from_yaml(name: str) -> Dict[str, Any]:
/ "layers"
/ f"{name.lower()}.yaml"
)
- cfg = read_yaml_config(str(yaml_path.resolve()))
+ cfg = _read_yaml_config(str(yaml_path.resolve()))
return parse_model_config(cfg)
@classmethod
@@ -302,11 +310,13 @@ def get_from_yaml(cls, path: Optional[str] = None):
)
)
else:
- return cls(**parse_model_config(read_yaml_config(path)))
+ return cls(**parse_model_config(_read_yaml_config(path)))
@dataclass
class SequenceModelConfig(ModelConfig):
+ """Dataclass for a :class:`~benchmarl.models.SequenceModel`."""
+
model_configs: Sequence[ModelConfig]
intermediate_sizes: Sequence[int]
diff --git a/benchmarl/models/mlp.py b/benchmarl/models/mlp.py
index 76e25eef..dfd49131 100644
--- a/benchmarl/models/mlp.py
+++ b/benchmarl/models/mlp.py
@@ -18,6 +18,20 @@
class Mlp(Model):
+ """Multi layer perceptron model.
+
+ Args:
+ num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If
+ an integer is provided, every layer will have the same number of cells. If an iterable is provided,
+ the linear layers out_features will match the content of num_cells.
+ layer_class (Type[nn.Module]): class to be used for the linear layers;
+ activation_class (Type[nn.Module]): activation class to be used.
+ activation_kwargs (dict, optional): kwargs to be used with the activation class;
+ norm_class (Type, optional): normalization class, if any.
+ norm_kwargs (dict, optional): kwargs to be used with the normalization layers;
+
+ """
+
def __init__(
self,
**kwargs,
@@ -106,6 +120,8 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
@dataclass
class MlpConfig(ModelConfig):
+ """Dataclass config for a :class:`~benchmarl.models.Mlp`."""
+
num_cells: Sequence[int] = MISSING
layer_class: Type[nn.Module] = MISSING
diff --git a/benchmarl/run.py b/benchmarl/run.py
index a2cfb299..459b8f6f 100644
--- a/benchmarl/run.py
+++ b/benchmarl/run.py
@@ -13,6 +13,19 @@
@hydra.main(version_base=None, config_path="conf", config_name="config")
def hydra_experiment(cfg: DictConfig) -> None:
+ """Runs an experiment loading its config from hydra.
+
+ This function is decorated as ``@hydra.main`` and is called by running
+
+ .. code-block:: console
+
+ python benchmarl/run.py algorithm=mappo task=vmas/balance
+
+
+ Args:
+ cfg (DictConfig): the hydra config dictionary
+
+ """
hydra_choices = HydraConfig.get().runtime.choices
task_name = hydra_choices.task
algorithm_name = hydra_choices.algorithm
diff --git a/benchmarl/utils.py b/benchmarl/utils.py
index 3e3fcacc..b1f80141 100644
--- a/benchmarl/utils.py
+++ b/benchmarl/utils.py
@@ -13,7 +13,7 @@
DEVICE_TYPING = Union[torch.device, str, int]
-def read_yaml_config(config_file: str) -> Dict[str, Any]:
+def _read_yaml_config(config_file: str) -> Dict[str, Any]:
with open(config_file) as config:
yaml_string = config.read()
config_dict = yaml.safe_load(yaml_string)
@@ -22,7 +22,7 @@ def read_yaml_config(config_file: str) -> Dict[str, Any]:
return config_dict
-def class_from_name(name: str):
+def _class_from_name(name: str):
name_split = name.split(".")
module_name = ".".join(name_split[:-1])
class_name = name_split[-1]
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 00000000..d0c3cbf1
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = source
+BUILDDIR = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/make.bat b/docs/make.bat
new file mode 100644
index 00000000..9534b018
--- /dev/null
+++ b/docs/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=source
+set BUILDDIR=build
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.http://sphinx-doc.org/
+ exit /b 1
+)
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/docs/requirements.txt b/docs/requirements.txt
new file mode 100644
index 00000000..897e4720
--- /dev/null
+++ b/docs/requirements.txt
@@ -0,0 +1,6 @@
+git+https://github.com/matteobettini/benchmarl_sphinx_theme.git
+torchrl>=0.2.0
+torch
+tqdm
+hydra-core
+vmas>=1.2.10
diff --git a/docs/source/_templates/autosummary/class.rst b/docs/source/_templates/autosummary/class.rst
new file mode 100644
index 00000000..139e7c16
--- /dev/null
+++ b/docs/source/_templates/autosummary/class.rst
@@ -0,0 +1,9 @@
+{{ fullname | escape | underline }}
+
+.. currentmodule:: {{ module }}
+
+.. autoclass:: {{ objname }}
+ :show-inheritance:
+ :members:
+ :undoc-members:
+ :inherited-members:
diff --git a/docs/source/_templates/autosummary/class_no_inherit.rst b/docs/source/_templates/autosummary/class_no_inherit.rst
new file mode 100644
index 00000000..08b5ed83
--- /dev/null
+++ b/docs/source/_templates/autosummary/class_no_inherit.rst
@@ -0,0 +1,8 @@
+{{ fullname | escape | underline }}
+
+.. currentmodule:: {{ module }}
+
+.. autoclass:: {{ objname }}
+ :show-inheritance:
+ :members:
+ :undoc-members:
diff --git a/docs/source/_templates/autosummary/class_private.rst b/docs/source/_templates/autosummary/class_private.rst
new file mode 100644
index 00000000..e9f2f9de
--- /dev/null
+++ b/docs/source/_templates/autosummary/class_private.rst
@@ -0,0 +1,9 @@
+{{ fullname | escape | underline }}
+
+.. currentmodule:: {{ module }}
+
+.. autoclass:: {{ objname }}
+ :show-inheritance:
+ :members:
+ :undoc-members:
+ :private-members:
diff --git a/docs/source/_templates/autosummary/class_private_no_undoc.rst b/docs/source/_templates/autosummary/class_private_no_undoc.rst
new file mode 100644
index 00000000..191ccbcf
--- /dev/null
+++ b/docs/source/_templates/autosummary/class_private_no_undoc.rst
@@ -0,0 +1,8 @@
+{{ fullname | escape | underline }}
+
+.. currentmodule:: {{ module }}
+
+.. autoclass:: {{ objname }}
+ :show-inheritance:
+ :members:
+ :private-members:
diff --git a/docs/source/_templates/breadcrumbs.html b/docs/source/_templates/breadcrumbs.html
new file mode 100644
index 00000000..4ecb013f
--- /dev/null
+++ b/docs/source/_templates/breadcrumbs.html
@@ -0,0 +1,4 @@
+{%- extends "sphinx_rtd_theme/breadcrumbs.html" %}
+
+{% block breadcrumbs_aside %}
+{% endblock %}
diff --git a/docs/source/concepts/benchmarks.rst b/docs/source/concepts/benchmarks.rst
new file mode 100644
index 00000000..d7eca9e0
--- /dev/null
+++ b/docs/source/concepts/benchmarks.rst
@@ -0,0 +1,25 @@
+Public benchmarks
+=================
+
+.. warning::
+ This section is under a work in progress. We are constantly working on fine-tuning
+ our experiments to enable our users to have access to state-of-the-art benchmarks.
+ If you would like to collaborate in this effort, please reach out to us.
+
+In the `fine_tuned `__
+folder we are collecting some tested hyperparameters for
+specific environments to enable users to bootstrap their benchmarking.
+You can just run the scripts in this folder to automatically use the proposed hyperparameters.
+
+We will tune benchmarks for you and publish the config and benchmarking plots on
+:wandb:`null` `Wandb `__ publicly.
+
+Currently available ones are:
+
+VMAS
+----
+`Conf `__ | :wandb:`null` `Wandb `__
+
+.. raw:: html
+
+