Skip to content

🦁 A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX

License

Notifications You must be signed in to change notification settings

instadeepai/Mava

Repository files navigation

Mava logo

Distributed Multi-Agent Reinforcement Learning in JAX


πŸ‘‹ UPDATE - 25/8/2023: We have changed the focus of Mava away from a framework and more towards a lightweight and easy-to-use codebase for MARL. Mava is also now end-to-end JAX-based and henceforth we will only be supporting JAX-based environments. We currently provide native support for the Jumanji environment API. Mava now follows a similar design philosophy to CleanRL and PureJaxRL, where we allow for code duplication to enable readability and easy reuse. All algorithmic logic can be found in the file implementing a particular algorithm. If you would still like to use our deprecated TF2-based framework and systems please install v0.1.3 of Mava (e.g. pip install id-mava==0.1.3).


Welcome to Mava! 🦁

Mava provides simplified code for quickly iterating on ideas in multi-agent reinforcement learning (MARL) with useful implementations of MARL algorithms in JAX allowing for easy parallelisation across devices with JAX's pmap. Mava is a project originating in the Research Team at InstaDeep.

To join us in these efforts, please feel free to reach out, raise issues or read our contribution guidelines (or just star 🌟 to stay up to date with the latest developments)!

Performance and Speed πŸš€

All of the experiments below were performed using an NVIDIA Quadro RTX 4000 GPU with 8GB Memory.

In order to show the utility of end-to-end JAX-based MARL systems and JAX-based environments we compare the speed of Mava against EPyMARL as measured in total training wallclock time on simple Robotic Warehouse (RWARE) tasks with 2 and 4 agents. Our aim is to illustrate the speed increases that are possible with using end-to-end Jax-based systems and we do not necessarily make an effort to achieve optimal performance. For EPyMARL, we use the hyperparameters as recommended by Papoudakis et al. (2020) and for Mava we performed a basic grid search. In both cases, systems were trained up to 20 million total environment steps using 16 vectorised environments.

Mava ff mappo tiny 2ag Mava ff mappo tiny 4ag Mava ff mappo small 4ag

Mava MAPPO performance on the `tiny-2ag`, `tiny-4ag` and `small-4ag` RWARE tasks.

πŸ“Œ An important note on the differences in converged performance

In order to benefit from the wallclock speed-ups afforded by JAX-based systems it is required that environments also be written in JAX. It is for this reason that Mava does not use the exact same version of the RWARE environment as EPyMARL but instead uses a JAX-based implementation of RWARE found in Jumanji, under the name of RobotWarehouse. One of the notable differences in the underlying environment logic is that RobotWarehouse (the RWARE version in Jumanji) will not attempt to handle agent collisions in the environment but will instead terminate an episode upon agent collision. In our experiments, this appeared to make the environment more challenging. For a more detailed discussion, please see the following page.

🧨 Steps per second experiments using vectorised environments

Furthermore, we illustrate the speed of Mava by showing the steps per second as the number of parallel environments is increased. These steps per second scaling plots were computed using a standard laptop GPU, specifically an RTX-3060 GPU with 6GB memory.

Mava sps

Mava environment steps per second scaling with increased vectorised environments.

Here, also note the system performance on a larger set of Robotic Warehouse environments. In all cases, systems were trained up to 200 million environment timesteps with 512 vectorised environments. We give the average experiment wall clock time on the x-axis.

Code Philosophy

The current code in Mava is adapted from PureJaxRL which provides high-quality single-file implementations with research-friendly features. In turn, PureJaxRL is inspired by the code philosophy from CleanRL. Along this vein of easy-to-use and understandable RL codebases, Mava is not designed to be a modular library and is not meant to be imported. Our repository focuses on simplicity and clarity in its implementations while utilising the advantages offered by JAX such as pmap and vmap, making it an excellent resource for researchers and practitioners to build upon.

Overview 🦜

Mava currently offers the following building blocks for MARL research:

  • πŸ₯‘ Implementations of MARL algorithms: Implementations of multi-agent PPO systems that follow both the Centralised Training with Decentralised Execution (CTDE) and Decentralised Training with Decentralised Execution (DTDE) MARL paradigms.
  • 🍬 Environment Wrappers: Example wrapper for mapping a Jumanji environment to an environment usable in Mava. At the moment, we only support Robotic Warehouse but plan to support more environments soon.
  • πŸŽ“ Educational Material: Quickstart notebook to demonstrate how Mava can be used and to highlight the added value of JAX-based MARL.

Installation 🎬

At the moment Mava is not meant to be installed as a library, but rather to be used as a research tool.

You can use Mava by cloning the repo and pip installing as follows:

git clone https://github.com/instadeepai/mava.git
cd mava
pip install -e .

We have tested Mava on Python 3.9. Note that because the installation of JAX differs depending on your hardware accelerator, we advise users to explicitly install the correct JAX version (see the official installation guide). For more in-depth installation guides including Docker builds and virtual environments, please see our detailed installation guide.

Quickstart ⚑

We have a Quickstart notebook that can be used to quickly create and train your first Multi-Agent System.

Contributing 🀝

Please read our contributing docs for details on how to submit pull requests, our Contributor License Agreement and community guidelines.

Roadmap πŸ›€οΈ

We plan to iteratively expand Mava in the following increments:

  • 🌴 Support for more multi-agent Jumanji environments.
  • πŸ“Š Benchmarks on more environments.
  • 🦾 Support for off-policy algorithms.

Please do follow along as we develop this next phase!

Citing Mava πŸ“š

If you use Mava in your work, please cite the accompanying technical report (to be updated soon to reflect our transition to JAX):

@article{pretorius2021mava,
    title={Mava: A Research Framework for Distributed Multi-Agent Reinforcement Learning},
    author={Arnu Pretorius and Kale-ab Tessera and Andries P. Smit and Kevin Eloff
    and Claude Formanek and St John Grimbly and Siphelele Danisa and Lawrence Francis
    and Jonathan Shock and Herman Kamper and Willie Brink and Herman Engelbrecht
    and Alexandre Laterre and Karim Beguir},
    year={2021},
    journal={arXiv preprint arXiv:2107.01460},
    url={https://arxiv.org/pdf/2107.01460.pdf},
}

See Also πŸ”Ž

The current version of Mava has been based on code from the following projects:

  • πŸ€– PureJaxRL for simple code implementations of end-to-end RL training in JAX.
  • πŸŒ€ DeepMind Anakin for the Anakin podracer architecture to train RL agents at scale.
  • 🌴 Jumanji a diverse suite of scalable RL environments written in JAX, including multi-agent environments.