Skip to content

Commit

Permalink
[BugFix] Fix unwanted lazy stacks (pytorch#2102)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 23, 2024
1 parent 6c2e141 commit 09c934d
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 44 deletions.
10 changes: 2 additions & 8 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dd
replay_buffer.size=120 \
env.name=Pendulum-v1 \
logger.backend=
# record_video=True \
# record_frames=4 \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/a2c/a2c_mujoco.py \
env.env_name=HalfCheetah-v4 \
collector.total_frames=40 \
Expand Down Expand Up @@ -125,7 +123,7 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/re
collector.env_per_collector=2 \
buffer.batch_size=10 \
optim.steps_per_batch=1 \
logger.record_video=True \
logger.video=True \
logger.record_frames=4 \
buffer.size=120 \
logger.backend=
Expand All @@ -151,8 +149,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/di
replay_buffer.size=120 \
env.name=CartPole-v1 \
logger.backend=
# logger.record_video=True \
# logger.record_frames=4 \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
collector.total_frames=200 \
collector.init_random_frames=10 \
Expand Down Expand Up @@ -220,8 +216,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dd
replay_buffer.size=120 \
env.name=Pendulum-v1 \
logger.backend=
# record_video=True \
# record_frames=4 \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dqn/dqn_atari.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
Expand All @@ -238,7 +232,7 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/re
collector.env_per_collector=1 \
buffer.batch_size=10 \
optim.steps_per_batch=1 \
logger.record_video=True \
logger.video=True \
logger.record_frames=4 \
buffer.size=120 \
logger.backend=
Expand Down
18 changes: 9 additions & 9 deletions sota-implementations/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def main(cfg: "DictConfig"): # noqa: F821
},
)
else:
logger = ""
logger = None

key, init_env_steps, stats = None, None, None
if not cfg.env.vecnorm and cfg.env.norm_stats:
Expand Down Expand Up @@ -174,14 +174,14 @@ def main(cfg: "DictConfig"): # noqa: F821
t.loc.fill_(0.0)

trainer = make_trainer(
collector,
loss_module,
recorder,
target_net_updater,
actor_model_explore,
replay_buffer,
logger,
cfg,
collector=collector,
loss_module=loss_module,
recorder=recorder,
target_net_updater=target_net_updater,
policy_exploration=actor_model_explore,
replay_buffer=replay_buffer,
logger=logger,
cfg=cfg,
)

trainer.train()
Expand Down
38 changes: 28 additions & 10 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,7 +1832,16 @@ def test_info_dict_reader(self, device, seed=0):
import gym

env = GymWrapper(gym.make(HALFCHEETAH_VERSIONED()), device=device)
env.set_info_dict_reader(default_info_dict_reader(["x_position"]))
env.set_info_dict_reader(
default_info_dict_reader(
["x_position"],
spec=CompositeSpec(
x_position=UnboundedContinuousTensorSpec(
dtype=torch.float64, shape=()
)
),
)
)

