Skip to content

Commit

Permalink
[Features] PettingZoo possibility to choose reset strategy (pytorch#2048
Browse files Browse the repository at this point in the history
)

Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
matteobettini and vmoens authored Apr 8, 2024
1 parent 90542ef commit 79e2b07
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ dependencies:
- cloudpickle
- gym
- gym-notices
- importlib-metadata
- six
- zipp
- pytest
Expand All @@ -20,4 +19,4 @@ dependencies:
- expecttest
- pyyaml
- autorom[accept-rom-license]
- pettingzoo[all]==1.24.1
- pettingzoo[all]==1.24.3
1 change: 0 additions & 1 deletion .github/unittest/linux_libs/scripts_smacv2/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ dependencies:
- cloudpickle
- gym
- gym-notices
- importlib-metadata
- zipp
- pytest
- pytest-cov
Expand Down
1 change: 0 additions & 1 deletion .github/unittest/linux_libs/scripts_vmas/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ dependencies:
- cloudpickle
- gym
- gym-notices
- importlib-metadata
- numpy
- pyglet==1.5.27
- six
Expand Down
58 changes: 54 additions & 4 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ def _make_spec( # noqa: F811

@pytest.mark.parametrize("categorical", [True, False])
def test_gym_spec_cast(self, categorical):

batch_size = [3, 4]
cat = DiscreteTensorSpec if categorical else OneHotDiscreteTensorSpec
cat_shape = batch_size if categorical else (*batch_size, 5)
Expand Down Expand Up @@ -543,7 +542,6 @@ def test_torchrl_to_gym(self, backend, numpy):
],
)
def test_gym(self, env_name, frame_skip, from_pixels, pixels_only):

if env_name == PONG_VERSIONED() and not from_pixels:
# raise pytest.skip("already pixel")
# we don't skip because that would raise an exception
Expand Down Expand Up @@ -3126,7 +3124,6 @@ class TestPettingZoo:
def test_pistonball(
self, parallel, continuous_actions, use_mask, return_state, group_map
):

kwargs = {"n_pistons": 21, "continuous": continuous_actions}

