Skip to content

Commit

Permalink
[Refactor] Refactor 'next_' into nested tensordicts (pytorch#649)
Browse files Browse the repository at this point in the history
* init

* [Feature] Nested composite spec (pytorch#654)

* [Feature] Move `transform.forward` to `transform.step` (pytorch#660)

* transform step function

* amend

* amend

* amend

* amend

* amend

* fixing key names

* fixing key names

* [Refactor] Transform next remove (pytorch#661)

* Refactor "next_" into ("next", ) (pytorch#673)

* amend

* amend

* bugfix

* init

* strict=False

* strict=False

* minor

* amend

* [BugFix] Use GitHub for flake8 pre-commit hook (pytorch#679)

* amend

* [BugFix] Update to strict select (pytorch#675)

* init

* strict=False

* amend

* amend

* [Feature] Auto-compute stats for ObservationNorm (pytorch#669)

* Add auto-compute stats feature for ObservationNorm

* Fix issue in ObservNorm init function

* Quick refactor of ObservationNorm init method

* Minor refactoring and adding more tests for ObservationNorm

* lint

* docstring

* docstring

Co-authored-by: vmoens <vincentmoens@gmail.com>

* amend

* amend

* lint

* bf

* bf

* amend

Co-authored-by: Romain Julien <romainjulien@fb.com>

Co-authored-by: Romain Julien <romainjulien@fb.com>
  • Loading branch information
vmoens and romainjln authored Nov 16, 2022
1 parent 354c198 commit 0e3f066
Show file tree
Hide file tree
Showing 41 changed files with 893 additions and 720 deletions.
4 changes: 1 addition & 3 deletions examples/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,7 @@ def main(cfg: "DictConfig"): # noqa: F821
):
sampled_tensordict_save = (
sampled_tensordict.select(
"next_pixels",
"next_reco_pixels",
"state",
"next" "state",
"belief",
)[:4]
.detach()
Expand Down
12 changes: 6 additions & 6 deletions examples/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ def make_env_transforms(
if cfg.grayscale:
env.append_transform(GrayScale())
env.append_transform(FlattenObservation())
env.append_transform(CatFrames(N=cfg.catframes, in_keys=["next_pixels"]))
env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"]))
if stats is None:
obs_stats = {"loc": 0.0, "scale": 1.0}
else:
obs_stats = stats
obs_stats["standard_normal"] = True
env.append_transform(ObservationNorm(**obs_stats, in_keys=["next_pixels"]))
env.append_transform(ObservationNorm(**obs_stats, in_keys=["pixels"]))
if norm_rewards:
reward_scaling = 1.0
reward_loc = 0.0
Expand All @@ -122,8 +122,8 @@ def make_env_transforms(
)

default_dict = {
"next_state": NdUnboundedContinuousTensorSpec(cfg.state_dim),
"next_belief": NdUnboundedContinuousTensorSpec(cfg.rssm_hidden_dim),
"state": NdUnboundedContinuousTensorSpec(cfg.state_dim),
"belief": NdUnboundedContinuousTensorSpec(cfg.rssm_hidden_dim),
}
env.append_transform(
TensorDictPrimer(random=False, default_value=0, **default_dict)
Expand Down Expand Up @@ -309,7 +309,7 @@ def call_record(

true_pixels = recover_pixels(world_model_td["next_pixels"], stats)

reco_pixels = recover_pixels(world_model_td["next_reco_pixels"], stats)
reco_pixels = recover_pixels(world_model_td["next", "reco_pixels"], stats)
with autocast(dtype=torch.float16):
world_model_td = world_model_td.select("state", "belief", "reward")
world_model_td = model_based_env.rollout(
Expand All @@ -319,7 +319,7 @@ def call_record(
tensordict=world_model_td[:, 0],
)
imagine_pxls = recover_pixels(
model_based_env.decode_obs(world_model_td)["next_reco_pixels"],
model_based_env.decode_obs(world_model_td)["next", "reco_pixels"],
stats,
)

Expand Down
32 changes: 19 additions & 13 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch.cuda
from tensordict.tensordict import TensorDictBase
from torchrl._utils import seed_generator
from torchrl.data import CompositeSpec
from torchrl.envs import EnvBase


Expand Down Expand Up @@ -62,21 +61,20 @@ def _test_fake_tensordict(env: EnvBase):


def _check_dtype(key, value, obs_spec, input_spec):
if key.startswith("next_"):
return
if isinstance(value, TensorDictBase):
if isinstance(value, TensorDictBase) and key == "next":
for _key, _value in value.items():
if isinstance(obs_spec, CompositeSpec) and "next_" + key in obs_spec.keys():
_check_dtype(_key, _value, obs_spec["next_" + key], input_spec=None)
elif isinstance(input_spec, CompositeSpec) and key in input_spec.keys():
_check_dtype(_key, _value, obs_spec=None, input_spec=input_spec[key])
else:
raise KeyError(f"key '{_key}' is unknown.")
_check_dtype(_key, _value, obs_spec, input_spec=None)
elif isinstance(value, TensorDictBase) and key in obs_spec.keys():
for _key, _value in value.items():
_check_dtype(_key, _value, obs_spec=obs_spec[key], input_spec=None)
elif isinstance(value, TensorDictBase) and key in input_spec.keys():
for _key, _value in value.items():
_check_dtype(_key, _value, obs_spec=None, input_spec=input_spec[key])
else:
if obs_spec is not None and "next_" + key in obs_spec.keys():
if obs_spec is not None and key in obs_spec.keys():
assert (
obs_spec["next_" + key].dtype is value.dtype
), f"{obs_spec['next_' + key].dtype} vs {value.dtype} for {key}"
obs_spec[key].dtype is value.dtype
), f"{obs_spec[key].dtype} vs {value.dtype} for {key}"
elif input_spec is not None and key in input_spec.keys():
assert (
input_spec[key].dtype is value.dtype
Expand Down Expand Up @@ -112,3 +110,11 @@ def f_retry(*args, **kwargs):
return f_retry # true decorator

return deco_retry


@pytest.fixture
def dtype_fixture():
dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.double)
yield dtype
torch.set_default_dtype(dtype)
83 changes: 32 additions & 51 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __new__(
action_spec = NdUnboundedContinuousTensorSpec((1,))
if observation_spec is None:
observation_spec = CompositeSpec(
next_observation=NdUnboundedContinuousTensorSpec((1,))
observation=NdUnboundedContinuousTensorSpec((1,))
)
if reward_spec is None:
reward_spec = NdUnboundedContinuousTensorSpec((1,))
Expand Down Expand Up @@ -152,19 +152,17 @@ def _step(self, tensordict):
)
done = self.counter >= self.max_val
done = torch.tensor([done], dtype=torch.bool, device=self.device)
return TensorDict(
{"reward": n, "done": done, "next_observation": n.clone()}, []
)
return TensorDict({"reward": n, "done": done, "observation": n.clone()}, [])

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase:
self.max_val = max(self.counter + 100, self.counter * 2)

n = torch.tensor(
[self.counter], device=self.device, dtype=torch.get_default_dtype()
)
done = self.counter >= self.max_val
done = torch.tensor([done], dtype=torch.bool, device=self.device)
return TensorDict({"done": done, "next_observation": n}, [])
return TensorDict({"done": done, "observation": n}, [])

def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
return self.step(tensordict)
Expand Down Expand Up @@ -192,7 +190,7 @@ def __new__(
)
if observation_spec is None:
observation_spec = CompositeSpec(
next_observation=NdUnboundedContinuousTensorSpec((1,))
observation=NdUnboundedContinuousTensorSpec((1,))
)
if reward_spec is None:
reward_spec = NdUnboundedContinuousTensorSpec((1,))
Expand Down Expand Up @@ -226,7 +224,7 @@ def _step(self, tensordict):
)

return TensorDict(
{"reward": n, "done": done, "next_observation": n},
{"reward": n, "done": done, "observation": n},
tensordict.batch_size,
device=self.device,
)
Expand All @@ -247,7 +245,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
done = torch.full(batch_size, done, dtype=torch.bool, device=self.device)

return TensorDict(
{"reward": n, "done": done, "next_observation": n},
{"reward": n, "done": done, "observation": n},
batch_size,
device=self.device,
)
Expand Down Expand Up @@ -287,10 +285,8 @@ def __new__(
if observation_spec is None:
cls.out_key = "observation"
observation_spec = CompositeSpec(
next_observation=NdUnboundedContinuousTensorSpec(
shape=torch.Size([size])
),
next_observation_orig=NdUnboundedContinuousTensorSpec(
observation=NdUnboundedContinuousTensorSpec(shape=torch.Size([size])),
observation_orig=NdUnboundedContinuousTensorSpec(
shape=torch.Size([size])
),
)
Expand All @@ -308,7 +304,7 @@ def __new__(
cls._out_key = "observation_orig"
input_spec = CompositeSpec(
**{
cls._out_key: observation_spec["next_observation"],
cls._out_key: observation_spec["observation"],
"action": action_spec,
}
)
Expand All @@ -325,15 +321,13 @@ def _get_in_obs(self, obs):
def _get_out_obs(self, obs):
return obs

def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:
def _reset(self, tensordict: TensorDictBase = None) -> TensorDictBase:
self.counter += 1
state = torch.zeros(self.size) + self.counter
if tensordict is None:
tensordict = TensorDict({}, self.batch_size, device=self.device)
tensordict = tensordict.select().set(
"next_" + self.out_key, self._get_out_obs(state)
)
tensordict = tensordict.set("next_" + self._out_key, self._get_out_obs(state))
tensordict = tensordict.select().set(self.out_key, self._get_out_obs(state))
tensordict = tensordict.set(self._out_key, self._get_out_obs(state))
tensordict.set("done", torch.zeros(*tensordict.shape, 1, dtype=torch.bool))
return tensordict

Expand All @@ -351,8 +345,8 @@ def _step(
obs = self._get_in_obs(tensordict.get(self._out_key)) + a / self.maxstep
tensordict = tensordict.select() # empty tensordict

tensordict.set("next_" + self.out_key, self._get_out_obs(obs))
tensordict.set("next_" + self._out_key, self._get_out_obs(obs))
tensordict.set(self.out_key, self._get_out_obs(obs))
tensordict.set(self._out_key, self._get_out_obs(obs))

done = torch.isclose(obs, torch.ones_like(obs) * (self.counter + 1))
reward = done.any(-1).unsqueeze(-1)
Expand All @@ -379,10 +373,8 @@ def __new__(
if observation_spec is None:
cls.out_key = "observation"
observation_spec = CompositeSpec(
next_observation=NdUnboundedContinuousTensorSpec(
shape=torch.Size([size])
),
next_observation_orig=NdUnboundedContinuousTensorSpec(
observation=NdUnboundedContinuousTensorSpec(shape=torch.Size([size])),
observation_orig=NdUnboundedContinuousTensorSpec(
shape=torch.Size([size])
),
)
Expand All @@ -395,7 +387,7 @@ def __new__(
cls._out_key = "observation_orig"
input_spec = CompositeSpec(
**{
cls._out_key: observation_spec["next_observation"],
cls._out_key: observation_spec["observation"],
"action": action_spec,
}
)
Expand Down Expand Up @@ -436,8 +428,8 @@ def _step(
obs = self._obs_step(self._get_in_obs(tensordict.get(self._out_key)), a)
tensordict = tensordict.select() # empty tensordict

tensordict.set("next_" + self.out_key, self._get_out_obs(obs))
tensordict.set("next_" + self._out_key, self._get_out_obs(obs))
tensordict.set(self.out_key, self._get_out_obs(obs))
tensordict.set(self._out_key, self._get_out_obs(obs))

done = torch.isclose(obs, torch.ones_like(obs) * (self.counter + 1))
reward = done.any(-1).unsqueeze(-1)
Expand Down Expand Up @@ -483,10 +475,8 @@ def __new__(
if observation_spec is None:
cls.out_key = "pixels"
observation_spec = CompositeSpec(
next_pixels=NdUnboundedContinuousTensorSpec(
shape=torch.Size([1, 7, 7])
),
next_pixels_orig=NdUnboundedContinuousTensorSpec(
pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([1, 7, 7])),
pixels_orig=NdUnboundedContinuousTensorSpec(
shape=torch.Size([1, 7, 7])
),
)
Expand All @@ -499,7 +489,7 @@ def __new__(
cls._out_key = "pixels_orig"
input_spec = CompositeSpec(
**{
cls._out_key: observation_spec["next_pixels_orig"],
cls._out_key: observation_spec["pixels_orig"],
"action": action_spec,
}
)
Expand Down Expand Up @@ -537,10 +527,8 @@ def __new__(
if observation_spec is None:
cls.out_key = "pixels"
observation_spec = CompositeSpec(
next_pixels=NdUnboundedContinuousTensorSpec(
shape=torch.Size([7, 7, 3])
),
next_pixels_orig=NdUnboundedContinuousTensorSpec(
pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])),
pixels_orig=NdUnboundedContinuousTensorSpec(
shape=torch.Size([7, 7, 3])
),
)
Expand All @@ -555,7 +543,7 @@ def __new__(
cls._out_key = "pixels_orig"
input_spec = CompositeSpec(
**{
cls._out_key: observation_spec["next_pixels_orig"],
cls._out_key: observation_spec["pixels_orig"],
"action": action_spec,
}
)
Expand Down Expand Up @@ -599,10 +587,8 @@ def __new__(
if observation_spec is None:
cls.out_key = "pixels"
observation_spec = CompositeSpec(
next_pixels=NdUnboundedContinuousTensorSpec(
shape=torch.Size(pixel_shape)
),
next_pixels_orig=NdUnboundedContinuousTensorSpec(
pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size(pixel_shape)),
pixels_orig=NdUnboundedContinuousTensorSpec(
shape=torch.Size(pixel_shape)
),
)
Expand All @@ -615,7 +601,7 @@ def __new__(
if input_spec is None:
cls._out_key = "pixels_orig"
input_spec = CompositeSpec(
**{cls._out_key: observation_spec["next_pixels"], "action": action_spec}
**{cls._out_key: observation_spec["pixels"], "action": action_spec}
)
return super().__new__(
*args,
Expand Down Expand Up @@ -650,10 +636,8 @@ def __new__(
if observation_spec is None:
cls.out_key = "pixels"
observation_spec = CompositeSpec(
next_pixels=NdUnboundedContinuousTensorSpec(
shape=torch.Size([7, 7, 3])
),
next_pixels_orig=NdUnboundedContinuousTensorSpec(
pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])),
pixels_orig=NdUnboundedContinuousTensorSpec(
shape=torch.Size([7, 7, 3])
),
)
Expand Down Expand Up @@ -714,7 +698,7 @@ def __init__(
batch_size=batch_size,
)
self.observation_spec = CompositeSpec(
next_hidden_observation=NdUnboundedContinuousTensorSpec((4,))
hidden_observation=NdUnboundedContinuousTensorSpec((4,))
)
self.input_spec = CompositeSpec(
hidden_observation=NdUnboundedContinuousTensorSpec((4,)),
Expand All @@ -728,9 +712,6 @@ def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict:
"hidden_observation": self.input_spec["hidden_observation"].rand(
self.batch_size
),
"next_hidden_observation": self.observation_spec[
"next_hidden_observation"
].rand(self.batch_size),
},
batch_size=self.batch_size,
device=self.device,
Expand Down
Loading

0 comments on commit 0e3f066

Please sign in to comment.