Skip to content

Commit

Permalink
[Algorithm] QMixer loss and multiagent models (#1378)
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <matbet@meta.com>
  • Loading branch information
matteobettini authored Jul 14, 2023
1 parent 9c95e1d commit 574dbf1
Show file tree
Hide file tree
Showing 15 changed files with 1,691 additions and 40 deletions.
14 changes: 14 additions & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,20 @@ algorithms, such as DQN, DDPG or Dreamer.
RSSMPrior
RSSMPosterior

Multi-agent-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~~

These networks implement models that can be used in
multi-agent contexts.

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

MultiAgentMLP
QMixer
VDNMixer


Exploration
-----------
Expand Down
15 changes: 15 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,21 @@ Dreamer
DreamerModelLoss
DreamerValueLoss

Multi-agent objectives
----------------------
.. currentmodule:: torchrl.objectives.multiagent

These objectives are specific to multi-agent algorithms.

QMixer
~~~~~~

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

QMixerLoss


Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def _main(argv):
"checkpointing": [
"torchsnapshot",
],
"marl": ["vmas"],
},
zip_safe=False,
classifiers=[
Expand All @@ -254,5 +255,4 @@ def _main(argv):


if __name__ == "__main__":

_main(sys.argv[1:])
Loading

1 comment on commit 574dbf1

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 574dbf1 Previous: 9c95e1d Ratio
benchmarks/test_objectives_benchmarks.py::test_reinforce_speed 103.36318234699311 iter/sec (stddev: 0.0011253947426459006) 215.8205136415671 iter/sec (stddev: 0.00021539867449864384) 2.09
benchmarks/test_objectives_benchmarks.py::test_iql_speed 20.80897970496085 iter/sec (stddev: 0.0033384474096849206) 42.24455036474581 iter/sec (stddev: 0.0013206046382742407) 2.03

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.