Skip to content

Commit

Permalink
[Feature] More flexibility in loading PettingZoo (#1817)
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini authored Jan 19, 2024
1 parent e98ee38 commit a10cdbf
Showing 1 changed file with 57 additions and 4 deletions.
61 changes: 57 additions & 4 deletions torchrl/envs/libs/pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import copy
import importlib
import warnings
from typing import Dict, List, Tuple, Union

import torch
Expand All @@ -27,11 +28,54 @@
def _get_envs():
if not _has_pettingzoo:
raise ImportError("PettingZoo is not installed in your virtual environment.")
from pettingzoo.utils.all_modules import all_environments
try:
from pettingzoo.utils.all_modules import all_environments
except ModuleNotFoundError as err:
warnings.warn(
f"PettingZoo failed to load all modules with error message {err}, trying to load individual modules."
)
all_environments = _load_available_envs()

return list(all_environments.keys())


def _load_available_envs() -> Dict:
all_environments = {}
try:
from pettingzoo.mpe.all_modules import mpe_environments

all_environments.update(mpe_environments)
except ModuleNotFoundError as err:
warnings.warn(f"MPE environments failed to load with error message {err}.")
try:
from pettingzoo.sisl.all_modules import sisl_environments

all_environments.update(sisl_environments)
except ModuleNotFoundError as err:
warnings.warn(f"SISL environments failed to load with error message {err}.")
try:
from pettingzoo.classic.all_modules import classic_environments

all_environments.update(classic_environments)
except ModuleNotFoundError as err:
warnings.warn(f"Classic environments failed to load with error message {err}.")
try:
from pettingzoo.atari.all_modules import atari_environments

all_environments.update(atari_environments)
except ModuleNotFoundError as err:
warnings.warn(f"Atari environments failed to load with error message {err}.")
try:
from pettingzoo.butterfly.all_modules import butterfly_environments

all_environments.update(butterfly_environments)
except ModuleNotFoundError as err:
warnings.warn(
f"Butterfly environments failed to load with error message {err}."
)
return all_environments


class PettingZooWrapper(_EnvWrapper):
"""PettingZoo environment wrapper.
Expand Down Expand Up @@ -834,7 +878,8 @@ class PettingZooEnv(PettingZooWrapper):
neural network.
Args:
task (str): the name of the pettingzoo task to create (for example, "multiwalker_v9").
task (str): the name of the pettingzoo task to create in the "<env>/<task>" format (for example, "sisl/multiwalker_v9")
or "<task>" format (for example, "multiwalker_v9").
parallel (bool): if to construct the ``pettingzoo.ParallelEnv`` version of the task or the ``pettingzoo.AECEnv``.
return_state (bool, optional): whether to return the global state from pettingzoo
(not available in all environments). Defaults to ``False``.
Expand Down Expand Up @@ -919,7 +964,13 @@ def _build_env(
]:
self.task_name = task

from pettingzoo.utils.all_modules import all_environments
try:
from pettingzoo.utils.all_modules import all_environments
except ModuleNotFoundError as err:
warnings.warn(
f"PettingZoo failed to load all modules with error message {err}, trying to load individual modules."
)
all_environments = _load_available_envs()

if task not in all_environments:
# Try looking at the literal translation of values
Expand All @@ -929,7 +980,9 @@ def _build_env(
task_module = value
break
if task_module is None:
raise RuntimeError(f"Specified task not in {_get_envs()}")
raise RuntimeError(
f"Specified task not in available environments {all_environments}"
)
else:
task_module = all_environments[task]

Expand Down

0 comments on commit a10cdbf

Please sign in to comment.