Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Dynamic specs #2143

Merged
merged 27 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed May 31, 2024
commit 3e579c1a91630025645801ffbdddaac0579c6f79
13 changes: 7 additions & 6 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,12 +798,8 @@ def make_env(seed):


@pytest.mark.parametrize("num_env", [1, 2])
@pytest.mark.parametrize(
"env_name",
[
"vec",
],
) # 1226: for efficiency, we just test vec, not "conv"
# 1226: for efficiency, we just test vec, not "conv"
@pytest.mark.parametrize("env_name", ["vec"])
def test_collector_batch_size(
num_env, env_name, seed=100, num_workers=2, frames_per_batch=20
):
Expand Down Expand Up @@ -1428,6 +1424,7 @@ def env_fn(seed):
device=device,
storing_device=storing_device,
)
assert collector._use_buffers
batch = next(collector.iterator())
assert batch.device == torch.device(storing_device)
collector.shutdown()
Expand Down Expand Up @@ -2581,6 +2578,7 @@ def test_unique_traj_sync(self, cat_results):
try:
for d in c:
buffer.extend(d)
assert c._use_buffers
traj_ids = buffer[:].get(("collector", "traj_ids"))
# check that we have as many trajs as expected (no skip)
assert traj_ids.unique().numel() == traj_ids.max() + 1
Expand Down Expand Up @@ -2611,6 +2609,7 @@ def test_dynamic_sync_collector(self):
)
for data in collector:
assert isinstance(data, LazyStackedTensorDict)
assert data.names[-1] == "time"

def test_dynamic_multisync_collector(self):
env = EnvWithDynamicSpec
Expand All @@ -2625,6 +2624,7 @@ def test_dynamic_multisync_collector(self):
)
for data in collector:
assert isinstance(data, LazyStackedTensorDict)
assert data.names[-1] == "time"

def test_dynamic_multiasync_collector(self):
env = EnvWithDynamicSpec
Expand All @@ -2638,6 +2638,7 @@ def test_dynamic_multiasync_collector(self):
)
for data in collector:
assert isinstance(data, LazyStackedTensorDict)
assert data.names[-1] == "time"


if __name__ == "__main__":
Expand Down
21 changes: 16 additions & 5 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def update_policy_weights_(
self.policy_weights.data.update_(self.get_weights_fn())

def __iter__(self) -> Iterator[TensorDictBase]:
return self.iterator()
yield from self.iterator()

def next(self):
try:
Expand Down Expand Up @@ -796,6 +796,8 @@ def filter_policy(value_output, value_input, value_input_clone):
)
self._final_rollout.refine_names(..., "time")

assert self._final_rollout.names[-1] == "time"

def _set_truncated_keys(self):
self._truncated_keys = []
if self.set_truncated:
Expand Down Expand Up @@ -1080,25 +1082,30 @@ def rollout(self) -> TensorDictBase:
)
else:
result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
assert result.names[-1] == "time"
break
else:
if self._use_buffers:
result = self._final_rollout
try:
self._final_rollout = torch.stack(
result = torch.stack(
tensordicts,
self._final_rollout.ndim - 1,
out=self._final_rollout,
)
assert result.names[-1] == "time"

except RuntimeError:
with self._final_rollout.unlock_():
self._final_rollout = torch.stack(
result = torch.stack(
tensordicts,
self._final_rollout.ndim - 1,
out=self._final_rollout,
)
assert result.names[-1] == "time"
else:
result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
result.refine_names(..., "time")

return self._maybe_set_truncated(result)

Expand Down Expand Up @@ -2213,7 +2220,11 @@ def iterator(self) -> Iterator[TensorDictBase]:
("collector", "traj_ids"), torch.stack(traj_ids_list), inplace=True
)
else:
if not self._use_buffers:
if self._use_buffers is None:
torchrl_logger.warning(
"use_buffer not specified and not yet inferred from data, assuming `True`."
)
elif not self._use_buffers:
raise RuntimeError(
"Cannot concatenate results with use_buffers=False"
)
Expand Down Expand Up @@ -2455,7 +2466,6 @@ def iterator(self) -> Iterator[TensorDictBase]:
_check_for_faulty_process(self.procs)
self._iter += 1
idx, j, out = self._get_from_queue()

worker_frames = out.numel()
if self.split_trajs:
out = split_trajectories(out, prefix="collector")
Expand Down Expand Up @@ -2854,6 +2864,7 @@ def _main_async_collector(
else x
)
data = (collected_tensordict, idx)
assert collected_tensordict.names[-1] == "time"
else:
if next_data is not collected_tensordict:
raise RuntimeError(
Expand Down
Loading
Loading