Skip to content

Commit

Permalink
[Feature] Fine control over devices in collectors (#1835)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 30, 2024
1 parent 6277226 commit b1cc796
Show file tree
Hide file tree
Showing 66 changed files with 2,334 additions and 986 deletions.
4 changes: 0 additions & 4 deletions benchmarks/ecosystem/gym_env_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def make(envname=envname, gym_backend=gym_backend):
frames_per_batch=1024,
total_frames=num_workers * 10_000,
device=device,
storing_device=device,
)
pbar = tqdm.tqdm(total=num_workers * 10_000)
total_frames = 0
Expand Down Expand Up @@ -178,7 +177,6 @@ def make_env(envname=envname, gym_backend=gym_backend):
frames_per_batch=1024,
total_frames=num_workers * 10_000,
device=device,
storing_device=device,
)
pbar = tqdm.tqdm(total=num_workers * 10_000)
total_frames = 0
Expand Down Expand Up @@ -222,7 +220,6 @@ def make_env(
total_frames=num_workers * 10_000,
num_sub_threads=num_workers // num_collectors,
device=device,
storing_device=device,
)
pbar = tqdm.tqdm(total=num_workers * 10_000)
total_frames = 0
Expand Down Expand Up @@ -260,7 +257,6 @@ def make_env(envname=envname, gym_backend=gym_backend):
frames_per_batch=1024,
total_frames=num_workers * 10_000,
device=device,
storing_device=device,
)
pbar = tqdm.tqdm(total=num_workers * 10_000)
total_frames = 0
Expand Down
16 changes: 11 additions & 5 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import torch
import torch.nn as nn
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict import TensorDict, TensorDictBase
from tensordict.nn import TensorDictModuleBase
from tensordict.utils import expand_right, NestedKey

from torchrl.data.tensor_specs import (
Expand Down Expand Up @@ -229,6 +230,7 @@ def _step(self, tensordict):
"observation": n.clone(),
},
batch_size=[],
device=self.device,
)

def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase:
Expand All @@ -240,7 +242,9 @@ def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase:
done = self.counter >= self.max_val
done = torch.tensor([done], dtype=torch.bool, device=self.device)
return TensorDict(
{"done": done, "terminated": done.clone(), "observation": n}, []
{"done": done, "terminated": done.clone(), "observation": n},
[],
device=self.device,
)

def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
Expand Down Expand Up @@ -1374,8 +1378,9 @@ def _step(
return tensordict


class HeteroCountingEnvPolicy:
class HeterogeneousCountingEnvPolicy(TensorDictModuleBase):
def __init__(self, full_action_spec: TensorSpec, count: bool = True):
super().__init__()
self.full_action_spec = full_action_spec
self.count = count

Expand All @@ -1386,7 +1391,7 @@ def __call__(self, td: TensorDictBase) -> TensorDictBase:
return td.update(action_td)


class HeteroCountingEnv(EnvBase):
class HeterogeneousCountingEnv(EnvBase):
"""A heterogeneous, counting Env."""

def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
Expand Down Expand Up @@ -1569,13 +1574,14 @@ def _set_seed(self, seed: Optional[int]):
torch.manual_seed(seed)


class MultiKeyCountingEnvPolicy:
class MultiKeyCountingEnvPolicy(TensorDictModuleBase):
def __init__(
self,
full_action_spec: TensorSpec,
count: bool = True,
deterministic: bool = False,
):
super().__init__()
if not deterministic and not count:
raise ValueError("Not counting policy is always deterministic")

Expand Down
Loading

0 comments on commit b1cc796

Please sign in to comment.