assert "x_position" in env.observation_spec.keys()
assert isinstance(
Expand All @@ -1842,15 +1851,21 @@ def test_info_dict_reader(self, device, seed=0):
tensordict = env.reset()
tensordict = env.rand_step(tensordict)

assert env.observation_spec["x_position"].is_in(
tensordict[("next", "x_position")]
x_position_data = tensordict["next", "x_position"]
assert env.observation_spec["x_position"].is_in(x_position_data), (
x_position_data.shape,
x_position_data.dtype,
env.observation_spec["x_position"],
)

for spec in (
{"x_position": UnboundedContinuousTensorSpec(10)},
None,
CompositeSpec(x_position=UnboundedContinuousTensorSpec(10), shape=[]),
[UnboundedContinuousTensorSpec(10)],
{"x_position": UnboundedContinuousTensorSpec((), dtype=torch.float64)},
# None,
CompositeSpec(
x_position=UnboundedContinuousTensorSpec((), dtype=torch.float64),
shape=[],
),
[UnboundedContinuousTensorSpec((), dtype=torch.float64)],
):
env2 = GymWrapper(gym.make("HalfCheetah-v4"))
env2.set_info_dict_reader(
Expand All @@ -1859,9 +1874,12 @@ def test_info_dict_reader(self, device, seed=0):

tensordict2 = env2.reset()
tensordict2 = env2.rand_step(tensordict2)

assert env2.observation_spec["x_position"].is_in(
tensordict2[("next", "x_position")]
data = tensordict2[("next", "x_position")]
assert env2.observation_spec["x_position"].is_in(data), (
data.dtype,
data.device,
data.shape,
env2.observation_spec["x_position"],
)

@pytest.mark.skipif(not _has_gym, reason="no gym")
Expand Down
6 changes: 5 additions & 1 deletion test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,11 @@ def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=
out_noexp = []
out = []
for i in range(n_steps):
tensordict_noexp = policy(tensordict.clone())
tensordict_noexp = policy(
tensordict.clone().exclude(
*(key for key in tensordict.keys() if key.startswith("_"))
)
)
tensordict = exploratory_policy(tensordict.clone())
if i == 0:
assert (tensordict[exploratory_policy.ou.steps_key] == 1).all()
Expand Down
4 changes: 2 additions & 2 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import torch
from mocking_classes import CountingEnv, DiscreteActionVecMockEnv
from tensordict import pad, TensorDict, unravel_key_list
from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list
from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential
from torch import nn
from torchrl.data.tensor_specs import (
Expand Down Expand Up @@ -515,7 +515,7 @@ def test_sequential_partial(self, stack):
)

if stack:
td = torch.stack(
td = LazyStackedTensorDict.maybe_dense_stack(
[
TensorDict({"a": torch.randn(3), "b": torch.randn(4)}, []),
TensorDict({"a": torch.randn(3), "c": torch.randn(4)}, []),
Expand Down
3 changes: 2 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10583,6 +10583,7 @@ def test_multistep_transform(self):

outs_2 = []
td = env.reset().contiguous()
assert "reward" not in td
for _ in range(1):
rollout = env.rollout(
250, auto_reset=False, tensordict=td, break_when_any_done=False
Expand Down Expand Up @@ -10626,7 +10627,7 @@ def test_multistep_transform(self):
).contiguous()
assert "reward" not in rollout.keys()
out = t._inv_call(rollout)
td = rollout[..., -1]["next"]
td = rollout[..., -1]["next"].exclude("reward")
if out is not None:
outs_3.append(out)

Expand Down
5 changes: 4 additions & 1 deletion torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from torchrl.data.replay_buffers.storages import (
_get_default_collate,
_stack_anything,
ListStorage,
Storage,
StorageEnsemble,
Expand Down Expand Up @@ -1541,8 +1542,10 @@ def __init__(
num_buffer_sampled: int | None = None,
**kwargs,
):

if collate_fn is None:
collate_fn = torch.stack
collate_fn = _stack_anything

if rbs:
if storages is not None or samplers is not None or writers is not None:
raise RuntimeError
Expand Down
13 changes: 12 additions & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
import numpy as np
import tensordict
import torch
from tensordict import is_tensor_collection, TensorDict, TensorDictBase
from tensordict import (
is_tensor_collection,
LazyStackedTensorDict,
TensorDict,
TensorDictBase,
)
from tensordict.memmap import MemmapTensor, MemoryMappedTensor
from tensordict.utils import _STRDTYPE2DTYPE
from torch import multiprocessing as mp
Expand Down Expand Up @@ -1322,6 +1327,12 @@ def _collate_list_tensordict(x):
return out


def _stack_anything(x):
if is_tensor_collection(x[0]):
return LazyStackedTensorDict.maybe_dense_stack(x)
return torch.stack(x)


def _collate_id(x):
return x

Expand Down
10 changes: 7 additions & 3 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,14 +974,18 @@ def zero(self, shape=None) -> TensorDictBase:
dim = self.dim + len(shape)
else:
dim = self.dim
return torch.stack([spec.zero(shape) for spec in self._specs], dim)
return LazyStackedTensorDict.maybe_dense_stack(
[spec.zero(shape) for spec in self._specs], dim
)

def rand(self, shape=None) -> TensorDictBase:
if shape is not None:
dim = self.dim + len(shape)
else:
dim = self.dim
return torch.stack([spec.rand(shape) for spec in self._specs], dim)
return LazyStackedTensorDict.maybe_dense_stack(
[spec.rand(shape) for spec in self._specs], dim
)

def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> T:
if dest is None:
Expand Down Expand Up @@ -4344,7 +4348,7 @@ def project(self, val: TensorDictBase) -> TensorDictBase:
vals.append(spec.project(subval))
else:
vals.append(subval)
res = torch.stack(vals, dim=self.dim)
res = LazyStackedTensorDict.maybe_dense_stack(vals, dim=self.dim)
if not isinstance(val, LazyStackedTensorDict):
res = res.to_tensordict()
return res
Expand Down
20 changes: 13 additions & 7 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

import abc
import re
import warnings
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -506,13 +507,18 @@ def auto_register_info_dict(self):
try:
check_env_specs(self)
return self
except AssertionError as err:
if "The keys of the specs and data do not match" in str(err):
result = TransformedEnv(
self, TensorDictPrimer(self.info_dict_reader[0].info_spec)
)
check_env_specs(result)
return result
except (AssertionError, RuntimeError) as err:
patterns = [
"The keys of the specs and data do not match",
"The sets of keys in the tensordicts to stack are exclusive",
]
for pattern in patterns:
if re.search(pattern, str(err)):
result = TransformedEnv(
self, TensorDictPrimer(self.info_dict_reader[0].info_spec)
)
check_env_specs(result)
return result
raise err

def __repr__(self) -> str:
Expand Down
1 change: 0 additions & 1 deletion torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ def _is_reset(key: NestedKey):
expected = set(expected)
self.validated = expected.intersection(actual) == expected
if not self.validated:
raise RuntimeError
warnings.warn(
"The expected key set and actual key set differ. "
"This will work but with a slower throughput than "
Expand Down

0 comments on commit 09c934d

Please sign in to comment.