Skip to content

Commit

Permalink
[BugFix] Fix collectors with non tensors (pytorch#2232)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 19, 2024
1 parent 45ab9de commit 47a2627
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 24 deletions.
1 change: 1 addition & 0 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def main(cfg: "DictConfig"): # noqa: F821

pbar.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")
eval_env.close()


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def main(cfg: "DictConfig"): # noqa: F821
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")

collector.shutdown()
eval_env.close()
train_env.close()


if __name__ == "__main__":
Expand Down
117 changes: 116 additions & 1 deletion test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch

from _utils_internal import (
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
generate_seeds,
Expand Down Expand Up @@ -42,7 +43,13 @@
MultiKeyCountingEnvPolicy,
NestedCountingEnv,
)
from tensordict import assert_allclose_td, LazyStackedTensorDict, TensorDict
from tensordict import (
assert_allclose_td,
LazyStackedTensorDict,
NonTensorData,
TensorDict,
TensorDictBase,
)
from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential

from torch import nn
Expand All @@ -57,7 +64,9 @@
from torchrl.data import (
CompositeSpec,
LazyTensorStorage,
NonTensorSpec,
ReplayBuffer,
TensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import (
Expand All @@ -67,6 +76,7 @@
ParallelEnv,
SerialEnv,
StepCounter,
Transform,
)
from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, set_gym_backend
from torchrl.envs.transforms import TransformedEnv, VecNorm
Expand Down Expand Up @@ -2641,6 +2651,111 @@ def test_dynamic_multiasync_collector(self):
assert data.names[-1] == "time"


@pytest.mark.skipif(not _has_gym, reason="gym required for this test")
class TestCollectorsNonTensor:
class AddNontTensorData(Transform):
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict["nt"] = f"a string! - {tensordict.get('step_count').item()}"
return tensordict

def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
return tensordict_reset.set("nt", NonTensorData("reset!"))

def transform_observation_spec(
self, observation_spec: TensorSpec
) -> TensorSpec:
observation_spec["nt"] = NonTensorSpec(shape=())
return observation_spec

@classmethod
def make_env(cls):
return (
GymEnv(CARTPOLE_VERSIONED())
.append_transform(StepCounter())
.append_transform(cls.AddNontTensorData())
)

def test_simple(self):
torch.manual_seed(0)
env = self.make_env()
env.set_seed(0)
collector = SyncDataCollector(env, frames_per_batch=10, total_frames=200)
result = []
for data in collector:
result.append(data)
result = torch.cat(result)
for i, val in enumerate(result["nt"][1:]):
if val == "a string! - 1":
assert result["nt"][i] == "reset!"
elif val.startswith("a string!"):
assert result["next", "nt"][i] == val
int1 = int(val.split(" - ")[-1])
int0 = int(result["nt"][i].split(" - ")[-1])
assert int0 + 1 == int1
elif val == "reset!":
assert result["next", "nt"][i + 1] == "a string! - 1", i

@pytest.mark.parametrize("use_buffers", [True, False])
def test_sync(self, use_buffers):
torch.manual_seed(0)
collector = MultiSyncDataCollector(
[self.make_env, self.make_env],
frames_per_batch=10,
total_frames=200,
cat_results="stack",
use_buffers=use_buffers,
)
try:
result = []
for data in collector:
result.append(data)
results = torch.cat(result)
for result in results.unbind(0):
for i, val in enumerate(result["nt"][1:]):
if val == "a string! - 1":
assert result["nt"][i] == "reset!"
elif val.startswith("a string!"):
assert result["next", "nt"][i] == val
int1 = int(val.split(" - ")[-1])
int0 = int(result["nt"][i].split(" - ")[-1])
assert int0 + 1 == int1
elif val == "reset!":
assert result["next", "nt"][i + 1] == "a string! - 1", i
finally:
collector.shutdown()
del collector

@pytest.mark.parametrize("use_buffers", [True, False])
def test_async(self, use_buffers):
torch.manual_seed(0)
collector = MultiaSyncDataCollector(
[self.make_env, self.make_env],
frames_per_batch=10,
total_frames=200,
use_buffers=use_buffers,
)
try:
results = []
for data in collector:
results.append(data)
for result in results:
for i, val in enumerate(result["nt"][1:]):
if val == "a string! - 1":
assert result["nt"][i] == "reset!"
elif val.startswith("a string!"):
assert result["next", "nt"][i] == val
int1 = int(val.split(" - ")[-1])
int0 = int(result["nt"][i].split(" - ")[-1])
assert int0 + 1 == int1
elif val == "reset!":
assert result["next", "nt"][i + 1] == "a string! - 1", i
finally:
collector.shutdown()
del collector


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
37 changes: 21 additions & 16 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def __init__(
TensorDictModule,
Callable[[TensorDictBase], TensorDictBase],
]
],
] = None,
*,
frames_per_batch: int,
total_frames: int = -1,
Expand All @@ -449,7 +449,6 @@ def __init__(
from torchrl.envs.batched_envs import BatchedEnvBase

self.closed = True

exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
)
Expand All @@ -467,6 +466,11 @@ def __init__(
)
env.update_kwargs(create_env_kwargs)

