Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Extend TensorDictPrimer default_value options #2071

Merged
merged 21 commits into from
Apr 18, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
minor fix
  • Loading branch information
albertbou92 committed Apr 14, 2024
commit a1cb9a1a4cdc13f4dd295b550d92dfd7fbd23caa
50 changes: 18 additions & 32 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4518,15 +4518,19 @@ def __init__(
)
self.random = random
if isinstance(default_value, dict):
if len(default_value) != len(self.primers) and set(dict.keys()) != set(
self.primers.keys(True, True)
):
primer_keys = {unravel_key(key) for key in self.primers.keys(True, True)}
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
default_value_keys = {unravel_key(key) for key in default_value.keys()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about passing through a tensordict to represent this?
For instance, this format will be messy with nested keys

default_values = {("a", "b"): 1, ("c", "d"): lambda: torch.randn(()), "e": {"f": lambda: torch.zeros(())}}

but if you use tensordict nightly you get a nice representation:

default_values = TensorDict(default_values, []).to_dict()
default_values

which prints

{'a': {'b': tensor(1)},
 'c': {'d': <function __main__.<lambda>()>},
 'e': {'f': <function __main__.<lambda>()>}}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since tensordict accepts whatever value now, we could even not transform it back to a dict

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That way the default value structure will be 100% identical with the CompositeSpec that we use to represent the specs

Copy link
Contributor Author

@albertbou92 albertbou92 Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah very cool option @vmoens !

atm I need to transform it back to dict, because I get the following behaviour for non-tensor data:

import torch
from tensordict import TensorDict
default_value = {
    "mykey1": lambda: torch.ones(3),
    "mykey2": lambda: torch.tensor(1, dtype=torch.int64),
}
default_value = TensorDict(default_value, [])
keys = default_value.keys(True, True)
print(keys)

output:

_TensorDictKeysView([],
    include_nested=True,
    leaves_only=True)

So non-tensor data ara not considered leafs

Copy link
Contributor

@vmoens vmoens Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you using the latest nightly?
You can always define your own is_leaf for keys:

import tensordict
import torch
from tensordict import TensorDict
default_value = {
    "mykey1": lambda: torch.ones(3),
    "mykey2": lambda: torch.tensor(1, dtype=torch.int64),
}
default_value = TensorDict(default_value, [])
print(default_value)
keys = list(default_value.keys(True, True, is_leaf=lambda x: issubclass(x, (tensordict.NonTensorData, torch.Tensor))))
print(keys)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was using the latest commit, yes. But this solution works fine :)

if primer_keys != default_value_keys:
raise ValueError(
"If a default_value dictionary is provided, it must match the primers keys."
)
default_value = {
key: default_value[key] for key in self.primers.keys(True, True)
}
else:
default_value = {
key: default_value for key in self.primers.keys(True, True)
}
self.default_value = default_value
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
self._validated = False
self.reset_key = reset_key
Expand Down Expand Up @@ -4620,18 +4624,6 @@ def _batch_size(self):
return self.parent.batch_size

def _validate_value_tensor(self, value, spec):
if value.shape != spec.shape:
raise RuntimeError(
f"Value shape ({value.shape}) does not match the spec shape ({spec.shape})."
)
if value.dtype != spec.dtype:
raise RuntimeError(
f"Value dtype ({value.dtype}) does not match the spec dtype ({spec.dtype})."
)
if value.device != spec.device:
raise RuntimeError(
f"Value device ({value.device}) does not match the spec device ({spec.device})."
)
if not spec.is_in(value):
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(f"Value ({value}) is not in the spec domain ({spec}).")
return True
Expand All @@ -4648,19 +4640,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.random:
value = spec.rand()
else:
if isinstance(self.default_value, dict):
value = self.default_value[key]
else:
value = self.default_value
value = self.default_value[key]
if callable(value):
value = value()
if not self._validated:
self._validate_value_tensor(value, spec)
else:
value = torch.full_like(
spec.zero(),
value = torch.full(
spec.shape,
value,
)

tensordict.set(key, value)
if not self._validated:
self._validated = True
Expand Down Expand Up @@ -4696,17 +4686,14 @@ def _reset(
if self.random:
value = spec.rand(shape)
else:
if isinstance(self.default_value, dict):
value = self.default_value[key]
else:
value = self.default_value
value = self.default_value[key]
if callable(value):
value = value()
if not self._validated:
self._validate_value_tensor(value, spec)
else:
value = torch.full_like(
spec.zero(shape),
value = torch.full(
spec.shape,
value,
)
prev_val = tensordict.get(key, 0.0)
Expand All @@ -4719,11 +4706,10 @@ def _reset(

def __repr__(self) -> str:
class_name = self.__class__.__name__
default_value = (
self.default_value
if isinstance(self.default_value, float)
else self.default_value.__class__.__name__
)
default_value = {
key: value if isinstance(value, float) else "Callable"
for key, value in self.default_value.items()
}
return f"{class_name}(primers={self.primers}, default_value={default_value}, random={self.random})"


Expand Down