Skip to content

Commit

Permalink
[Refactor] Remove _run_checks from TensorDict.__init__ (pytorch#2256
Browse files Browse the repository at this point in the history
)
  • Loading branch information
vmoens authored Jun 30, 2024
1 parent 443620f commit 39462f0
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 19 deletions.
16 changes: 14 additions & 2 deletions torchrl/collectors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,20 @@ def nest(x, splits=splits):

layout = as_nested if as_nested is not bool else None

def nest(*x):
return torch.nested.nested_tensor(list(x), layout=layout)
if torch.__version__ < "2.4":
# Layout must be True, there is no other layout available
if layout not in (True,):
raise RuntimeError(
f"layout={layout} is only available for torch>=v2.4"
)

def nest(*x):
return torch.nested.nested_tensor(list(x))

else:

def nest(*x):
return torch.nested.nested_tensor(list(x), layout=layout)

return out_splits[0]._fast_apply(
nest,
Expand Down
9 changes: 4 additions & 5 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3982,7 +3982,7 @@ def encode(
if isinstance(vals, TensorDict):
out = vals.empty() # create and empty tensordict similar to vals
else:
out = TensorDict({}, torch.Size([]), _run_checks=False)
out = TensorDict._new_unsafe({}, torch.Size([]))
for key, item in vals.items():
if item is None:
raise RuntimeError(
Expand Down Expand Up @@ -4047,13 +4047,12 @@ def rand(self, shape=None) -> TensorDictBase:
for key, item in self.items():
if item is not None:
_dict[key] = item.rand(shape)
return TensorDict(
# No need to run checks since we know Composite is compliant with
# TensorDict requirements
return TensorDict._new_unsafe(
_dict,
batch_size=torch.Size([*shape, *self.shape]),
device=self._device,
# No need to run checks since we know Composite is compliant with
# TensorDict requirements
_run_checks=False,
)

def keys(
Expand Down
8 changes: 4 additions & 4 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
for key, val in TensorDict(obs_dict, []).items(True, True)
)
else:
tensordict_out = TensorDict(
obs_dict, batch_size=tensordict.batch_size, _run_checks=False
tensordict_out = TensorDict._new_unsafe(
obs_dict,
batch_size=tensordict.batch_size,
)
if self.device is not None:
tensordict_out = tensordict_out.to(self.device, non_blocking=True)
Expand Down Expand Up @@ -377,10 +378,9 @@ def _reset(

source = self.read_obs(obs)

tensordict_out = TensorDict(
tensordict_out = TensorDict._new_unsafe(
source=source,
batch_size=self.batch_size,
_run_checks=not self.validated,
)
if self.info_dict_reader and info is not None:
for info_dict_reader in self.info_dict_reader:
Expand Down
9 changes: 3 additions & 6 deletions torchrl/envs/libs/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase:
state["reward"] = state.get("reward").view(*self.reward_spec.shape)
state["done"] = state.get("done").view(*self.reward_spec.shape)
done = state["done"].bool()
tensordict_out = TensorDict(
tensordict_out = TensorDict._new_unsafe(
source={
"observation": state.get("obs"),
# "reward": reward,
Expand All @@ -331,7 +331,6 @@ def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase:
},
batch_size=self.batch_size,
device=self.device,
_run_checks=False,
)
return tensordict_out

Expand All @@ -357,7 +356,7 @@ def _step_without_grad(self, tensordict: TensorDictBase):
next_state.set("done", next_state.get("done").view(self.reward_spec.shape))
done = next_state["done"].bool()
reward = next_state["reward"]
tensordict_out = TensorDict(
tensordict_out = TensorDict._new_unsafe(
source={
"observation": next_state.get("obs"),
"reward": reward,
Expand All @@ -367,7 +366,6 @@ def _step_without_grad(self, tensordict: TensorDictBase):
},
batch_size=self.batch_size,
device=self.device,
_run_checks=False,
)
return tensordict_out

Expand Down Expand Up @@ -396,7 +394,7 @@ def _step_with_grad(self, tensordict: TensorDictBase):
next_state.get("pipeline_state").update(dict(zip(qp_keys, next_qp_values)))

# build result
tensordict_out = TensorDict(
tensordict_out = TensorDict._new_unsafe(
source={
"observation": next_obs,
"reward": next_reward,
Expand All @@ -406,7 +404,6 @@ def _step_with_grad(self, tensordict: TensorDictBase):
},
batch_size=self.batch_size,
device=self.device,
_run_checks=False,
)
return tensordict_out

Expand Down
3 changes: 1 addition & 2 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,15 +632,14 @@ def _actor_loss(
@property
@_cache_values
def _cached_target_params_actor_value(self):
return TensorDict(
return TensorDict._new_unsafe(
{
"module": {
"0": self.target_actor_network_params,
"1": self.target_value_network_params,
}
},
torch.Size([]),
_run_checks=False,
)

def _qvalue_v1_loss(
Expand Down

0 comments on commit 39462f0

Please sign in to comment.