if policy is None:
from torchrl.collectors import RandomPolicy

policy = RandomPolicy(env.action_spec)

##########################
# Setting devices:
# The rule is the following:
Expand Down Expand Up @@ -796,8 +800,6 @@ def filter_policy(value_output, value_input, value_input_clone):
)
self._final_rollout.refine_names(..., "time")

assert self._final_rollout.names[-1] == "time"

def _set_truncated_keys(self):
self._truncated_keys = []
if self.set_truncated:
Expand Down Expand Up @@ -1082,7 +1084,6 @@ def rollout(self) -> TensorDictBase:
)
else:
result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
assert result.names[-1] == "time"
break
else:
if self._use_buffers:
Expand All @@ -1093,7 +1094,6 @@ def rollout(self) -> TensorDictBase:
self._final_rollout.ndim - 1,
out=self._final_rollout,
)
assert result.names[-1] == "time"

except RuntimeError:
with self._final_rollout.unlock_():
Expand All @@ -1102,7 +1102,6 @@ def rollout(self) -> TensorDictBase:
self._final_rollout.ndim - 1,
out=self._final_rollout,
)
assert result.names[-1] == "time"
else:
result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
result.refine_names(..., "time")
Expand Down Expand Up @@ -1237,7 +1236,7 @@ class _MultiDataCollector(DataCollectorBase):
instance of :class:`~torchrl.envs.EnvBase`.
policy (Callable): Policy to be executed in the environment.
Must accept :class:`tensordict.tensordict.TensorDictBase` object as input.
If ``None`` is provided, the policy used will be a
If ``None`` is provided (default), the policy used will be a
:class:`~torchrl.collectors.RandomPolicy` instance with the environment
``action_spec``.
Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`.
Expand Down Expand Up @@ -1382,7 +1381,7 @@ def __init__(
TensorDictModule,
Callable[[TensorDictBase], TensorDictBase],
]
],
] = None,
*,
frames_per_batch: int,
total_frames: Optional[int] = -1,
Expand Down Expand Up @@ -1453,13 +1452,18 @@ def __init__(
_policy_weights_dict = {}
_get_weights_fn_dict = {}

policy = _NonParametricPolicyWrapper(policy)
policy_weights = TensorDict.from_module(policy, as_module=True)
if policy is not None:
policy = _NonParametricPolicyWrapper(policy)
policy_weights = TensorDict.from_module(policy, as_module=True)

# store a stateless policy
# store a stateless policy
with policy_weights.apply(_make_meta_params).to_module(policy):
# TODO:
self.policy = deepcopy(policy)

with policy_weights.apply(_make_meta_params).to_module(policy):
self.policy = deepcopy(policy)
else:
policy_weights = TensorDict()
self.policy = None

for policy_device in policy_devices:
# if we have already mapped onto that device, get that value
Expand Down Expand Up @@ -1694,7 +1698,9 @@ def _run_processes(self) -> None:
storing_device = self.storing_device[i]
env_device = self.env_device[i]
policy = self.policy
with self._policy_weights_dict[policy_device].to_module(policy):
with self._policy_weights_dict[policy_device].to_module(
policy
) if policy is not None else contextlib.nullcontext():
kwargs = {
"pipe_parent": pipe_parent,
"pipe_child": pipe_child,
Expand Down Expand Up @@ -2864,7 +2870,6 @@ def _main_async_collector(
else x
)
data = (collected_tensordict, idx)
assert collected_tensordict.names[-1] == "time"
else:
if next_data is not collected_tensordict:
raise RuntimeError(
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,8 +1355,8 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
self._workers.append(process)

for parent_pipe in self.parent_channels:
msg = parent_pipe.recv()
assert msg == "started"
# use msg as sync point
parent_pipe.recv()

# send shared tensordict to workers
for channel in self.parent_channels:
Expand Down
4 changes: 0 additions & 4 deletions torchrl/modules/models/model_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import torch
from packaging import version
from tensordict import LazyStackedTensorDict
from tensordict.nn import (
NormalParamExtractor,
TensorDictModule,
Expand Down Expand Up @@ -263,9 +262,6 @@ def forward(self, tensordict):
_tensordict = update_values[t + 1].update(_tensordict)

out = torch.stack(tensordict_out, tensordict.ndim - 1)
assert not any(
isinstance(val, LazyStackedTensorDict) for val in out.values(True)
), out
return out


Expand Down
1 change: 0 additions & 1 deletion torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
target_params=self.target_critic_network_params,
)
advantage = tensordict.get(self.tensor_keys.advantage)
assert not advantage.requires_grad
log_probs, dist = self._log_probs(tensordict)
loss = -(log_probs * advantage)
td_out = TensorDict({"loss_objective": loss}, batch_size=[])
Expand Down

0 comments on commit 47a2627

Please sign in to comment.