Skip to content

Commit

Permalink
[BugFix] Fix collectors with non tensors (#2232)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 19, 2024
1 parent 45ab9de commit 47a2627
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 24 deletions.
1 change: 1 addition & 0 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def main(cfg: "DictConfig"): # noqa: F821

pbar.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")
eval_env.close()


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def main(cfg: "DictConfig"): # noqa: F821
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")

collector.shutdown()
eval_env.close()
train_env.close()


if __name__ == "__main__":
Expand Down
117 changes: 116 additions & 1 deletion test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch

from _utils_internal import (
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
generate_seeds,
Expand Down Expand Up @@ -42,7 +43,13 @@
MultiKeyCountingEnvPolicy,
NestedCountingEnv,
)
from tensordict import assert_allclose_td, LazyStackedTensorDict, TensorDict
from tensordict import (
assert_allclose_td,
LazyStackedTensorDict,
NonTensorData,
TensorDict,
TensorDictBase,
)
from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential

from torch import nn
Expand All @@ -57,7 +64,9 @@
from torchrl.data import (
CompositeSpec,
LazyTensorStorage,
NonTensorSpec,
ReplayBuffer,
TensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import (
Expand All @@ -67,6 +76,7 @@
ParallelEnv,
SerialEnv,
StepCounter,
Transform,
)
from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, set_gym_backend
from torchrl.envs.transforms import TransformedEnv, VecNorm
Expand Down Expand Up @@ -2641,6 +2651,111 @@ def test_dynamic_multiasync_collector(self):
assert data.names[-1] == "time"


@pytest.mark.skipif(not _has_gym, reason="gym required for this test")
class TestCollectorsNonTensor:
class AddNontTensorData(Transform):
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict["nt"] = f"a string! - {tensordict.get('step_count').item()}"
return tensordict

def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
return tensordict_reset.set("nt", NonTensorData("reset!"))

def transform_observation_spec(
self, observation_spec: TensorSpec
) -> TensorSpec:
observation_spec["nt"] = NonTensorSpec(shape=())
return observation_spec

@classmethod
def make_env(cls):
return (
GymEnv(CARTPOLE_VERSIONED())
.append_transform(StepCounter())
.append_transform(cls.AddNontTensorData())
)

def test_simple(self):
torch.manual_seed(0)
env = self.make_env()
env.set_seed(0)
collector = SyncDataCollector(env, frames_per_batch=10, total_frames=200)
result = []
for data in collector:
result.append(data)
result = torch.cat(result)
for i, val in enumerate(result["nt"][1:]):
if val == "a string! - 1":
assert result["nt"][i] == "reset!"
elif val.startswith("a string!"):
assert result["next", "nt"][i] == val
int1 = int(val.split(" - ")[-1])
int0 = int(result["nt"][i].split(" - ")[-1])
assert int0 + 1 == int1
elif val == "reset!":
assert result["next", "nt"][i + 1] == "a string! - 1", i

@pytest.mark.parametrize("use_buffers", [True, False])
def test_sync(self, use_buffers):
torch.manual_seed(0)
collector = MultiSyncDataCollector(
[self.make_env, self.make_env],
frames_per_batch=10,
total_frames=200,
cat_results="stack",
use_buffers=use_buffers,
)
try:
result = []
for data in collector:
result.append(data)
results = torch.cat(result)
for result in results.unbind(0):
for i, val in enumerate(result["nt"][1:]):
if val == "a string! - 1":
assert result["nt"][i] == "reset!"
elif val.startswith("a string!"):
assert result["next", "nt"][i] == val
int1 = int(val.split(" - ")[-1])
int0 = int(result["nt"][i].split(" - ")[-1])
assert int0 + 1 == int1
elif val == "reset!":
assert result["next", "nt"][i + 1] == "a string! - 1", i
finally:
collector.shutdown()
del collector

@pytest.mark.parametrize("use_buffers", [True, False])
def test_async(self, use_buffers):
torch.manual_seed(0)
collector = MultiaSyncDataCollector(
[self.make_env, self.make_env],
frames_per_batch=10,
total_frames=200,
use_buffers=use_buffers,
)
try:
results = []
for data in collector:
results.append(data)
for result in results:
for i, val in enumerate(result["nt"][1:]):
if val == "a string! - 1":
assert result["nt"][i] == "reset!"
elif val.startswith("a string!"):
assert result["next", "nt"][i] == val
int1 = int(val.split(" - ")[-1])
int0 = int(result["nt"][i].split(" - ")[-1])
assert int0 + 1 == int1
elif val == "reset!":
assert result["next", "nt"][i + 1] == "a string! - 1", i
finally:
collector.shutdown()
del collector


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
37 changes: 21 additions & 16 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def __init__(
TensorDictModule,
Callable[[TensorDictBase], TensorDictBase],
]
],
] = None,
*,
frames_per_batch: int,
total_frames: int = -1,
Expand All @@ -449,7 +449,6 @@ def __init__(
from torchrl.envs.batched_envs import BatchedEnvBase

self.closed = True

exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
)
Expand All @@ -467,6 +466,11 @@ def __init__(
)
env.update_kwargs(create_env_kwargs)

if policy is None:
from torchrl.collectors import RandomPolicy

policy = RandomPolicy(env.action_spec)

##########################
# Setting devices:
# The rule is the following:
Expand Down Expand Up @@ -796,8 +800,6 @@ def filter_policy(value_output, value_input, value_input_clone):
)
self._final_rollout.refine_names(..., "time")

assert self._final_rollout.names[-1] == "time"

def _set_truncated_keys(self):
self._truncated_keys = []
if self.set_truncated:
Expand Down Expand Up @@ -1082,7 +1084,6 @@ def rollout(self) -> TensorDictBase:
)
else:
result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
assert result.names[-1] == "time"
break
else:
if self._use_buffers:
Expand All @@ -1093,7 +1094,6 @@ def rollout(self) -> TensorDictBase:
self._final_rollout.ndim - 1,
out=self._final_rollout,
)
assert result.names[-1] == "time"

except RuntimeError:
with self._final_rollout.unlock_():
Expand All @@ -1102,7 +1102,6 @@ def rollout(self) -> TensorDictBase:
self._final_rollout.ndim - 1,
out=self._final_rollout,
)
assert result.names[-1] == "time"
else:
result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
result.refine_names(..., "time")
Expand Down Expand Up @@ -1237,7 +1236,7 @@ class _MultiDataCollector(DataCollectorBase):
instance of :class:`~torchrl.envs.EnvBase`.
policy (Callable): Policy to be executed in the environment.
Must accept :class:`tensordict.tensordict.TensorDictBase` object as input.
If ``None`` is provided, the policy used will be a
If ``None`` is provided (default), the policy used will be a
:class:`~torchrl.collectors.RandomPolicy` instance with the environment
``action_spec``.
Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`.
Expand Down Expand Up @@ -1382,7 +1381,7 @@ def __init__(
TensorDictModule,
Callable[[TensorDictBase], TensorDictBase],
]
],
] = None,
*,
frames_per_batch: int,
total_frames: Optional[int] = -1,
Expand Down Expand Up @@ -1453,13 +1452,18 @@ def __init__(
_policy_weights_dict = {}
_get_weights_fn_dict = {}

policy = _NonParametricPolicyWrapper(policy)
policy_weights = TensorDict.from_module(policy, as_module=True)
if policy is not None:
policy = _NonParametricPolicyWrapper(policy)
policy_weights = TensorDict.from_module(policy, as_module=True)

# store a stateless policy
# store a stateless policy
with policy_weights.apply(_make_meta_params).to_module(policy):
# TODO:
self.policy = deepcopy(policy)

with policy_weights.apply(_make_meta_params).to_module(policy):
self.policy = deepcopy(policy)
else:
policy_weights = TensorDict()
self.policy = None

for policy_device in policy_devices:
# if we have already mapped onto that device, get that value
Expand Down Expand Up @@ -1694,7 +1698,9 @@ def _run_processes(self) -> None:
storing_device = self.storing_device[i]
env_device = self.env_device[i]
policy = self.policy
with self._policy_weights_dict[policy_device].to_module(policy):
with self._policy_weights_dict[policy_device].to_module(
policy
) if policy is not None else contextlib.nullcontext():
kwargs = {
"pipe_parent": pipe_parent,
"pipe_child": pipe_child,
Expand Down Expand Up @@ -2864,7 +2870,6 @@ def _main_async_collector(
else x
)
data = (collected_tensordict, idx)
assert collected_tensordict.names[-1] == "time"
else:
if next_data is not collected_tensordict:
raise RuntimeError(
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,8 +1355,8 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
self._workers.append(process)

for parent_pipe in self.parent_channels:
msg = parent_pipe.recv()
assert msg == "started"
# use msg as sync point
parent_pipe.recv()

# send shared tensordict to workers
for channel in self.parent_channels:
Expand Down
4 changes: 0 additions & 4 deletions torchrl/modules/models/model_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import torch
from packaging import version
from tensordict import LazyStackedTensorDict
from tensordict.nn import (
NormalParamExtractor,
TensorDictModule,
Expand Down Expand Up @@ -263,9 +262,6 @@ def forward(self, tensordict):
_tensordict = update_values[t + 1].update(_tensordict)

out = torch.stack(tensordict_out, tensordict.ndim - 1)
assert not any(
isinstance(val, LazyStackedTensorDict) for val in out.values(True)
), out
return out


Expand Down
1 change: 0 additions & 1 deletion torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
target_params=self.target_critic_network_params,
)
advantage = tensordict.get(self.tensor_keys.advantage)
assert not advantage.requires_grad

This comment has been minimized.

Copy link
@jkrude

jkrude Jun 28, 2024

@vmoens May I ask why you removed this assertion? Isn't it still necessary to have a gradient-free advantage for the policy-loss? Could we ensure that by adding advantage.detach() or output a warning?

This comment has been minimized.

Copy link
@vmoens

vmoens Jun 28, 2024

Author Contributor

just because we don't want assert in the main code base but proper error messages. Now, you're right that this has its value and there should be a check there.
I think that detaching is dangerous bc someone could be trying to backprop through that advantage at some point somewhere, and just silently detaching may cause that person some headache (whilst an error will tell them that this isn't supposed to be differentiable)

log_probs, dist = self._log_probs(tensordict)
loss = -(log_probs * advantage)
td_out = TensorDict({"loss_objective": loss}, batch_size=[])
Expand Down

0 comments on commit 47a2627

Please sign in to comment.