env = PettingZooEnv(
Expand All @@ -3141,6 +3138,60 @@ def test_pistonball(

check_env_specs(env)

def test_dead_agents_done(self, seed=0):
scenario_args = {"n_walkers": 3, "terminate_on_fall": False}

env = PettingZooEnv(
task="multiwalker_v9",
parallel=True,
seed=seed,
use_mask=False,
done_on_any=False,
**scenario_args,
)
td_reset = env.reset(seed=seed)
with pytest.raises(
ValueError,
match="Dead agents found in the environment, "
"you need to set use_mask=True to allow this.",
):
env.rollout(
max_steps=500,
break_when_any_done=True, # This looks at root done set with done_on_any
auto_reset=False,
tensordict=td_reset,
)

for done_on_any in [True, False]:
env = PettingZooEnv(
task="multiwalker_v9",
parallel=True,
seed=seed,
use_mask=True,
done_on_any=done_on_any,
**scenario_args,
)
td_reset = env.reset(seed=seed)
td = env.rollout(
max_steps=500,
break_when_any_done=True, # This looks at root done set with done_on_any
auto_reset=False,
tensordict=td_reset,
)
done = td.get(("next", "walker", "done"))
mask = td.get(("next", "walker", "mask"))

if done_on_any:
assert not done[-1].all() # Done triggered on any
else:
assert done[-1].all() # Done triggered on all
assert not done[
mask
].any() # When mask is true (alive agent), all agents are not done
assert done[
~mask
].all() # When mask is false (dead agent), all agents are done

@pytest.mark.parametrize(
"wins_player_0",
[True, False],
Expand All @@ -3156,7 +3207,6 @@ def test_tic_tac_toe(self, wins_player_0):
)

class Policy:

action = 0
t = 0

Expand Down
57 changes: 46 additions & 11 deletions torchrl/envs/libs/pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
from typing import Dict, List, Tuple, Union

import packaging
import torch
from tensordict import TensorDictBase

Expand Down Expand Up @@ -90,12 +91,12 @@ class PettingZooWrapper(_EnvWrapper):
If the number of agents during the task varies, please set ``use_mask=True``.
``"mask"`` will be provided
as an output in each group and should be used to mask out dead agents.
The environment will be reset as soon as one agent is done.
The environment will be reset as soon as one agent is done (unless ``done_on_any`` is ``False``).
In wrapped ``pettingzoo.AECEnv``, at each step only one agent will act.
For this reason, it is compulsory to set ``use_mask=True`` for this type of environment.
``"mask"`` will be provided as an output for each group and can be used to mask out non-acting agents.
The environment will be reset only when all agents are done.
The environment will be reset only when all agents are done (unless ``done_on_any`` is ``True``).
If there are any unavailable actions for an agent,
the environment will also automatically update the mask of its ``action_spec`` and output an ``"action_mask"``
Expand Down Expand Up @@ -156,6 +157,9 @@ class PettingZooWrapper(_EnvWrapper):
categorical_actions (bool, optional): if the enviornments actions are discrete, whether to transform
them to categorical or one-hot.
seed (int, optional): the seed. Defaults to ``None``.
done_on_any (bool, optional): whether the environment's done keys are set by aggregating the agent keys
using ``any()`` (when ``True``) or ``all()`` (when ``False``). Default (``None``) is to use ``any()`` for
parallel environments and ``all()`` for AEC ones.
Examples:
>>> # Parallel env
Expand Down Expand Up @@ -204,6 +208,7 @@ def __init__(
use_mask: bool = False,
categorical_actions: bool = True,
seed: int | None = None,
done_on_any: bool | None = None,
**kwargs,
):
if env is not None:
Expand All @@ -214,6 +219,7 @@ def __init__(
self.seed = seed
self.use_mask = use_mask
self.categorical_actions = categorical_actions
self.done_on_any = done_on_any

super().__init__(**kwargs, allow_done_after_reset=True)

Expand Down Expand Up @@ -265,6 +271,13 @@ def _build_env(
):
import pettingzoo

if packaging.version.parse(pettingzoo.__version__).base_version != "1.24.3":
warnings.warn(
"PettingZoo in TorchRL is tested using version == 1.24.3 , "
"If you are using a different version and are experiencing compatibility issues,"
"please raise an issue in the TorchRL github."
)

self.parallel = isinstance(env, pettingzoo.utils.env.ParallelEnv)
if not self.parallel and not self.use_mask:
raise ValueError("For AEC environments you need to set use_mask=True")
Expand All @@ -283,6 +296,9 @@ def _make_specs(
"pettingzoo.utils.env.AECEnv", # noqa: F821
],
) -> None:
# Set default for done on any or all
if self.done_on_any is None:
self.done_on_any = self.parallel

# Create and check group map
if self.group_map is None:
Expand Down Expand Up @@ -582,7 +598,6 @@ def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:

if self.parallel:
(
observation_dict,
Expand Down Expand Up @@ -651,16 +666,33 @@ def _step(
value, device=self.device
)

elif not self.use_action_mask:
elif self.use_mask:
if agent in self.agents:
raise ValueError(
f"Dead agent {agent} not found in step observation but still available in {self.agents}"
)
# Dead agent
terminated = (
terminations_dict[agent] if agent in terminations_dict else True
)
truncated = (
truncations_dict[agent] if agent in truncations_dict else True
)
done = terminated or truncated
group_done[index] = done
group_terminated[index] = terminated
group_truncated[index] = truncated

else:
# Dead agent, if we are not masking it out, this is not allowed
raise ValueError(
"Dead agents found in the environment,"
" you need to set use_action_mask=True to allow this."
" you need to set use_mask=True to allow this."
)

# set done values
done, terminated, truncated = self._aggregate_done(
tensordict_out, use_any=self.parallel
tensordict_out, use_any=self.done_on_any
)

tensordict_out.set("done", done)
Expand All @@ -673,7 +705,7 @@ def _aggregate_done(self, tensordict_out, use_any):
truncated = False if use_any else True
terminated = False if use_any else True
for key in self.done_keys:
if isinstance(key, tuple):
if isinstance(key, tuple): # Only look at group keys
if use_any:
if key[-1] == "done":
done = done | tensordict_out.get(key).any()
Expand Down Expand Up @@ -719,7 +751,6 @@ def _step_aec(
self,
tensordict: TensorDictBase,
) -> Tuple[Dict, Dict, Dict, Dict, Dict]:

for group, agents in self.group_map.items():
if self.agent_selection in agents:
agent_index = agents.index(self._env.agent_selection)
Expand Down Expand Up @@ -747,7 +778,6 @@ def _step_aec(
)

def _update_action_mask(self, td, observation_dict, info_dict):

# Since we remove the action_mask keys we need to copy the data
observation_dict = copy.deepcopy(observation_dict)
info_dict = copy.deepcopy(info_dict)
Expand Down Expand Up @@ -821,15 +851,15 @@ class PettingZooEnv(PettingZooWrapper):
If the number of agents during the task varies, please set ``use_mask=True``.
``"mask"`` will be provided
as an output in each group and should be used to mask out dead agents.
The environment will be reset as soon as one agent is done.
The environment will be reset as soon as one agent is done (unless ``done_on_any`` is ``False``).
For wrapping ``pettingzoo.AECEnv`` provide the name of your petting zoo task (in the ``task`` argument)
and specify ``parallel=False``. This will construct the ``pettingzoo.AECEnv`` version of that task
and wrap it for torchrl.
In wrapped ``pettingzoo.AECEnv``, at each step only one agent will act.
For this reason, it is compulsory to set ``use_mask=True`` for this type of environment.
``"mask"`` will be provided as an output for each group and can be used to mask out non-acting agents.
The environment will be reset only when all agents are done.
The environment will be reset only when all agents are done (unless ``done_on_any`` is ``True``).
If there are any unavailable actions for an agent,
the environment will also automatically update the mask of its ``action_spec`` and output an ``"action_mask"``
Expand Down Expand Up @@ -892,6 +922,9 @@ class PettingZooEnv(PettingZooWrapper):
categorical_actions (bool, optional): if the enviornments actions are discrete, whether to transform
them to categorical or one-hot.
seed (int, optional): the seed. Defaults to ``None``.
done_on_any (bool, optional): whether the environment's done keys are set by aggregating the agent keys
using ``any()`` (when ``True``) or ``all()`` (when ``False``). Default (``None``) is to use ``any()`` for
parallel environments and ``all()`` for AEC ones.
Examples:
>>> # Parallel env
Expand Down Expand Up @@ -930,6 +963,7 @@ def __init__(
use_mask: bool = False,
categorical_actions: bool = True,
seed: int | None = None,
done_on_any: bool | None = None,
**kwargs,
):
if not _has_pettingzoo:
Expand All @@ -944,6 +978,7 @@ def __init__(
kwargs["use_mask"] = use_mask
kwargs["categorical_actions"] = categorical_actions
kwargs["seed"] = seed
kwargs["done_on_any"] = done_on_any

super().__init__(**kwargs)

Expand Down

0 comments on commit 79e2b07

Please sign in to comment.