Skip to content

Commit

Permalink
[Feature] Execute rollouts with regular nn.Module instances (pytorch#…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 22, 2024
1 parent c3bda41 commit 40e9900
Show file tree
Hide file tree
Showing 22 changed files with 379 additions and 278 deletions.
2 changes: 1 addition & 1 deletion benchmarks/ecosystem/gym_env_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from torchrl.collectors import (
MultiaSyncDataCollector,
MultiSyncDataCollector,
RandomPolicy,
SyncDataCollector,
)
from torchrl.envs import EnvCreator, GymEnv, ParallelEnv
from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend
from torchrl.envs.utils import RandomPolicy

if __name__ == "__main__":
avail_devices = ("cpu",)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/test_collectors_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from torchrl.collectors.collectors import (
MultiaSyncDataCollector,
MultiSyncDataCollector,
RandomPolicy,
)
from torchrl.envs import EnvCreator, GymEnv, StepCounter, TransformedEnv
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.envs.utils import RandomPolicy


def single_collector_setup():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@
def main():
import gym
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
from torchrl.collectors.collectors import RandomPolicy
from torchrl.data import BoundedTensorSpec
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy

collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector
device_str = "device" if num_workers == 1 else "devices"
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/delayed_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@
def main():
import gym
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
from torchrl.collectors.collectors import RandomPolicy
from torchrl.data import BoundedTensorSpec
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy

collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector
device_str = "device" if num_workers == 1 else "devices"
Expand Down
7 changes: 2 additions & 5 deletions examples/distributed/collectors/multi_nodes/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,11 @@
import tqdm
from torchrl._utils import logger as torchrl_logger

from torchrl.collectors.collectors import (
MultiSyncDataCollector,
RandomPolicy,
SyncDataCollector,
)
from torchrl.collectors.collectors import MultiSyncDataCollector, SyncDataCollector
from torchrl.collectors.distributed import DistributedDataCollector
from torchrl.envs import EnvCreator
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
7 changes: 2 additions & 5 deletions examples/distributed/collectors/multi_nodes/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@
import tqdm
from torchrl._utils import logger as torchrl_logger

from torchrl.collectors.collectors import (
MultiSyncDataCollector,
RandomPolicy,
SyncDataCollector,
)
from torchrl.collectors.collectors import MultiSyncDataCollector, SyncDataCollector
from torchrl.collectors.distributed import RPCDataCollector
from torchrl.envs import EnvCreator
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
7 changes: 2 additions & 5 deletions examples/distributed/collectors/multi_nodes/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,11 @@
import tqdm
from torchrl._utils import logger as torchrl_logger

from torchrl.collectors.collectors import (
MultiSyncDataCollector,
RandomPolicy,
SyncDataCollector,
)
from torchrl.collectors.collectors import MultiSyncDataCollector, SyncDataCollector
from torchrl.collectors.distributed import DistributedSyncDataCollector
from torchrl.envs import EnvCreator
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/single_machine/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@
from torchrl.collectors.collectors import (
MultiaSyncDataCollector,
MultiSyncDataCollector,
RandomPolicy,
SyncDataCollector,
)
from torchrl.collectors.distributed import DistributedDataCollector
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
3 changes: 2 additions & 1 deletion examples/distributed/collectors/single_machine/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
import tqdm
from torchrl._utils import logger as torchrl_logger

from torchrl.collectors.collectors import RandomPolicy, SyncDataCollector
from torchrl.collectors.collectors import SyncDataCollector
from torchrl.collectors.distributed import RPCDataCollector
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
7 changes: 2 additions & 5 deletions examples/distributed/collectors/single_machine/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,11 @@
import tqdm
from torchrl._utils import logger as torchrl_logger

from torchrl.collectors.collectors import (
MultiSyncDataCollector,
RandomPolicy,
SyncDataCollector,
)
from torchrl.collectors.collectors import MultiSyncDataCollector, SyncDataCollector
from torchrl.collectors.distributed import DistributedSyncDataCollector
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
_Interruptor,
MultiaSyncDataCollector,
MultiSyncDataCollector,
RandomPolicy,
)
from torchrl.collectors.utils import split_trajectories
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
Expand All @@ -67,6 +66,7 @@
_aggregate_end_of_traj,
check_env_specs,
PARTIAL_MISSING_ERR,
RandomPolicy,
)
from torchrl.modules import Actor, LSTMNet, OrnsteinUhlenbeckProcessWrapper, SafeModule

Expand Down
2 changes: 1 addition & 1 deletion test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from torchrl.collectors.collectors import (
MultiaSyncDataCollector,
MultiSyncDataCollector,
RandomPolicy,
SyncDataCollector,
)
from torchrl.collectors.distributed import (
Expand All @@ -43,6 +42,7 @@
RPCDataCollector,
)
from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG
from torchrl.envs.utils import RandomPolicy

TIMEOUT = 200

Expand Down
17 changes: 17 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2592,6 +2592,23 @@ def make_env(seed, device=device):
p_env.close()


@pytest.mark.skipif(not _has_gym, reason="Gym required for this test")
def test_non_td_policy():
env = GymEnv("CartPole-v1", categorical_action_encoding=True)

class ArgMaxModule(nn.Module):
def forward(self, values):
return values.argmax(-1)

policy = nn.Sequential(
nn.Linear(env.observation_spec["observation"].shape[-1], env.action_spec.n),
ArgMaxModule(),
)
env.rollout(10, policy)
env = SerialEnv(2, lambda: GymEnv("CartPole-v1", categorical_action_encoding=True))
env.rollout(10, policy)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
9 changes: 7 additions & 2 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
)
from torch import nn
from torchrl._utils import implement_for
from torchrl.collectors.collectors import RandomPolicy, SyncDataCollector
from torchrl.collectors.collectors import SyncDataCollector
from torchrl.data import (
BinaryDiscreteTensorSpec,
BoundedTensorSpec,
Expand Down Expand Up @@ -105,7 +105,12 @@
from torchrl.envs.libs.robohive import _has_robohive, RoboHiveEnv
from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env
from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper
from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType
from torchrl.envs.utils import (
check_env_specs,
ExplorationType,
MarlGroupMapType,
RandomPolicy,
)
from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator

_has_d4rl = importlib.util.find_spec("d4rl") is not None
Expand Down
3 changes: 2 additions & 1 deletion torchrl/collectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from torchrl.envs.utils import RandomPolicy

from .collectors import (
aSyncDataCollector,
DataCollectorBase,
MultiaSyncDataCollector,
MultiSyncDataCollector,
RandomPolicy,
SyncDataCollector,
)
Loading

0 comments on commit 40e9900

Please sign in to comment.