Skip to content

Commit

Permalink
[BugFix] Fix envpool (pytorch#1530)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Sep 14, 2023
1 parent 274cdfc commit da50587
Show file tree
Hide file tree
Showing 9 changed files with 348 additions and 320 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ And it is `functorch` and `torch.compile` compatible!
- A common [interface for environments](torchrl/envs)
which supports common libraries (OpenAI gym, deepmind control lab, etc.)<sup>(1)</sup> and state-less execution
(e.g. Model-based environments).
The [batched environments](torchrl/envs/vec_env.py) containers allow parallel execution<sup>(2)</sup>.
The [batched environments](torchrl/envs/batched_envs.py) containers allow parallel execution<sup>(2)</sup>.
A common PyTorch-first class of [tensor-specification class](torchrl/data/tensor_specs.py) is also provided.
TorchRL's environments API is simple but stringent and specific. Check the
[documentation](https://pytorch.org/rl/reference/envs.html)
Expand Down
5 changes: 3 additions & 2 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
from tensordict import tensorclass
from torchrl._utils import implement_for, seed_generator

from torchrl.envs import ObservationNorm
from torchrl.envs import MultiThreadedEnv, ObservationNorm
from torchrl.envs.batched_envs import ParallelEnv, SerialEnv
from torchrl.envs.libs.envpool import _has_envpool
from torchrl.envs.libs.gym import _has_gym, GymEnv
from torchrl.envs.transforms import (
Compose,
RewardClipping,
ToTensorImage,
TransformedEnv,
)
from torchrl.envs.vec_env import _has_envpool, MultiThreadedEnv, ParallelEnv, SerialEnv

# Specified for test_utils.py
__version__ = "0.3"
Expand Down
3 changes: 2 additions & 1 deletion test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@
ParallelEnv,
RenameTransform,
)
from torchrl.envs.batched_envs import SerialEnv
from torchrl.envs.libs.brax import _has_brax, BraxEnv
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv, DMControlWrapper
from torchrl.envs.libs.envpool import _has_envpool, MultiThreadedEnvWrapper
from torchrl.envs.libs.gym import (
_has_gym,
_is_from_pixels,
Expand All @@ -66,7 +68,6 @@
from torchrl.envs.libs.robohive import RoboHiveEnv
from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper
from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType
from torchrl.envs.vec_env import _has_envpool, MultiThreadedEnvWrapper, SerialEnv
from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator

_has_d4rl = importlib.util.find_spec("d4rl") is not None
Expand Down
2 changes: 1 addition & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from torchrl.collectors.utils import split_trajectories
from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.batched_envs import _BatchedEnv
from torchrl.envs.common import EnvBase
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.utils import (
Expand All @@ -47,7 +48,6 @@
set_exploration_type,
step_mdp,
)
from torchrl.envs.vec_env import _BatchedEnv

_TIMEOUT = 1.0
_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory
Expand Down
3 changes: 2 additions & 1 deletion torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .batched_envs import ParallelEnv, SerialEnv
from .common import EnvBase, EnvMetaData, make_tensordict
from .env_creator import EnvCreator, get_env_metadata
from .gym_like import default_info_dict_reader, GymLikeEnv
from .libs.envpool import MultiThreadedEnv
from .model_based import ModelBasedEnvBase
from .transforms import (
ActionMask,
Expand Down Expand Up @@ -66,4 +68,3 @@
set_exploration_type,
step_mdp,
)
from .vec_env import MultiThreadedEnv, ParallelEnv, SerialEnv
Loading

0 comments on commit da50587

Please sign in to comment.