diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index 13f87600eef..218fc3305af 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -29,6 +29,17 @@ python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_ python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 20 # With batched environments +python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \ + env.num_envs=1 \ + collector.total_frames=48 \ + collector.frames_per_batch=16 \ + collector.collector_device=cuda:0 \ + optim.device=cuda:0 \ + loss.mini_batch_size=10 \ + loss.ppo_epochs=1 \ + logger.backend= \ + logger.log_interval=4 \ + optim.lr_scheduler=False python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ total_frames=48 \ init_random_frames=10 \ @@ -86,17 +97,6 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ record_video=True \ record_frames=4 \ buffer_size=120 -python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \ - env.num_envs=1 \ - collector.total_frames=48 \ - collector.frames_per_batch=16 \ - collector.collector_device=cuda:0 \ - optim.device=cuda:0 \ - loss.mini_batch_size=10 \ - loss.ppo_epochs=1 \ - logger.backend= \ - logger.log_interval=4 \ - optim.lr_scheduler=False python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \ total_frames=200 \ init_random_frames=10 \ diff --git a/.github/workflows/nightly_build.yml b/.github/workflows/nightly_build.yml index d6d8cc07087..04a3085dda6 100644 --- a/.github/workflows/nightly_build.yml +++ b/.github/workflows/nightly_build.yml @@ -119,6 +119,9 @@ jobs: - name: Install test dependencies run: | python3 -mpip install numpy pytest --no-cache-dir + - name: Install tensordict + run: | + python3 -mpip install git+https://github.com/pytorch-labs/tensordict.git - name: Download built wheels uses: actions/download-artifact@v2 with: @@ -324,6 +327,9 @@ jobs: shell: bash run: | python3 -mpip install numpy pytest --no-cache-dir + - name: Install tensordict + run: | + python3 -mpip install git+https://github.com/pytorch-labs/tensordict.git - name: Download built wheels uses: actions/download-artifact@v2 with: diff --git a/examples/a2c/utils.py b/examples/a2c/utils.py index ff8d0f3ddab..ef094645c8f 100644 --- a/examples/a2c/utils.py +++ b/examples/a2c/utils.py @@ -107,13 +107,7 @@ def make_transformed_env_pixels(base_env, env_cfg): double_to_float_list += [ "reward", ] - double_to_float_list += [ - "action", - ] double_to_float_inv_list += ["action"] # DMControl requires double-precision - double_to_float_list += ["observation_vector"] - else: - double_to_float_list += ["observation_vector"] env.append_transform( DoubleToFloat( in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list @@ -152,9 +146,6 @@ def make_transformed_env_states(base_env, env_cfg): double_to_float_list += [ "reward", ] - double_to_float_list += [ - "action", - ] double_to_float_inv_list += ["action"] # DMControl requires double-precision double_to_float_list += ["observation_vector"] else: diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index eba26e89b86..a87259480a5 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -119,7 +119,6 @@ def make_env_transforms( if env_library is DMControlEnv: double_to_float_list += [ "reward", - "action", ] float_to_double_list += ["action"] # DMControl requires double-precision env.append_transform( diff --git a/examples/ppo/ppo.py b/examples/ppo/ppo.py index 01ee041d8a5..e1cab5bfb0d 100644 --- a/examples/ppo/ppo.py +++ b/examples/ppo/ppo.py @@ -72,10 +72,13 @@ def main(cfg: "DictConfig"): # noqa: F821 # Main loop r0 = None l0 = None + frame_skip = cfg.env.frame_skip + mini_batch_size = cfg.loss.mini_batch_size + ppo_epochs = cfg.loss.ppo_epochs for data in collector: frames_in_batch = data.numel() - collected_frames += frames_in_batch * cfg.env.frame_skip + collected_frames += frames_in_batch * frame_skip pbar.update(data.numel()) data_view = data.reshape(-1) @@ -93,8 +96,8 @@ def main(cfg: "DictConfig"): # noqa: F821 "reward_training", episode_rewards.mean().item(), collected_frames ) - for _ in range(cfg.loss.ppo_epochs): - for _ in range(frames_in_batch // cfg.loss.mini_batch_size): + for _ in range(ppo_epochs): + for _ in range(frames_in_batch // mini_batch_size): # Get a data batch batch = data_buffer.sample().to(model_device) diff --git a/examples/ppo/utils.py b/examples/ppo/utils.py index 49a648f480f..101adf874a8 100644 --- a/examples/ppo/utils.py +++ b/examples/ppo/utils.py @@ -108,13 +108,7 @@ def make_transformed_env_pixels(base_env, env_cfg): double_to_float_list += [ "reward", ] - double_to_float_list += [ - "action", - ] double_to_float_inv_list += ["action"] # DMControl requires double-precision - double_to_float_list += ["observation_vector"] - else: - double_to_float_list += ["observation_vector"] env.append_transform( DoubleToFloat( in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list @@ -153,9 +147,6 @@ def make_transformed_env_states(base_env, env_cfg): double_to_float_list += [ "reward", ] - double_to_float_list += [ - "action", - ] double_to_float_inv_list += ["action"] # DMControl requires double-precision double_to_float_list += ["observation_vector"] else: diff --git a/test/_utils_internal.py b/test/_utils_internal.py index a780be803d6..df3d54bfdf7 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -156,21 +156,27 @@ def create_env_fn(): return GymEnv(env_name, frame_skip=frame_skip, device=device) else: - if env_name == "ALE/Pong-v5": + if env_name == PONG_VERSIONED: def create_env_fn(): + base_env = GymEnv(env_name, frame_skip=frame_skip, device=device) + in_keys = list(base_env.observation_spec.keys(True, True))[:1] return TransformedEnv( - GymEnv(env_name, frame_skip=frame_skip, device=device), - Compose(*[ToTensorImage(), RewardClipping(0, 0.1)]), + base_env, + Compose(*[ToTensorImage(in_keys=in_keys), RewardClipping(0, 0.1)]), ) else: def create_env_fn(): + + base_env = GymEnv(env_name, frame_skip=frame_skip, device=device) + in_keys = list(base_env.observation_spec.keys(True, True))[:1] + return TransformedEnv( - GymEnv(env_name, frame_skip=frame_skip, device=device), + base_env, Compose( - ObservationNorm(in_keys=["observation"], loc=0.5, scale=1.1), + ObservationNorm(in_keys=in_keys, loc=0.5, scale=1.1), RewardClipping(0, 0.1), ), ) @@ -179,8 +185,14 @@ def create_env_fn(): env_parallel = ParallelEnv(N, create_env_fn, create_env_kwargs=kwargs) env_serial = SerialEnv(N, create_env_fn, create_env_kwargs=kwargs) + for key in env0.observation_spec.keys(True, True): + obs_key = key + break + else: + obs_key = None + if transformed_out: - t_out = get_transform_out(env_name, transformed_in) + t_out = get_transform_out(env_name, transformed_in, obs_key=obs_key) env0 = TransformedEnv( env0, @@ -223,7 +235,7 @@ def _make_multithreaded_env( torch.manual_seed(0) multithreaded_kwargs = ( - {"frame_skip": frame_skip} if env_name == "ALE/Pong-v5" else {} + {"frame_skip": frame_skip} if env_name == PONG_VERSIONED else {} ) env_multithread = MultiThreadedEnv( N, @@ -233,46 +245,53 @@ def _make_multithreaded_env( ) if transformed_out: + for key in env_multithread.observation_spec.keys(True, True): + obs_key = key + break + else: + obs_key = None env_multithread = TransformedEnv( env_multithread, - get_transform_out(env_name, transformed_in=False)(), + get_transform_out(env_name, transformed_in=False, obs_key=obs_key)(), ) return env_multithread -def get_transform_out(env_name, transformed_in): +def get_transform_out(env_name, transformed_in, obs_key=None): - if env_name == "ALE/Pong-v5": + if env_name == PONG_VERSIONED: + if obs_key is None: + obs_key = "pixels" def t_out(): return ( - Compose(*[ToTensorImage(), RewardClipping(0, 0.1)]) + Compose(*[ToTensorImage(in_keys=[obs_key]), RewardClipping(0, 0.1)]) if not transformed_in - else Compose(*[ObservationNorm(in_keys=["pixels"], loc=0, scale=1)]) + else Compose(*[ObservationNorm(in_keys=[obs_key], loc=0, scale=1)]) ) - elif env_name == "CheetahRun-v1": + elif env_name == HALFCHEETAH_VERSIONED: + if obs_key is None: + obs_key = ("observation", "velocity") def t_out(): return Compose( - ObservationNorm( - in_keys=[("observation", "velocity")], loc=0.5, scale=1.1 - ), + ObservationNorm(in_keys=[obs_key], loc=0.5, scale=1.1), RewardClipping(0, 0.1), ) else: + if obs_key is None: + obs_key = "observation" def t_out(): return ( Compose( - ObservationNorm(in_keys=["observation"], loc=0.5, scale=1.1), + ObservationNorm(in_keys=[obs_key], loc=0.5, scale=1.1), RewardClipping(0, 0.1), ) if not transformed_in - else Compose( - ObservationNorm(in_keys=["observation"], loc=1.0, scale=1.0) - ) + else Compose(ObservationNorm(in_keys=[obs_key], loc=1.0, scale=1.0)) ) return t_out diff --git a/test/test_libs.py b/test/test_libs.py index d69b6979f67..03578aeb82d 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -518,7 +518,7 @@ def test_jumanji_consistency(self, envname, batch_size): "Acrobot-v1", CARTPOLE_VERSIONED, ] -ENVPOOL_ATARI_ENVS = [PONG_VERSIONED] +ENVPOOL_ATARI_ENVS = [] # PONG_VERSIONED] ENVPOOL_GYM_ENVS = ENVPOOL_CLASSIC_CONTROL_ENVS + ENVPOOL_ATARI_ENVS ENVPOOL_DM_ENVS = ["CheetahRun-v1"] ENVPOOL_ALL_ENVS = ENVPOOL_GYM_ENVS + ENVPOOL_DM_ENVS @@ -558,6 +558,7 @@ def test_specs(self, env_name, frame_skip, transformed_out, T=10, N=3): def test_env_basic_operation( self, env_name, frame_skip, transformed_out, T=10, N=3 ): + torch.manual_seed(0) env_multithreaded = _make_multithreaded_env( env_name, frame_skip, @@ -737,7 +738,7 @@ def test_multithreaded_env_seed( # Check that results are different if seed is different # Skip Pong, since there different actions can lead to the same result - if env_name != "ALE/Pong-v5": + if env_name != PONG_VERSIONED: env.set_seed( seed=seed + 10, ) diff --git a/test/test_rb.py b/test/test_rb.py index 28890caa38d..437ff233114 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -840,14 +840,10 @@ def test_insert_transform(): def test_smoke_replay_buffer_transform(transform): rb = ReplayBuffer(transform=transform(in_keys="observation"), batch_size=1) + # td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1), "action": torch.randn(3)}, []) td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1)}, []) rb.add(td) - if not isinstance(rb._transform[0], (CatFrames,)): - rb.sample() - else: - with pytest.raises(NotImplementedError): - rb.sample() - return + rb.sample() rb._transform = mock.MagicMock() rb._transform.__len__ = lambda *args: 3 @@ -856,7 +852,7 @@ def test_smoke_replay_buffer_transform(transform): transforms = [ - partial(DiscreteActionProjection, num_actions_effective=1, max_actions=1), + partial(DiscreteActionProjection, num_actions_effective=1, max_actions=3), FiniteTensorDictCheck, gSDENoise, PinMemoryTransform, @@ -865,13 +861,15 @@ def test_smoke_replay_buffer_transform(transform): @pytest.mark.parametrize("transform", transforms) def test_smoke_replay_buffer_transform_no_inkeys(transform): - if PinMemoryTransform is PinMemoryTransform and not torch.cuda.is_available(): + if transform == PinMemoryTransform and not torch.cuda.is_available(): raise pytest.skip("No CUDA device detected, skipping PinMemory") rb = ReplayBuffer( collate_fn=lambda x: torch.stack(x, 0), transform=transform(), batch_size=1 ) - td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1)}, []) + action = torch.zeros(3) + action[..., 0] = 1 + td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1), "action": action}, []) rb.add(td) rb.sample() diff --git a/test/test_trainer.py b/test/test_trainer.py index a253fcc2f8b..236ff3a3a4b 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -22,6 +22,7 @@ except ImportError: _has_tb = False +from _utils_internal import PONG_VERSIONED from tensordict import TensorDict from torchrl.data import ( LazyMemmapStorage, @@ -836,7 +837,7 @@ def test_subsampler_state_dict(self): class TestRecorder: def _get_args(self): args = Namespace() - args.env_name = "ALE/Pong-v5" + args.env_name = PONG_VERSIONED args.env_task = "" args.grayscale = True args.env_library = "gym" @@ -894,7 +895,7 @@ def test_recorder(self, N=8): }, ) ea.Reload() - img = ea.Images("tmp_ALE/Pong-v5_video") + img = ea.Images(f"tmp_{PONG_VERSIONED}_video") try: assert len(img) == N // args.record_interval break diff --git a/test/test_transforms.py b/test/test_transforms.py index 2d0ca754713..6d0e2a6648c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -384,50 +384,77 @@ def test_transform_env_clone(self): ).all() assert cloned is not env.transform - def test_transform_model(self): - key1 = "first key" - key2 = "second key" - keys = [key1, key2] - dim = -2 - d = 4 - N = 3 - batch_size = (5,) - extra_d = (3,) * (-dim - 1) - device = "cpu" - key1_tensor = torch.ones(*batch_size, d, *extra_d, device=device) * 2 - key2_tensor = torch.ones(*batch_size, d, *extra_d, device=device) - key_tensors = [key1_tensor, key2_tensor] - td = TensorDict(dict(zip(keys, key_tensors)), batch_size, device=device) - cat_frames = CatFrames(N=N, in_keys=keys, dim=dim) + @pytest.mark.parametrize("dim", [-2, -1]) + @pytest.mark.parametrize("N", [3, 4]) + @pytest.mark.parametrize("padding", ["same", "zeros"]) + def test_transform_model(self, dim, N, padding): + # test equivalence between transforms within an env and within a rb + key1 = "observation" + keys = [key1] + out_keys = ["out_" + key1] + cat_frames = CatFrames( + N=N, in_keys=out_keys, out_keys=out_keys, dim=dim, padding=padding + ) + cat_frames2 = CatFrames( + N=N, + in_keys=keys + [("next", keys[0])], + out_keys=out_keys + [("next", out_keys[0])], + dim=dim, + padding=padding, + ) + envbase = ContinuousActionVecMockEnv() + env = TransformedEnv( + envbase, + Compose( + UnsqueezeTransform(dim, in_keys=keys, out_keys=out_keys), cat_frames + ), + ) + torch.manual_seed(10) + env.set_seed(10) + td = env.rollout(10) + torch.manual_seed(10) + envbase.set_seed(10) + tdbase = envbase.rollout(10) + + model = nn.Sequential(cat_frames2, nn.Identity()) + model(tdbase) + assert (td == tdbase).all() + + @pytest.mark.parametrize("dim", [-2, -1]) + @pytest.mark.parametrize("N", [3, 4]) + @pytest.mark.parametrize("padding", ["same", "zeros"]) + def test_transform_rb(self, dim, N, padding): + # test equivalence between transforms within an env and within a rb + key1 = "observation" + keys = [key1] + out_keys = ["out_" + key1] + cat_frames = CatFrames( + N=N, in_keys=out_keys, out_keys=out_keys, dim=dim, padding=padding + ) + cat_frames2 = CatFrames( + N=N, + in_keys=keys + [("next", keys[0])], + out_keys=out_keys + [("next", out_keys[0])], + dim=dim, + padding=padding, + ) - model = nn.Sequential(cat_frames, nn.Identity()) - with pytest.raises( - NotImplementedError, match="CatFrames cannot be called independently" - ): - model(td) + env = TransformedEnv( + ContinuousActionVecMockEnv(), + Compose( + UnsqueezeTransform(dim, in_keys=keys, out_keys=out_keys), cat_frames + ), + ) + td = env.rollout(10) - def test_transform_rb(self): - key1 = "first key" - key2 = "second key" - keys = [key1, key2] - dim = -2 - d = 4 - N = 3 - batch_size = (5,) - extra_d = (3,) * (-dim - 1) - device = "cpu" - key1_tensor = torch.ones(*batch_size, d, *extra_d, device=device) * 2 - key2_tensor = torch.ones(*batch_size, d, *extra_d, device=device) - key_tensors = [key1_tensor, key2_tensor] - td = TensorDict(dict(zip(keys, key_tensors)), batch_size, device=device) - cat_frames = CatFrames(N=N, in_keys=keys, dim=dim) rb = ReplayBuffer(storage=LazyTensorStorage(20)) - rb.append_transform(cat_frames) - rb.extend(td) - with pytest.raises( - NotImplementedError, match="CatFrames cannot be called independently" - ): - _ = rb.sample(10) + rb.append_transform(cat_frames2) + rb.add(td.exclude(*out_keys, ("next", out_keys[0]))) + tdsample = rb.sample(1).squeeze(0).exclude("index") + for key in td.keys(True, True): + assert (tdsample[key] == td[key]).all(), key + assert (tdsample["out_" + key1] == td["out_" + key1]).all() + assert (tdsample["next", "out_" + key1] == td["next", "out_" + key1]).all() def test_catframes_transform_observation_spec(self): N = 4 @@ -2823,6 +2850,7 @@ def make_env(): ContinuousActionVecMockEnv(), ObservationNorm( loc=torch.zeros(7), + in_keys=["observation"], scale=1.0, ), ) @@ -2838,6 +2866,7 @@ def make_env(): ContinuousActionVecMockEnv(), ObservationNorm( loc=torch.zeros(7), + in_keys=["observation"], scale=1.0, ), ) @@ -2852,6 +2881,7 @@ def test_trans_serial_env_check( SerialEnv(2, ContinuousActionVecMockEnv), ObservationNorm( loc=torch.zeros(7), + in_keys=["observation"], scale=1.0, ), ) @@ -2864,6 +2894,7 @@ def test_trans_parallel_env_check( ParallelEnv(2, ContinuousActionVecMockEnv), ObservationNorm( loc=torch.zeros(7), + in_keys=["observation"], scale=1.0, ), ) @@ -4341,11 +4372,12 @@ def test_transform_rb(self, out_keys, unsqueeze_dim): def test_transform_inverse(self): env = TransformedEnv( GymEnv(HALFCHEETAH_VERSIONED), + # the order is inverted Compose( - SqueezeTransform(-1, in_keys_inv=["action"], out_keys_inv=["action_t"]), UnsqueezeTransform( -1, in_keys_inv=["action_t"], out_keys_inv=["action"] ), + SqueezeTransform(-1, in_keys_inv=["action"], out_keys_inv=["action_t"]), ), ) td = env.rollout(3) @@ -4452,8 +4484,8 @@ def _circular_transform(self): @property def _inv_circular_transform(self): return Compose( - SqueezeTransform(-1, in_keys_inv=["action"], out_keys_inv=["action_un"]), UnsqueezeTransform(-1, in_keys_inv=["action_un"], out_keys_inv=["action"]), + SqueezeTransform(-1, in_keys_inv=["action"], out_keys_inv=["action_un"]), ) def test_single_trans_env_check(self): @@ -6210,7 +6242,8 @@ def test_compose(self, keys, batch, device, nchannels=1, N=4): dim=-3, ) t2 = FiniteTensorDictCheck() - compose = Compose(t1, t2) + t3 = StepCounter() + compose = Compose(t1, t2, t3) dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device) td = TensorDict( { @@ -6221,10 +6254,15 @@ def test_compose(self, keys, batch, device, nchannels=1, N=4): device=device, ) td.set("dont touch", dont_touch.clone()) + if not batch: + with pytest.raises( + ValueError, match="The last dimension of the tensordict" + ): + compose(td.clone(False)) with pytest.raises( - NotImplementedError, match="CatFrames cannot be called independently" + NotImplementedError, match="StepCounter cannot be called independently" ): - compose(td.clone(False)) + compose[1:](td.clone(False)) compose._call(td) for key in keys: assert td.get(key).shape[-3] == nchannels * N @@ -6232,7 +6270,8 @@ def test_compose(self, keys, batch, device, nchannels=1, N=4): if len(keys) == 1: observation_spec = BoundedTensorSpec(0, 255, (nchannels, 16, 16)) - observation_spec = compose.transform_observation_spec(observation_spec) + # StepCounter does not want non composite specs + observation_spec = compose[:2].transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([nchannels * N, 16, 16]) else: observation_spec = CompositeSpec( @@ -6466,7 +6505,7 @@ def test_batch_locked_transformed(device): env = TransformedEnv( MockBatchedLockedEnv(device), Compose( - ObservationNorm(in_keys=[("next", "observation")], loc=0.5, scale=1.1), + ObservationNorm(in_keys=["observation"], loc=0.5, scale=1.1), RewardClipping(0, 0.1), ), ) @@ -6490,7 +6529,7 @@ def test_batch_unlocked_transformed(device): env = TransformedEnv( MockBatchedUnLockedEnv(device), Compose( - ObservationNorm(in_keys=[("next", "observation")], loc=0.5, scale=1.1), + ObservationNorm(in_keys=["observation"], loc=0.5, scale=1.1), RewardClipping(0, 0.1), ), ) @@ -6510,7 +6549,7 @@ def test_batch_unlocked_with_batch_size_transformed(device): env = TransformedEnv( MockBatchedUnLockedEnv(device, batch_size=torch.Size([2])), Compose( - ObservationNorm(in_keys=[("next", "observation")], loc=0.5, scale=1.1), + ObservationNorm(in_keys=["observation"], loc=0.5, scale=1.1), RewardClipping(0, 0.1), ), ) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ea0aabd0a12..5eece6b11dd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -14,6 +14,7 @@ import torch from tensordict.nn import dispatch +from tensordict.nn.common import _seq_of_nested_key_check from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import expand_as_right from torch import nn, Tensor @@ -148,6 +149,7 @@ def __init__( if out_keys_inv is None: out_keys_inv = copy(self.in_keys_inv) self.out_keys_inv = out_keys_inv + self._missing_tolerance = False self.__dict__["_container"] = None self.__dict__["_parent"] = None @@ -184,25 +186,30 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: """ for in_key, out_key in zip(self.in_keys, self.out_keys): - is_tuple = isinstance(in_key, tuple) - if in_key in tensordict.keys(include_nested=is_tuple): + if in_key in tensordict.keys(include_nested=True): observation = self._apply_transform(tensordict.get(in_key)) tensordict.set( out_key, observation, ) + elif not self.missing_tolerance: + raise KeyError( + f"{self}: '{in_key}' not found in tensordict {tensordict}" + ) return tensordict @dispatch(source="in_keys", dest="out_keys") def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Reads the input tensordict, and for the selected keys, applies the transform.""" for in_key, out_key in zip(self.in_keys, self.out_keys): - if in_key in tensordict.keys(isinstance(in_key, tuple)): + if in_key in tensordict.keys(include_nested=True): observation = self._apply_transform(tensordict.get(in_key)) tensordict.set( out_key, observation, ) + elif not self.missing_tolerance: + raise KeyError(f"'{in_key}' not found in tensordict {tensordict}") return tensordict def _step(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -234,12 +241,15 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: # # exposed to the user: we'd like that the input keys remain unchanged # # in the originating script if they're being transformed. for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): - if in_key in tensordict.keys(include_nested=isinstance(in_key, tuple)): + if in_key in tensordict.keys(include_nested=True): item = self._inv_apply_transform(tensordict.get(in_key)) tensordict.set( out_key, item, ) + elif not self.missing_tolerance: + raise KeyError(f"'{in_key}' not found in tensordict {tensordict}") + return tensordict @dispatch(source="in_keys_inv", dest="out_keys_inv") @@ -388,6 +398,13 @@ def parent(self) -> Optional[EnvBase]: def empty_cache(self): self.__dict__["_parent"] = None + def set_missing_tolerance(self, mode=False): + self._missing_tolerance = mode + + @property + def missing_tolerance(self): + return self._missing_tolerance + class TransformedEnv(EnvBase): """A transformed_in environment. @@ -584,7 +601,11 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): tensordict = tensordict.clone(recurse=False) out_tensordict = self.base_env.reset(tensordict=tensordict, **kwargs) out_tensordict = self.transform.reset(out_tensordict) + + mt_mode = self.transform.missing_tolerance + self.set_missing_tolerance(True) out_tensordict = self.transform._call(out_tensordict) + self.set_missing_tolerance(mt_mode) return out_tensordict def state_dict(self, *args, **kwargs) -> OrderedDict: @@ -710,6 +731,10 @@ def __del__(self): # transformed env and that we don't want to close pass + def set_missing_tolerance(self, mode=False): + """Indicates if an KeyError should be raised whenever an in_key is missing from the input tensordict.""" + self.transform.set_missing_tolerance(mode) + class ObservationTransform(Transform): """Abstract class for transformations of the observations.""" @@ -877,6 +902,11 @@ def clone(self): transforms.append(t.clone()) return Compose(*transforms) + def set_missing_tolerance(self, mode=False): + for t in self.transforms: + t.set_missing_tolerance(mode) + super().set_missing_tolerance(mode) + class ToTensorImage(ObservationTransform): """Transforms a numpy-like image (3 x W x H) to a pytorch image (3 x W x H). @@ -1052,12 +1082,13 @@ def reset(self, tensordict: TensorDict): def _call(self, tensordict: TensorDict) -> TensorDict: for in_key, out_key in zip(self.in_keys, self.out_keys): - is_tuple = isinstance(in_key, tuple) - if in_key in tensordict.keys(include_nested=is_tuple): + if in_key in tensordict.keys(include_nested=True): target_return = self._apply_transform( tensordict.get(in_key), tensordict.get(out_key) ) tensordict.set(out_key, target_return) + elif not self.missing_tolerance: + raise KeyError(f"'{in_key}' not found in tensordict {tensordict}") return tensordict def _step(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -1421,7 +1452,7 @@ def __init__( out_keys_inv: Optional[Sequence[str]] = None, ): if in_keys is None: - in_keys = IMAGE_KEYS # default + in_keys = [] # default super().__init__( in_keys=in_keys, out_keys=out_keys, @@ -1836,8 +1867,10 @@ class CatFrames(ObservationTransform): feature. Proposed in "Playing Atari with Deep Reinforcement Learning" ( https://arxiv.org/pdf/1312.5602.pdf). - CatFrames is a stateful class and it can be reset to its native state by - calling the `reset()` method. + When used within a transformed environment, + :class:`CatFrames` is a stateful class, and it can be reset to its native state by + calling the :meth:`~.reset` method. This method accepts tensordicts with a + ``"_reset"`` entry that indicates which buffer to reset. Args: N (int): number of observation to concatenate. @@ -1848,6 +1881,63 @@ class CatFrames(ObservationTransform): to be concatenated. Defaults to ["pixels"]. out_keys (list of int, optional): keys pointing to where the output has to be written. Defaults to the value of `in_keys`. + padding (str, optional): the padding method. One of ``"same"`` or ``"zeros"``. + Defaults to ``"same"``, ie. the first value is uesd for padding. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> env = TransformedEnv(GymEnv('Pendulum-v1'), + ... Compose( + ... UnsqueezeTransform(-1, in_keys=["observation"]), + ... CatFrames(N=4, dim=-1, in_keys=["observation"]), + ... ) + ... ) + >>> print(env.rollout(3)) + + The :class:`CatFrames` transform can also be used offline to reproduce the + effect of the online frame concatenation at a different scale (or for the + purpose of limiting the memory consumption). The followin example + gives the complete picture, together with the usage of a :class:`torchrl.data.ReplayBuffer`: + + Examples: + >>> from torchrl.envs import UnsqueezeTransform, CatFrames + >>> from torchrl.collectors import SyncDataCollector, RandomPolicy + >>> # Create a transformed environment with CatFrames: notice the usage of UnsqueezeTransform to create an extra dimension + >>> env = TransformedEnv( + ... GymEnv("CartPole-v1", from_pixels=True), + ... Compose( + ... ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]), + ... Resize(in_keys=["pixels_trsf"], w=64, h=64), + ... GrayScale(in_keys=["pixels_trsf"]), + ... UnsqueezeTransform(-4, in_keys=["pixels_trsf"]), + ... CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]), + ... ) + ... ) + >>> # we design a collector + >>> collector = SyncDataCollector( + ... env, + ... RandomPolicy(env.action_spec), + ... frames_per_batch=10, + ... total_frames=1000, + ... ) + >>> for data in collector: + ... print(data) + ... break + >>> # now let's create a transform for the replay buffer. We don't need to unsqueeze the data here. + >>> # however, we need to point to both the pixel entry at the root and at the next levels: + >>> t = Compose( + ... ToTensorImage(in_keys=["pixels", ("next", "pixels")], out_keys=["pixels_trsf", ("next", "pixels_trsf")]), + ... Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64), + ... GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]), + ... CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), + ... ) + >>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage + >>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16) + >>> data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf")) + >>> rb.add(data_exclude) + >>> s = rb.sample(1) # the buffer has only one element + >>> # let's check that our sample is the same as the batch collected during inference + >>> assert (data.exclude("collector")==s.squeeze(0).exclude("index", "collector")).all() """ @@ -1856,6 +1946,7 @@ class CatFrames(ObservationTransform): "dim must be > 0 to accomodate for tensordict of " "different batch-sizes (since negative dims are batch invariant)." ) + ACCEPTED_PADDING = {"same", "zeros"} def __init__( self, @@ -1863,6 +1954,7 @@ def __init__( dim: int, in_keys: Optional[Sequence[str]] = None, out_keys: Optional[Sequence[str]] = None, + padding="same", ): if in_keys is None: in_keys = IMAGE_KEYS @@ -1871,6 +1963,9 @@ def __init__( if dim > 0: raise ValueError(self._CAT_DIM_ERR) self.dim = dim + if padding not in self.ACCEPTED_PADDING: + raise ValueError(f"padding must be one of {self.ACCEPTED_PADDING}") + self.padding = padding for in_key in self.in_keys: buffer_name = f"_cat_buffers_{in_key}" setattr( @@ -1939,9 +2034,15 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: data_in = buffer[_reset] shape = [1 for _ in data_in.shape] shape[self.dim] = self.N - buffer[_reset] = buffer[_reset].copy_( - data[_reset].repeat(shape).clone() - ) + if self.padding == "same": + buffer[_reset] = buffer[_reset].copy_( + data[_reset].repeat(shape).clone() + ) + elif self.padding == "zeros": + buffer[_reset] = 0 + else: + # make linter happy. An exception has already been raised + raise NotImplementedError buffer.copy_(torch.roll(buffer, shifts=-d, dims=self.dim)) # add new obs idx = self.dim @@ -1970,14 +2071,76 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec return observation_spec def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - raise NotImplementedError( - "CatFrames cannot be called independently, only its step and reset methods " - "are functional. The reason for this is that it is hard to consider using " - "CatFrames with non-sequential data, such as those collected by a replay buffer " - "or a dataset. If you need CatFrames to work on a batch of sequential data " - "(ie as LSTM would work over a whole sequence of data), file an issue on " - "TorchRL requesting that feature." + # it is assumed that the last dimension of the tensordict is the time dimension + if not tensordict.ndim or ( + tensordict.names[-1] is not None and tensordict.names[-1] != "time" + ): + raise ValueError( + "The last dimension of the tensordict must be marked as 'time'." + ) + # first sort the in_keys with strings and non-strings + in_keys = list( + zip( + (in_key, out_key) + for in_key, out_key in zip(self.in_keys, self.out_keys) + if isinstance(in_key, str) or len(in_key) == 1 + ) ) + in_keys += list( + zip( + (in_key, out_key) + for in_key, out_key in zip(self.in_keys, self.out_keys) + if not isinstance(in_key, str) and not len(in_key) == 1 + ) + ) + for in_key, out_key in zip(self.in_keys, self.out_keys): + # check if we have an obs in "next" that has already been processed. + # If so, we must add an offset + data = tensordict.get(in_key) + if isinstance(in_key, tuple) and in_key[0] == "next": + + # let's get the out_key we have already processed + prev_out_key = dict(zip(self.in_keys, self.out_keys))[in_key[1]] + prev_val = tensordict.get(prev_out_key) + # the first item is located along `dim+1` at the last index of the + # first time index + idx = ( + [slice(None)] * (tensordict.ndim - 1) + + [0] + + [..., -1] + + [slice(None)] * (abs(self.dim) - 1) + ) + first_val = prev_val[tuple(idx)].unsqueeze(tensordict.ndim - 1) + data0 = [first_val] * (self.N - 1) + if self.padding == "zeros": + data0 = [torch.zeros_like(elt) for elt in data0[:-1]] + data0[-1:] + elif self.padding == "same": + pass + else: + # make linter happy. An exception has already been raised + raise NotImplementedError + elif self.padding == "same": + idx = [slice(None)] * (tensordict.ndim - 1) + [0] + data0 = [data[tuple(idx)].unsqueeze(tensordict.ndim - 1)] * (self.N - 1) + elif self.padding == "zeros": + idx = [slice(None)] * (tensordict.ndim - 1) + [0] + data0 = [ + torch.zeros_like(data[tuple(idx)]).unsqueeze(tensordict.ndim - 1) + ] * (self.N - 1) + else: + # make linter happy. An exception has already been raised + raise NotImplementedError + + data = torch.cat(data0 + [data], tensordict.ndim - 1) + + data = data.unfold(tensordict.ndim - 1, self.N, 1) + data = data.permute( + *range(0, data.ndim + self.dim), + -1, + *range(data.ndim + self.dim, data.ndim - 1), + ) + tensordict.set(out_key, data) + return tensordict def __repr__(self) -> str: return ( @@ -3101,8 +3264,8 @@ def __init__( ): """Initialises the transform. Filters out non-reward input keys and defines output keys.""" if in_keys is None: - in_keys = [("next", "reward")] - if out_keys is None and in_keys == [("next", "reward")]: + in_keys = ["reward"] + if out_keys is None and in_keys == ["reward"]: out_keys = ["episode_reward"] elif out_keys is None: raise RuntimeError( @@ -3129,7 +3292,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict[out_key] = value.masked_fill( expand_as_right(_reset, value), 0.0 ) - elif in_key == ("next", "reward"): + elif in_key in ("reward", ("reward",)): # Since the episode reward is not in the tensordict, we need to allocate it # with zeros entirely (regardless of the _reset mask) tensordict[out_key] = self.parent.reward_spec.zero() @@ -3144,18 +3307,21 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: f"observation_spec with keys " f"{list(self.parent.observation_spec.keys(True))}. " ) from err - return tensordict def _step(self, tensordict: TensorDictBase) -> TensorDictBase: """Updates the episode rewards with the step rewards.""" # Update episode rewards + next_tensordict = tensordict.get("next") for in_key, out_key in zip(self.in_keys, self.out_keys): - if in_key in tensordict.keys(isinstance(in_key, tuple)): - reward = tensordict.get(in_key) - if out_key not in tensordict.keys(): - tensordict.set(("next", out_key), torch.zeros_like(reward)) - tensordict["next", out_key] = tensordict[out_key] + reward + if in_key in next_tensordict.keys(include_nested=True): + reward = next_tensordict.get(in_key) + if out_key not in tensordict.keys(True): + tensordict.set(out_key, torch.zeros_like(reward)) + next_tensordict.set(out_key, tensordict.get(out_key) + reward) + elif not self.missing_tolerance: + raise KeyError(f"'{in_key}' not found in tensordict {tensordict}") + tensordict.set("next", next_tensordict) return tensordict def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: @@ -3180,7 +3346,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec else: # If reward_spec is not a CompositeSpec, the only in_key should be ´reward´ - if set(self.in_keys) != {("next", "reward")}: + if set(self.in_keys) != {"reward"}: raise KeyError( "reward_spec is not a CompositeSpec class, in_keys should only include ´reward´" ) @@ -3334,8 +3500,12 @@ class ExcludeTransform(Transform): def __init__(self, *excluded_keys): super().__init__(in_keys=[], in_keys_inv=[], out_keys=[], out_keys_inv=[]) - if not all(isinstance(item, str) for item in excluded_keys): - raise ValueError("excluded_keys must be a list or tuple of strings.") + try: + _seq_of_nested_key_check(excluded_keys) + except ValueError: + raise ValueError( + "excluded keys must be a list or tuple of strings or tuples of strings." + ) self.excluded_keys = excluded_keys if "reward" in excluded_keys: raise RuntimeError("'reward' cannot be excluded from the keys.") @@ -3376,8 +3546,12 @@ class SelectTransform(Transform): def __init__(self, *selected_keys): super().__init__(in_keys=[], in_keys_inv=[], out_keys=[], out_keys_inv=[]) - if not all(isinstance(item, str) for item in selected_keys): - raise ValueError("excluded_keys must be a list or tuple of strings.") + try: + _seq_of_nested_key_check(selected_keys) + except ValueError: + raise ValueError( + "selected keys must be a list or tuple of strings or tuples of strings." + ) self.selected_keys = selected_keys def _call(self, tensordict: TensorDictBase) -> TensorDictBase: