Skip to content

Commit

Permalink
[Refactor] Use filter_empty=True in apply (#1879)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 6, 2024
1 parent ff3a350 commit 62d977b
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 41 deletions.
2 changes: 1 addition & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6100,7 +6100,7 @@ def zero_param(p):
if isinstance(p, nn.Parameter):
p.data.zero_()

params.apply(zero_param)
params.apply(zero_param, filter_empty=True)

# assert len(list(floss_fn.parameters())) == 0
with params.to_module(loss_fn):
Expand Down
22 changes: 15 additions & 7 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,22 +801,30 @@ def check_exclusive(val):
"Consider using a placeholder for missing keys."
)

policy_output._fast_apply(check_exclusive, call_on_nested=True)
policy_output._fast_apply(
check_exclusive, call_on_nested=True, filter_empty=True
)

# Use apply, because it works well with lazy stacks
# Edge-case of this approach: the policy may change the values in-place and only by a tiny bit
# or occasionally. In these cases, the keys will be missed (we can't detect if the policy has
# changed them here).
# This will cause a failure to update entries when policy and env device mismatch and
# casting is necessary.
def filter_policy(value_output, value_input, value_input_clone):
if (
(value_input is None)
or (value_output is not value_input)
or ~torch.isclose(value_output, value_input_clone).any()
):
return value_output

filtered_policy_output = policy_output.apply(
lambda value_output, value_input, value_input_clone: value_output
if (value_input is None)
or (value_output is not value_input)
or ~torch.isclose(value_output, value_input_clone).any()
else None,
filter_policy,
policy_input_copy,
policy_input_clone,
default=None,
filter_empty=True,
)
self._policy_output_keys = list(
self._policy_output_keys.union(
Expand Down Expand Up @@ -933,7 +941,7 @@ def cuda_check(tensor: torch.Tensor):
if tensor.is_cuda:
cuda_devices.add(tensor.device)

self._final_rollout.apply(cuda_check)
self._final_rollout.apply(cuda_check, filter_empty=True)
for device in cuda_devices:
streams.append(torch.cuda.Stream(device, priority=-1))
events.append(streams[-1].record_event())
Expand Down
31 changes: 9 additions & 22 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,12 +419,8 @@ def _check_for_empty_spec(specs: CompositeSpec):
def map_device(key, value, device_map=device_map):
return value.to(device_map[key])

# self._env_tensordict.named_apply(
# map_device, nested_keys=True, filter_empty=True
# )
self._env_tensordict.named_apply(
map_device,
nested_keys=True,
map_device, nested_keys=True, filter_empty=True
)

self._batch_locked = meta_data.batch_locked
Expand Down Expand Up @@ -792,16 +788,11 @@ def select_and_clone(name, tensor):
if name in selected_output_keys:
return tensor.clone()

# out = self.shared_tensordict_parent.named_apply(
# select_and_clone,
# nested_keys=True,
# filter_empty=True,
# )
out = self.shared_tensordict_parent.named_apply(
select_and_clone,
nested_keys=True,
filter_empty=True,
)
del out["next"]

if out.device != device:
if device is None:
Expand Down Expand Up @@ -842,8 +833,7 @@ def select_and_clone(name, tensor):
if name in self._selected_step_keys:
return tensor.clone()

# out = next_td.named_apply(select_and_clone, nested_keys=True, filter_empty=True)
out = next_td.named_apply(select_and_clone, nested_keys=True)
out = next_td.named_apply(select_and_clone, nested_keys=True, filter_empty=True)

if out.device != device:
if device is None:
Expand Down Expand Up @@ -1059,8 +1049,7 @@ def _start_workers(self) -> None:
def look_for_cuda(tensor, has_cuda=has_cuda):
has_cuda[0] = has_cuda[0] or tensor.is_cuda

# self.shared_tensordict_parent.apply(look_for_cuda, filter_empty=True)
self.shared_tensordict_parent.apply(look_for_cuda)
self.shared_tensordict_parent.apply(look_for_cuda, filter_empty=True)
has_cuda = has_cuda[0]
if has_cuda:
self.event = torch.cuda.Event()
Expand Down Expand Up @@ -1182,14 +1171,14 @@ def step_and_maybe_reset(
if x.device != device
else x.clone(),
device=device,
# filter_empty=True,
filter_empty=True,
)
tensordict_ = tensordict_._fast_apply(
lambda x: x.to(device, non_blocking=True)
if x.device != device
else x.clone(),
device=device,
# filter_empty=True,
filter_empty=True,
)
else:
next_td = next_td.clone().clear_device_()
Expand Down Expand Up @@ -1244,7 +1233,7 @@ def select_and_clone(name, tensor):
out = next_td.named_apply(
select_and_clone,
nested_keys=True,
# filter_empty=True,
filter_empty=True,
)
if out.device != device:
if device is None:
Expand Down Expand Up @@ -1314,9 +1303,8 @@ def select_and_clone(name, tensor):
out = self.shared_tensordict_parent.named_apply(
select_and_clone,
nested_keys=True,
# filter_empty=True,
filter_empty=True,
)
del out["next"]

if out.device != device:
if device is None:
Expand Down Expand Up @@ -1452,8 +1440,7 @@ def _run_worker_pipe_shared_mem(
def look_for_cuda(tensor, has_cuda=has_cuda):
has_cuda[0] = has_cuda[0] or tensor.is_cuda

# shared_tensordict.apply(look_for_cuda, filter_empty=True)
shared_tensordict.apply(look_for_cuda)
shared_tensordict.apply(look_for_cuda, filter_empty=True)
has_cuda = has_cuda[0]
else:
has_cuda = device.type == "cuda"
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def metadata_from_env(env) -> EnvMetaData:
def fill_device_map(name, val, device_map=device_map):
device_map[name] = val.device

tensordict.named_apply(fill_device_map, nested_keys=True)
tensordict.named_apply(fill_device_map, nested_keys=True, filter_empty=True)
return EnvMetaData(
tensordict, specs, batch_size, env_str, device, batch_locked, device_map
)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3063,7 +3063,7 @@ def __init__(self):
super().__init__(in_keys=[])

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict.apply(check_finite)
tensordict.apply(check_finite, filter_empty=True)
return tensordict

def _reset(
Expand Down
28 changes: 19 additions & 9 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,9 +577,14 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor:
def _get_policy_actions(self, data, actor_params, num_actions=10):
batch_size = data.batch_size
batch_size = list(batch_size[:-1]) + [batch_size[-1] * num_actions]
tensordict = data.select(*self.actor_network.in_keys).apply(
lambda x: x.repeat_interleave(num_actions, dim=data.ndim - 1),
batch_size=batch_size,
in_keys = [unravel_key(key) for key in self.actor_network.in_keys]

def filter_and_repeat(name, x):
if name in in_keys:
return x.repeat_interleave(num_actions, dim=data.ndim - 1)

tensordict = data.named_apply(
filter_and_repeat, batch_size=batch_size, filter_empty=True
)
with torch.no_grad():
with set_exploration_type(ExplorationType.RANDOM), actor_params.to_module(
Expand Down Expand Up @@ -731,13 +736,18 @@ def cql_loss(self, tensordict: TensorDictBase) -> Tensor:

batch_size = tensordict_q_random.batch_size
batch_size = list(batch_size[:-1]) + [batch_size[-1] * self.num_random]
tensordict_q_random = tensordict_q_random.select(
*self.actor_network.in_keys
).apply(
lambda x: x.repeat_interleave(
self.num_random, dim=tensordict_q_random.ndim - 1
),
in_keys = [unravel_key(key) for key in self.actor_network.in_keys]

def filter_and_repeat(name, x):
if name in in_keys:
return x.repeat_interleave(
self.num_random, dim=tensordict_q_random.ndim - 1
)

tensordict_q_random = tensordict_q_random.named_apply(
filter_and_repeat,
batch_size=batch_size,
filter_empty=True,
)
tensordict_q_random.set(self.tensor_keys.action, random_actions_tensor)
cql_tensordict = torch.cat(
Expand Down

0 comments on commit 62d977b

Please sign in to comment.