diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index a175f05662a..5ba834d35d8 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -88,9 +88,7 @@ conda deactivate conda activate "${env_dir}" echo "installing gymnasium" -pip3 install "gymnasium" -pip3 install ale_py -pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py +pip3 install "gymnasium[atari,accept-rom-license,mujoco]<1.0" mo-gymnasium[mujoco] pip3 install "mujoco" -U # sanity check: remove? diff --git a/.github/unittest/linux_distributed/scripts/setup_env.sh b/.github/unittest/linux_distributed/scripts/setup_env.sh index 501dbe1c914..2a48ab21459 100755 --- a/.github/unittest/linux_distributed/scripts/setup_env.sh +++ b/.github/unittest/linux_distributed/scripts/setup_env.sh @@ -119,7 +119,7 @@ if [[ $OSTYPE != 'darwin'* ]]; then rm ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl fi echo "installing gymnasium" - pip install "gymnasium[atari,accept-rom-license]" + pip install "gymnasium[atari,accept-rom-license]<1.0" else - pip install "gymnasium[atari,accept-rom-license]" + pip install "gymnasium[atari,accept-rom-license]<1.0" fi diff --git a/.github/unittest/linux_examples/scripts/run_all.sh b/.github/unittest/linux_examples/scripts/run_all.sh index 1a713ce6870..073ef59ed3f 100755 --- a/.github/unittest/linux_examples/scripts/run_all.sh +++ b/.github/unittest/linux_examples/scripts/run_all.sh @@ -130,7 +130,7 @@ elif [[ $PY_VERSION == *"3.11"* ]]; then pip install ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl rm ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl fi -pip install "gymnasium[atari,accept-rom-license]" +pip install "gymnasium[atari,accept-rom-license]<1.0" # ============================================================================================ # # ================================ PyTorch & TorchRL ========================================= # diff --git a/.github/unittest/linux_libs/scripts_envpool/setup_env.sh b/.github/unittest/linux_libs/scripts_envpool/setup_env.sh index bb5c09079ea..aabc153bde3 100755 --- a/.github/unittest/linux_libs/scripts_envpool/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_envpool/setup_env.sh @@ -82,9 +82,9 @@ if [[ $OSTYPE != 'darwin'* ]]; then fi echo "installing gym" # envpool does not currently work with gymnasium - pip install "gym[atari,accept-rom-license]" + pip install "gym[atari,accept-rom-license]<1.0" else - pip install "gym[atari,accept-rom-license]" + pip install "gym[atari,accept-rom-license]<1.0" fi pip install envpool treevalue diff --git a/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh index 9622984a421..dc264e07b2d 100755 --- a/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh +++ b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh @@ -140,7 +140,7 @@ conda deactivate conda create --prefix ./cloned_env --clone ./env -y conda activate ./cloned_env -pip3 install 'gymnasium[accept-rom-license,ale-py,atari]' mo-gymnasium gymnasium-robotics -U +pip3 install 'gymnasium[accept-rom-license,ale-py,atari]<1.0' mo-gymnasium gymnasium-robotics -U $DIR/run_test.sh diff --git a/.github/unittest/linux_libs/scripts_robohive/environment.yml b/.github/unittest/linux_libs/scripts_robohive/environment.yml index cff88245d1e..4b6e4ef4f0e 100644 --- a/.github/unittest/linux_libs/scripts_robohive/environment.yml +++ b/.github/unittest/linux_libs/scripts_robohive/environment.yml @@ -6,7 +6,7 @@ dependencies: - protobuf - pip: # Initial version is required to install Atari ROMS in setup_env.sh - - gymnasium + - gymnasium<1.0 - hypothesis - future - cloudpickle diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index f5fa29ab7ca..e153641e775 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -3,6 +3,7 @@ name: Generate documentation on: push: branches: + - nightly - main - release/* tags: @@ -21,7 +22,7 @@ jobs: build-docs: strategy: matrix: - python_version: ["3.9"] + python_version: ["3.10"] cuda_arch_version: ["12.1"] uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: @@ -33,7 +34,7 @@ jobs: script: | set -e set -v - apt-get update && apt-get install -y git wget gcc g++ + apt-get update && apt-get install -y -f git wget gcc g++ dialog apt-utils root_dir="$(pwd)" conda_dir="${root_dir}/conda" env_dir="${root_dir}/env" @@ -45,14 +46,14 @@ jobs: bash ./miniconda.sh -b -f -p "${conda_dir}" eval "$(${conda_dir}/bin/conda shell.bash hook)" printf "* Creating a test environment\n" - conda create --prefix "${env_dir}" -y python=3.8 + conda create --prefix "${env_dir}" -y python=3.10 printf "* Activating\n" conda activate "${env_dir}" - + # 2. upgrade pip, ninja and packaging - # apt-get install python3.9 python3-pip -y + apt-get install python3-pip unzip -y -f python3 -m pip install --upgrade pip - python3 -m pip install setuptools ninja packaging -U + python3 -m pip install setuptools ninja packaging cmake -U # 3. check python version python3 --version diff --git a/docs/requirements.txt b/docs/requirements.txt index 258cff086ed..702a2884421 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -16,9 +16,7 @@ sphinx_design torchvision dm_control mujoco -atari-py -ale-py -gym[classic_control,accept-rom-license] +gym[classic_control,accept-rom-license,ale-py,atari] pygame tqdm ipython diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index afef09aa312..3578cbfd79f 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -996,11 +996,9 @@ Helpers RandomPolicy check_env_specs - exploration_mode #deprecated exploration_type get_available_libraries make_composite_from_td - set_exploration_mode #deprecated set_exploration_type step_mdp terminated_or_truncated diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 2d6a6344970..e1642868228 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -62,13 +62,13 @@ Exploration wrappers and modules To efficiently explore the environment, TorchRL proposes a series of modules that will override the action sampled by the policy by a noisier version. -Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_mode`: -if the exploration is set to ``"random"``, the exploration is active. In all +Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_type`: +if the exploration is set to ``ExplorationType.RANDOM``, the exploration is active. In all other cases, the action written in the tensordict is simply the network output. .. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule` uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch. - The :func:`~torchrl.envs.utils.set_exploration_mode` context manager will have no effect on + The :func:`~torchrl.envs.utils.set_exploration_type` context manager will have no effect on this module. .. currentmodule:: torchrl.modules diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/examples/distributed/collectors/multi_nodes/ray_train.py index b05e92619fa..5697d88dc61 100644 --- a/examples/distributed/collectors/multi_nodes/ray_train.py +++ b/examples/distributed/collectors/multi_nodes/ray_train.py @@ -26,7 +26,7 @@ TransformedEnv, ) from torchrl.envs.libs.gym import GymEnv -from torchrl.envs.utils import check_env_specs, set_exploration_mode +from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator from torchrl.objectives import ClipPPOLoss from torchrl.objectives.value import GAE @@ -85,8 +85,8 @@ in_keys=["loc", "scale"], distribution_class=TanhNormal, distribution_kwargs={ - "min": env.action_spec.space.low, - "max": env.action_spec.space.high, + "low": env.action_spec.space.low, + "high": env.action_spec.space.high, }, return_log_prob=True, ) @@ -201,7 +201,7 @@ stepcount_str = f"step count (max): {logs['step_count'][-1]}" logs["lr"].append(optim.param_groups[0]["lr"]) lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}" - with set_exploration_mode("mean"), torch.no_grad(): + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): # execute a rollout with the trained policy eval_rollout = env.rollout(1000, policy_module) logs["eval reward"].append(eval_rollout["next", "reward"].mean().item()) diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 409833c75fa..ee2cc6e424c 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -38,7 +38,7 @@ ) from torchrl.envs.libs.dm_control import DMControlEnv from torchrl.envs.libs.gym import set_gym_backend -from torchrl.envs.utils import set_exploration_mode +from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( DTActor, OnlineDTActor, @@ -374,13 +374,12 @@ def make_odt_model(cfg): module=actor_module, distribution_class=dist_class, distribution_kwargs=dist_kwargs, - default_interaction_mode="random", cache_dist=False, return_log_prob=False, ) # init the lazy layers - with torch.no_grad(), set_exploration_mode("random"): + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): td = proof_environment.rollout(max_steps=100) td["action"] = td["next", "action"] actor(td) @@ -428,13 +427,12 @@ def make_dt_model(cfg): module=actor_module, distribution_class=dist_class, distribution_kwargs=dist_kwargs, - default_interaction_mode="random", cache_dist=False, return_log_prob=False, ) # init the lazy layers - with torch.no_grad(), set_exploration_mode("random"): + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): td = proof_environment.rollout(max_steps=100) td["action"] = td["next", "action"] actor(td) diff --git a/sota-implementations/redq/config.yaml b/sota-implementations/redq/config.yaml index e60191c0f93..818f3386fda 100644 --- a/sota-implementations/redq/config.yaml +++ b/sota-implementations/redq/config.yaml @@ -36,7 +36,6 @@ collector: multi_step: 1 n_steps_return: 3 max_frames_per_traj: -1 - exploration_mode: random logger: backend: wandb diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index dd922372cbb..8312d359366 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -1021,7 +1021,6 @@ def make_collector_offpolicy( "init_random_frames": cfg.collector.init_random_frames, "split_trajs": True, # trajectories must be separated if multi-step is used - "exploration_type": ExplorationType.from_str(cfg.collector.exploration_mode), } collector = collector_helper(**collector_helper_kwargs) diff --git a/test/test_actors.py b/test/test_actors.py index 439094e922a..b81f322b708 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -54,8 +54,8 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions= out_keys=[("data", "action")], distribution_class=TanhDelta, distribution_kwargs={ - "min": action_spec.space.low, - "max": action_spec.space.high, + "low": action_spec.space.low, + "high": action_spec.space.high, }, log_prob_key=log_prob_key, return_log_prob=True, @@ -77,8 +77,8 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions= out_keys=[("data", "action")], distribution_class=TanhDelta, distribution_kwargs={ - "min": action_spec.space.low, - "max": action_spec.space.high, + "low": action_spec.space.low, + "high": action_spec.space.high, }, log_prob_key=log_prob_key, return_log_prob=True, diff --git a/test/test_distributions.py b/test/test_distributions.py index 53bfda343a2..8a5b651531e 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -190,8 +190,8 @@ def test_truncnormal(self, min, max, vecs, upscale, shape, device): d = TruncatedNormal( *vecs, upscale=upscale, - min=min, - max=max, + low=min, + high=max, ) assert d.device == device for _ in range(100): @@ -218,7 +218,7 @@ def test_truncnormal_against_scipy(self): high = 2 low = -1 log_pi_x = TruncatedNormal( - mu, sigma, min=low, max=high, tanh_loc=False + mu, sigma, low=low, high=high, tanh_loc=False ).log_prob(x) pi_x = torch.exp(log_pi_x) log_pi_x.backward(torch.ones_like(log_pi_x)) @@ -264,8 +264,8 @@ def test_truncnormal_mode(self, min, max, vecs, upscale, shape, device): d = TruncatedNormal( *vecs, upscale=upscale, - min=min, - max=max, + low=min, + high=max, ) assert d.mode is not None assert d.entropy() is not None diff --git a/test/test_libs.py b/test/test_libs.py index 87c69bf000c..6fc2979607d 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3065,7 +3065,7 @@ def test_atari_preproc(self, dataset_id, tmpdir): t = Compose( UnsqueezeTransform( - unsqueeze_dim=-3, in_keys=["observation", ("next", "observation")] + dim=-3, in_keys=["observation", ("next", "observation")] ), Resize(32, in_keys=["observation", ("next", "observation")]), RenameTransform(in_keys=["action"], out_keys=["other_action"]), diff --git a/test/test_rb.py b/test/test_rb.py index 34b34b5b486..24b33f89795 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -1776,10 +1776,8 @@ def test_insert_transform(self): not _has_tv, reason="needs torchvision dependency" ), ), - pytest.param( - partial(UnsqueezeTransform, unsqueeze_dim=-1), id="UnsqueezeTransform" - ), - pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"), + pytest.param(partial(UnsqueezeTransform, dim=-1), id="UnsqueezeTransform"), + pytest.param(partial(SqueezeTransform, dim=-1), id="SqueezeTransform"), GrayScale, pytest.param(partial(ObservationNorm, loc=1, scale=2), id="ObservationNorm"), pytest.param(partial(CatFrames, dim=-3, N=4), id="CatFrames"), diff --git a/test/test_transforms.py b/test/test_transforms.py index 589c32809cc..55b9a73e054 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -5627,7 +5627,7 @@ def test_transform_model(self): class TestUnsqueezeTransform(TransformBase): - @pytest.mark.parametrize("unsqueeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -5635,14 +5635,10 @@ class TestUnsqueezeTransform(TransformBase): "keys", [["observation", ("some_other", "nested_key")], ["observation_pixels"]] ) @pytest.mark.parametrize("device", get_default_devices()) - def test_transform_no_env( - self, keys, size, nchannels, batch, device, unsqueeze_dim - ): + def test_transform_no_env(self, keys, size, nchannels, batch, device, dim): torch.manual_seed(0) dont_touch = torch.randn(*batch, *size, nchannels, 16, 16, device=device) - unsqueeze = UnsqueezeTransform( - unsqueeze_dim, in_keys=keys, allow_positive_dim=True - ) + unsqueeze = UnsqueezeTransform(dim, in_keys=keys, allow_positive_dim=True) td = TensorDict( { key: torch.randn(*batch, *size, nchannels, 16, 16, device=device) @@ -5652,16 +5648,16 @@ def test_transform_no_env( device=device, ) td.set("dont touch", dont_touch.clone()) - if unsqueeze_dim >= 0 and unsqueeze_dim < len(batch): + if dim >= 0 and dim < len(batch): with pytest.raises(RuntimeError, match="batch dimension mismatch"): unsqueeze(td) return unsqueeze(td) expected_size = [*batch, *size, nchannels, 16, 16] - if unsqueeze_dim < 0: - expected_size.insert(len(expected_size) + unsqueeze_dim + 1, 1) + if dim < 0: + expected_size.insert(len(expected_size) + dim + 1, 1) else: - expected_size.insert(unsqueeze_dim, 1) + expected_size.insert(dim, 1) expected_size = torch.Size(expected_size) for key in keys: @@ -5669,7 +5665,7 @@ def test_transform_no_env( batch, size, nchannels, - unsqueeze_dim, + dim, ) assert (td.get("dont touch") == dont_touch).all() @@ -5688,7 +5684,7 @@ def test_transform_no_env( for key in keys: assert observation_spec[key].shape == expected_size - @pytest.mark.parametrize("unsqueeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -5704,13 +5700,11 @@ def test_transform_no_env( [("next", "observation_pixels")], ], ) - def test_unsqueeze_inv( - self, keys, keys_inv, size, nchannels, batch, device, unsqueeze_dim - ): + def test_unsqueeze_inv(self, keys, keys_inv, size, nchannels, batch, device, dim): torch.manual_seed(0) keys_total = set(keys + keys_inv) unsqueeze = UnsqueezeTransform( - unsqueeze_dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True + dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True ) td = TensorDict( { @@ -5726,8 +5720,8 @@ def test_unsqueeze_inv( for key in keys_total.difference(keys_inv): assert td.get(key).shape == torch.Size(expected_size) - if expected_size[unsqueeze_dim] == 1: - del expected_size[unsqueeze_dim] + if expected_size[dim] == 1: + del expected_size[dim] for key in keys_inv: assert td_modif.get(key).shape == torch.Size(expected_size) # for key in keys_inv: @@ -5787,7 +5781,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): except RuntimeError: pass - @pytest.mark.parametrize("unsqueeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -5795,13 +5789,11 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): "keys", [["observation", "some_other_key"], ["observation_pixels"]] ) @pytest.mark.parametrize("device", get_default_devices()) - def test_transform_compose( - self, keys, size, nchannels, batch, device, unsqueeze_dim - ): + def test_transform_compose(self, keys, size, nchannels, batch, device, dim): torch.manual_seed(0) dont_touch = torch.randn(*batch, *size, nchannels, 16, 16, device=device) unsqueeze = Compose( - UnsqueezeTransform(unsqueeze_dim, in_keys=keys, allow_positive_dim=True) + UnsqueezeTransform(dim, in_keys=keys, allow_positive_dim=True) ) td = TensorDict( { @@ -5812,16 +5804,16 @@ def test_transform_compose( device=device, ) td.set("dont touch", dont_touch.clone()) - if unsqueeze_dim >= 0 and unsqueeze_dim < len(batch): + if dim >= 0 and dim < len(batch): with pytest.raises(RuntimeError, match="batch dimension mismatch"): unsqueeze(td) return unsqueeze(td) expected_size = [*batch, *size, nchannels, 16, 16] - if unsqueeze_dim < 0: - expected_size.insert(len(expected_size) + unsqueeze_dim + 1, 1) + if dim < 0: + expected_size.insert(len(expected_size) + dim + 1, 1) else: - expected_size.insert(unsqueeze_dim, 1) + expected_size.insert(dim, 1) expected_size = torch.Size(expected_size) for key in keys: @@ -5829,7 +5821,7 @@ def test_transform_compose( batch, size, nchannels, - unsqueeze_dim, + dim, ) assert (td.get("dont touch") == dont_touch).all() @@ -5865,10 +5857,10 @@ def test_transform_env(self, out_keys): check_env_specs(env) @pytest.mark.parametrize("out_keys", [None, ["stuff"]]) - @pytest.mark.parametrize("unsqueeze_dim", [-1, 1]) - def test_transform_model(self, out_keys, unsqueeze_dim): + @pytest.mark.parametrize("dim", [-1, 1]) + def test_transform_model(self, out_keys, dim): t = UnsqueezeTransform( - unsqueeze_dim, + dim, in_keys=["observation"], out_keys=out_keys, allow_positive_dim=True, @@ -5878,21 +5870,21 @@ def test_transform_model(self, out_keys, unsqueeze_dim): ) t(td) expected_shape = [3, 4] - if unsqueeze_dim >= 0: - expected_shape.insert(unsqueeze_dim, 1) + if dim >= 0: + expected_shape.insert(dim, 1) else: - expected_shape.insert(len(expected_shape) + unsqueeze_dim + 1, 1) + expected_shape.insert(len(expected_shape) + dim + 1, 1) if out_keys is None: assert td["observation"].shape == torch.Size(expected_shape) else: assert td[out_keys[0]].shape == torch.Size(expected_shape) @pytest.mark.parametrize("out_keys", [None, ["stuff"]]) - @pytest.mark.parametrize("unsqueeze_dim", [-1, 1]) + @pytest.mark.parametrize("dim", [-1, 1]) @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) - def test_transform_rb(self, rbclass, out_keys, unsqueeze_dim): + def test_transform_rb(self, rbclass, out_keys, dim): t = UnsqueezeTransform( - unsqueeze_dim, + dim, in_keys=["observation"], out_keys=out_keys, allow_positive_dim=True, @@ -5905,10 +5897,10 @@ def test_transform_rb(self, rbclass, out_keys, unsqueeze_dim): rb.extend(td) td = rb.sample(2) expected_shape = [2, 3, 4] - if unsqueeze_dim >= 0: - expected_shape.insert(unsqueeze_dim, 1) + if dim >= 0: + expected_shape.insert(dim, 1) else: - expected_shape.insert(len(expected_shape) + unsqueeze_dim + 1, 1) + expected_shape.insert(len(expected_shape) + dim + 1, 1) if out_keys is None: assert td["observation"].shape == torch.Size(expected_shape) else: @@ -5932,7 +5924,7 @@ def test_transform_inverse(self): class TestSqueezeTransform(TransformBase): - @pytest.mark.parametrize("squeeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -5953,12 +5945,12 @@ class TestSqueezeTransform(TransformBase): ], ) def test_transform_no_env( - self, keys, keys_inv, size, nchannels, batch, device, squeeze_dim + self, keys, keys_inv, size, nchannels, batch, device, dim ): torch.manual_seed(0) keys_total = set(keys + keys_inv) squeeze = SqueezeTransform( - squeeze_dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True + dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True ) td = TensorDict( { @@ -5973,12 +5965,12 @@ def test_transform_no_env( for key in keys_total.difference(keys): assert td.get(key).shape == torch.Size(expected_size) - if expected_size[squeeze_dim] == 1: - del expected_size[squeeze_dim] + if expected_size[dim] == 1: + del expected_size[dim] for key in keys: assert td.get(key).shape == torch.Size(expected_size) - @pytest.mark.parametrize("squeeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -5998,15 +5990,13 @@ def test_transform_no_env( [("next", "observation_pixels")], ], ) - def test_squeeze_inv( - self, keys, keys_inv, size, nchannels, batch, device, squeeze_dim - ): + def test_squeeze_inv(self, keys, keys_inv, size, nchannels, batch, device, dim): torch.manual_seed(0) - if squeeze_dim >= 0: - squeeze_dim = squeeze_dim + len(batch) + if dim >= 0: + dim = dim + len(batch) keys_total = set(keys + keys_inv) squeeze = SqueezeTransform( - squeeze_dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True + dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True ) td = TensorDict( { @@ -6021,14 +6011,14 @@ def test_squeeze_inv( for key in keys_total.difference(keys_inv): assert td.get(key).shape == torch.Size(expected_size) - if squeeze_dim < 0: - expected_size.insert(len(expected_size) + squeeze_dim + 1, 1) + if dim < 0: + expected_size.insert(len(expected_size) + dim + 1, 1) else: - expected_size.insert(squeeze_dim, 1) + expected_size.insert(dim, 1) expected_size = torch.Size(expected_size) for key in keys_inv: - assert td.get(key).shape == torch.Size(expected_size), squeeze_dim + assert td.get(key).shape == torch.Size(expected_size), dim @property def _circular_transform(self): @@ -6101,7 +6091,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): except RuntimeError: pass - @pytest.mark.parametrize("squeeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -6114,13 +6104,13 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): "keys_inv", [[], ["action", "some_other_key"], [("next", "observation_pixels")]] ) def test_transform_compose( - self, keys, keys_inv, size, nchannels, batch, device, squeeze_dim + self, keys, keys_inv, size, nchannels, batch, device, dim ): torch.manual_seed(0) keys_total = set(keys + keys_inv) squeeze = Compose( SqueezeTransform( - squeeze_dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True + dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True ) ) td = TensorDict( @@ -6136,8 +6126,8 @@ def test_transform_compose( for key in keys_total.difference(keys): assert td.get(key).shape == torch.Size(expected_size) - if expected_size[squeeze_dim] == 1: - del expected_size[squeeze_dim] + if expected_size[dim] == 1: + del expected_size[dim] for key in keys: assert td.get(key).shape == torch.Size(expected_size) @@ -6154,9 +6144,9 @@ def test_transform_env(self, keys_inv): @pytest.mark.parametrize("out_keys", [None, ["obs_sq"]]) def test_transform_model(self, out_keys): - squeeze_dim = 1 + dim = 1 t = SqueezeTransform( - squeeze_dim, + dim, in_keys=["observation"], out_keys=out_keys, allow_positive_dim=True, @@ -6175,9 +6165,9 @@ def test_transform_model(self, out_keys): @pytest.mark.parametrize("out_keys", [None, ["obs_sq"]]) @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) def test_transform_rb(self, out_keys, rbclass): - squeeze_dim = -2 + dim = -2 t = SqueezeTransform( - squeeze_dim, + dim, in_keys=["observation"], out_keys=out_keys, allow_positive_dim=True, @@ -8925,10 +8915,8 @@ def test_batch_unlocked_with_batch_size_transformed(device): pytest.param( partial(FlattenObservation, first_dim=-3, last_dim=-3), id="FlattenObservation" ), - pytest.param( - partial(UnsqueezeTransform, unsqueeze_dim=-1), id="UnsqueezeTransform" - ), - pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"), + pytest.param(partial(UnsqueezeTransform, dim=-1), id="UnsqueezeTransform"), + pytest.param(partial(SqueezeTransform, dim=-1), id="SqueezeTransform"), GrayScale, pytest.param( partial(ObservationNorm, in_keys=["observation"]), id="ObservationNorm" diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 9ccd2e2aa80..3acc4bd8300 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -58,7 +58,6 @@ from torchrl.envs.transforms import StepCounter, TransformedEnv from torchrl.envs.utils import ( _aggregate_end_of_traj, - _convert_exploration_type, _make_compatible_policy, ExplorationType, RandomPolicy, @@ -489,7 +488,6 @@ def __init__( postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, split_trajs: bool | None = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - exploration_mode: str | None = None, return_same_td: bool = False, reset_when_done: bool = True, interruptor=None, @@ -502,9 +500,6 @@ def __init__( from torchrl.envs.batched_envs import BatchedEnvBase self.closed = True - exploration_type = _convert_exploration_type( - exploration_mode=exploration_mode, exploration_type=exploration_type - ) if create_env_kwargs is None: create_env_kwargs = {} if not isinstance(create_env_fn, EnvBase): @@ -1472,7 +1467,7 @@ class _MultiDataCollector(DataCollectorBase): A ``cat_results`` value of ``-1`` will always concatenate results along the time dimension. This should be preferred over the default. Intermediate values are also accepted. - Defaults to ``0``. + Defaults to ``"stack"``. .. note:: From v0.5, this argument will default to ``"stack"`` for a better interoperability with the rest of the library. @@ -1516,7 +1511,6 @@ def __init__( postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, split_trajs: Optional[bool] = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - exploration_mode=None, reset_when_done: bool = True, update_at_each_batch: bool = False, preemptive_threshold: float = None, @@ -1529,9 +1523,6 @@ def __init__( replay_buffer_chunk: bool = True, trust_policy: bool = None, ): - exploration_type = _convert_exploration_type( - exploration_mode=exploration_mode, exploration_type=exploration_type - ) self.closed = True self.num_workers = len(create_env_fn) @@ -1675,10 +1666,12 @@ def __init__( self.cat_results = cat_results def _check_replay_buffer_init(self): + if self.replay_buffer is None: + return is_init = getattr(self.replay_buffer._storage, "initialized", True) if not is_init: if isinstance(self.create_env_fn[0], EnvCreator): - fake_td = self.create_env_fn[0].tensordict + fake_td = self.create_env_fn[0].meta_data.tensordict elif isinstance(self.create_env_fn[0], EnvBase): fake_td = self.create_env_fn[0].fake_tensordict() else: @@ -2173,19 +2166,6 @@ def iterator(self) -> Iterator[TensorDictBase]: cat_results = self.cat_results if cat_results is None: cat_results = "stack" - warnings.warn( - f"`cat_results` was not specified in the constructor of {type(self).__name__}. " - f"For MultiSyncDataCollector, `cat_results` indicates how the data should " - f"be packed: the preferred option and current default is `cat_results='stack'` " - f"which provides the best interoperability across torchrl components. " - f"Other accepted values are `cat_results=0` (previous behavior) and " - f"`cat_results=-1` (cat along time dimension). Among these two, the latter " - f"should be preferred for consistency across environment configurations. " - f"Currently, the default value is `'stack'`." - f"From v0.6 onward, this warning will be removed. " - f"To suppress this warning, set `cat_results` to the desired value.", - category=DeprecationWarning, - ) self.buffers = {} dones = [False for _ in range(self.num_workers)] @@ -2770,7 +2750,6 @@ def __init__( postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, split_trajs: Optional[bool] = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - exploration_mode=None, reset_when_done: bool = True, update_at_each_batch: bool = False, preemptive_threshold: float = None, @@ -2795,7 +2774,6 @@ def __init__( env_device=env_device, storing_device=storing_device, exploration_type=exploration_type, - exploration_mode=exploration_mode, reset_when_done=reset_when_done, update_at_each_batch=update_at_each_batch, preemptive_threshold=preemptive_threshold, diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 65e6987b4aa..729b8a48171 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -34,7 +34,6 @@ from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator -from torchrl.envs.utils import _convert_exploration_type SUBMITIT_ERR = None try: @@ -419,7 +418,6 @@ def __init__( postproc: Callable | None = None, split_trajs: bool = False, exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa - exploration_mode: str = None, collector_class: Type = SyncDataCollector, collector_kwargs: dict = None, num_workers_per_collector: int = 1, @@ -431,9 +429,6 @@ def __init__( launcher: str = "submitit", tcp_port: int = None, ): - exploration_type = _convert_exploration_type( - exploration_mode=exploration_mode, exploration_type=exploration_type - ) if collector_class == "async": collector_class = MultiaSyncDataCollector diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 816364cf84a..73247df4b0c 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -24,7 +24,6 @@ ) from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories from torchrl.data.utils import CloudpickleWrapper -from torchrl.envs.utils import _convert_exploration_type SUBMITIT_ERR = None try: @@ -275,7 +274,6 @@ def __init__( postproc: Callable | None = None, split_trajs: bool = False, exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa - exploration_mode: str = None, collector_class=SyncDataCollector, collector_kwargs=None, num_workers_per_collector=1, @@ -288,9 +286,6 @@ def __init__( visible_devices=None, tensorpipe_options=None, ): - exploration_type = _convert_exploration_type( - exploration_mode=exploration_mode, exploration_type=exploration_type - ) if collector_class == "async": collector_class = MultiaSyncDataCollector elif collector_class == "sync": diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 744bce1446f..481fb70cc31 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -34,7 +34,6 @@ from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator -from torchrl.envs.utils import _convert_exploration_type SUBMITIT_ERR = None try: @@ -285,7 +284,6 @@ def __init__( postproc: Callable | None = None, split_trajs: bool = False, exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa - exploration_mode: str = None, collector_class=SyncDataCollector, collector_kwargs=None, num_workers_per_collector=1, @@ -296,9 +294,6 @@ def __init__( launcher="submitit", tcp_port=None, ): - exploration_type = _convert_exploration_type( - exploration_mode=exploration_mode, exploration_type=exploration_type - ) if collector_class == "async": collector_class = MultiaSyncDataCollector diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index c8b7fd4aafb..d0d92251b69 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -102,12 +102,10 @@ from .utils import ( check_env_specs, check_marl_grouping, - exploration_mode, exploration_type, ExplorationType, make_composite_from_td, MarlGroupMapType, - set_exploration_mode, set_exploration_type, step_mdp, ) diff --git a/torchrl/envs/transforms/r3m.py b/torchrl/envs/transforms/r3m.py index d4505a4d240..bdc8af1eefa 100644 --- a/torchrl/envs/transforms/r3m.py +++ b/torchrl/envs/transforms/r3m.py @@ -315,7 +315,7 @@ def _init(self): unsqueeze = UnsqueezeTransform( in_keys=in_keys, out_keys=in_keys, - unsqueeze_dim=-4, + dim=-4, ) transforms.append(unsqueeze) diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index b41a290d3f7..6228b0f22b7 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -142,8 +142,8 @@ def _make_detached_param(x): self.sample_log_prob_key = "sample_log_prob" def find_sample_log_prob(module): - if hasattr(module, "SAMPLE_LOG_PROB_KEY"): - self.sample_log_prob_key = module.SAMPLE_LOG_PROB_KEY + if hasattr(module, "log_prob_key"): + self.sample_log_prob_key = module.log_prob_key self.functional_actor.apply(find_sample_log_prob) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 216def16c42..f96a9407e97 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1348,11 +1348,11 @@ def _apply_transform(self, observation: torch.FloatTensor) -> torch.Tensor: @_apply_to_composite def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: observation_spec = self._pixel_observation(observation_spec) - unsqueeze_dim = [1] if self._should_unsqueeze(observation_spec) else [] + dim = [1] if self._should_unsqueeze(observation_spec) else [] if not self.shape_tolerant or observation_spec.shape[-1] == 3: observation_spec.shape = torch.Size( [ - *unsqueeze_dim, + *dim, *observation_spec.shape[:-3], observation_spec.shape[-1], observation_spec.shape[-3], @@ -2137,41 +2137,42 @@ class UnsqueezeTransform(Transform): """Inserts a dimension of size one at the specified position. Args: - unsqueeze_dim (int): dimension to unsqueeze. Must be negative (or allow_positive_dim + dim (int): dimension to unsqueeze. Must be negative (or allow_positive_dim must be turned on). + + Keyword Args: allow_positive_dim (bool, optional): if ``True``, positive dimensions are accepted. - :obj:`UnsqueezeTransform` will map these to the n^th feature dimension + `UnsqueezeTransform`` will map these to the n^th feature dimension (ie n^th dimension after batch size of parent env) of the input tensor, - independently from the tensordict batch size (ie positive dims may be + independently of the tensordict batch size (ie positive dims may be dangerous in contexts where tensordict of different batch dimension are passed). Defaults to False, ie. non-negative dimensions are not permitted. + in_keys (list of NestedKeys): input entries (read). + out_keys (list of NestedKeys): input entries (write). Defaults to ``in_keys`` if + not provided. + in_keys_inv (list of NestedKeys): input entries (read) during :meth:`~.inv` calls. + out_keys_inv (list of NestedKeys): input entries (write) during :meth:`~.inv` calls. + Defaults to ``in_keys_in`` if not provided. """ invertible = True @classmethod def __new__(cls, *args, **kwargs): - cls._unsqueeze_dim = None + cls._dim = None return super().__new__(cls) def __init__( self, dim: int = None, + *, allow_positive_dim: bool = False, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, - **kwargs, ): - if "unsqueeze_dim" in kwargs: - warnings.warn( - "The `unsqueeze_dim` kwarg will be removed in v0.6. Please use `dim` instead." - ) - dim = kwargs["unsqueeze_dim"] - elif dim is None: - raise TypeError("dim must be provided.") if in_keys is None: in_keys = [] # default if out_keys is None: @@ -2191,22 +2192,26 @@ def __init__( raise RuntimeError( "dim should be smaller than 0 to accommodate for " "envs of different batch_sizes. Turn allow_positive_dim to accommodate " - "for positive unsqueeze_dim." + "for positive dim." ) self._dim = dim @property def unsqueeze_dim(self): + return self.dim + + @property + def dim(self): if self._dim >= 0 and self.parent is not None: return len(self.parent.batch_size) + self._dim return self._dim def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: - observation = observation.unsqueeze(self.unsqueeze_dim) + observation = observation.unsqueeze(self.dim) return observation def _inv_apply_transform(self, observation: torch.Tensor) -> torch.Tensor: - observation = observation.squeeze(self.unsqueeze_dim) + observation = observation.squeeze(self.dim) return observation def _transform_spec(self, spec: TensorSpec): @@ -2253,7 +2258,7 @@ def _reset( def __repr__(self) -> str: s = ( - f"{self.__class__.__name__}(unsqueeze_dim={self.unsqueeze_dim}, in_keys={self.in_keys}, out_keys={self.out_keys}," + f"{self.__class__.__name__}(dim={self.dim}, in_keys={self.in_keys}, out_keys={self.out_keys}," f" in_keys_inv={self.in_keys_inv}, out_keys_inv={self.out_keys_inv})" ) return s @@ -2263,14 +2268,14 @@ class SqueezeTransform(UnsqueezeTransform): """Removes a dimension of size one at the specified position. Args: - squeeze_dim (int): dimension to squeeze. + dim (int): dimension to squeeze. """ invertible = True def __init__( self, - squeeze_dim: int, + dim: int | None = None, *args, in_keys: Optional[Sequence[str]] = None, out_keys: Optional[Sequence[str]] = None, @@ -2278,8 +2283,19 @@ def __init__( out_keys_inv: Optional[Sequence[str]] = None, **kwargs, ): + if dim is None: + if "squeeze_dim" in kwargs: + warnings.warn( + f"squeeze_dim will be deprecated in favor of dim arg in {type(self).__name__}." + ) + dim = kwargs.pop("squeeze_dim") + else: + raise TypeError( + f"dim must be passed to {type(self).__name__} constructor." + ) + super().__init__( - squeeze_dim, + dim, *args, in_keys=in_keys, out_keys=out_keys, @@ -2290,7 +2306,7 @@ def __init__( @property def squeeze_dim(self): - return super().unsqueeze_dim + return super().dim _apply_transform = UnsqueezeTransform._inv_apply_transform _inv_apply_transform = UnsqueezeTransform._apply_transform diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index 556eacf579c..a28e490c4f1 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -285,7 +285,7 @@ def _init(self): unsqueeze = UnsqueezeTransform( in_keys=in_keys, out_keys=in_keys, - unsqueeze_dim=-4, + dim=-4, ) transforms.append(unsqueeze) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 9701e96ef62..f1724326d2a 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -32,13 +32,8 @@ from tensordict.base import _is_leaf_nontensor from tensordict.nn import TensorDictModule, TensorDictModuleBase from tensordict.nn.probabilistic import ( # noqa - # Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated! - # Please use the `set_/interaction_type` ones above with the InteractionType enum instead. - # See more details: https://github.com/pytorch/rl/issues/1016 - interaction_mode as exploration_mode, interaction_type as exploration_type, InteractionType as ExplorationType, - set_interaction_mode as set_exploration_mode, set_interaction_type as set_exploration_type, ) from tensordict.utils import is_non_tensor, NestedKey @@ -55,9 +50,7 @@ from torchrl.data.utils import check_no_exclusive_keys, CloudpickleWrapper __all__ = [ - "exploration_mode", "exploration_type", - "set_exploration_mode", "set_exploration_type", "ExplorationType", "check_env_specs", @@ -79,12 +72,6 @@ ) -def _convert_exploration_type(*, exploration_mode, exploration_type): - if exploration_mode is not None: - return ExplorationType.from_str(exploration_mode) - return exploration_type - - class _classproperty(property): def __get__(self, cls, owner): return classmethod(self.fget).__get__(None, owner)() diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 33dfe6aa1df..debb836d6fa 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -212,13 +212,6 @@ class TruncatedNormal(D.Independent): "scale": constraints.greater_than(1e-6), } - def _warn_minmax(self): - warnings.warn( - f"the min / high keyword arguments are deprecated in favor of low / high in {type(self).__name__} " - f"and will be removed entirely in v0.6. ", - DeprecationWarning, - ) - def __init__( self, loc: torch.Tensor, @@ -227,14 +220,7 @@ def __init__( low: Union[torch.Tensor, float] = -1.0, high: Union[torch.Tensor, float] = 1.0, tanh_loc: bool = False, - **kwargs, ): - if "max" in kwargs: - self._warn_minmax() - high = kwargs.pop("max") - if "min" in kwargs: - self._warn_minmax() - low = kwargs.pop("min") err_msg = "TanhNormal high values must be strictly greater than low values" if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): @@ -392,13 +378,6 @@ class TanhNormal(FasterTransformedDistribution): num_params = 2 - def _warn_minmax(self): - warnings.warn( - f"the min / high keyword arguments are deprecated in favor of low / high in {type(self).__name__} " - f"and will be removed entirely in v0.6. ", - DeprecationWarning, - ) - def __init__( self, loc: torch.Tensor, @@ -411,13 +390,6 @@ def __init__( safe_tanh: bool = True, **kwargs, ): - if "max" in kwargs: - self._warn_minmax() - high = kwargs.pop("max") - if "min" in kwargs: - self._warn_minmax() - low = kwargs.pop("min") - if not isinstance(loc, torch.Tensor): loc = torch.as_tensor(loc, dtype=torch.get_default_dtype()) if not isinstance(scale, torch.Tensor): @@ -530,15 +502,10 @@ def root_dist(self): @property def mode(self): - warnings.warn( - "This computation of the mode is based on an inaccurate estimation of the mode " - "given the base_dist mode. " - "To use a more stable implementation of the mode, use dist.get_mode() method instead. " - "To silence this warning, consider using the DETERMINISTIC exploration_type." - "This implementation will be removed in v0.6.", - category=DeprecationWarning, + raise RuntimeError( + f"The distribution {type(self).__name__} has not analytical mode. " + f"Use ExplorationMode.DETERMINISTIC to get a deterministic sample from it." ) - return self.deterministic_sample @property def deterministic_sample(self): @@ -702,13 +669,6 @@ class TanhDelta(FasterTransformedDistribution): "loc": constraints.real, } - def _warn_minmax(self): - warnings.warn( - f"the min / high keyword arguments are deprecated in favor of low / high in {type(self).__name__} " - f"and will be removed entirely in v0.6. ", - category=DeprecationWarning, - ) - def __init__( self, param: torch.Tensor, @@ -717,15 +677,7 @@ def __init__( event_dims: int = 1, atol: float = 1e-6, rtol: float = 1e-6, - **kwargs, ): - if "max" in kwargs: - self._warn_minmax() - high = kwargs.pop("max") - if "min" in kwargs: - self._warn_minmax() - low = kwargs.pop("min") - minmax_msg = "high value has been found to be equal or less than low value" if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): if not (high > low).all(): @@ -767,7 +719,6 @@ def __init__( rtol=rtol, batch_shape=batch_shape, event_shape=event_shape, - **kwargs, ) super().__init__(base, t) diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index 720934a6809..d69a85fd685 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -553,7 +553,7 @@ class ConsistentDropout(_DropoutNd): .. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule` uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch. - The :func:`~torchrl.envs.utils.set_exploration_mode` context manager will have no effect on + The :func:`~torchrl.envs.utils.set_exploration_type` context manager will have no effect on this module. Args: diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 4b38b19c699..483d9b90eea 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -104,7 +104,6 @@ def __init__( out_keys: Optional[Union[NestedKey, List[NestedKey]]] = None, spec: Optional[TensorSpec] = None, safe: bool = False, - default_interaction_mode: str = None, default_interaction_type: str = InteractionType.DETERMINISTIC, distribution_class: Type = Delta, distribution_kwargs: Optional[dict] = None, @@ -117,7 +116,6 @@ def __init__( in_keys=in_keys, out_keys=out_keys, default_interaction_type=default_interaction_type, - default_interaction_mode=default_interaction_mode, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=return_log_prob, diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index cd4e47ef336..a1c70612484 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -97,8 +97,8 @@ class LossModule(TensorDictModuleBase, metaclass=_LossMeta): >>> loss.set_keys(action="action2") .. note:: When a policy that is wrapped or augmented with an exploration module is passed - to the loss, we want to deactivate the exploration through ``set_exploration_mode()`` where - ```` is either ``ExplorationType.MEAN``, ``ExplorationType.MODE`` or + to the loss, we want to deactivate the exploration through ``set_exploration_type()`` where + ```` is either ``ExplorationType.MEAN``, ``ExplorationType.MODE`` or ``ExplorationType.DETERMINISTIC``. The default value is ``DETERMINISTIC`` and it is set through the ``deterministic_sampling_mode`` loss attribute. If another exploration mode is required (or if ``DETERMINISTIC`` is not available), one can diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index b7db2e8242e..0be7d9cb437 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -16,13 +16,12 @@ from tensordict import TensorDictBase from tensordict.nn import ( dispatch, - is_functional, set_skip_existing, TensorDictModule, TensorDictModuleBase, ) from tensordict.utils import NestedKey -from torch import nn, Tensor +from torch import Tensor from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import step_mdp @@ -412,18 +411,13 @@ def value_estimate( @property def is_functional(self): - if isinstance(self.value_network, nn.Module): - return is_functional(self.value_network) - elif self.value_network is None: - return None - else: - raise RuntimeError("Cannot determine if value network is functional.") + # legacy + return False @property def is_stateless(self): - if not self.is_functional: - return False - return self.value_network._is_stateless + # legacy + return False def _next_value(self, tensordict, target_params, kwargs): step_td = step_mdp(tensordict, keep_other=False) diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index b192d115a54..efdde1a1c63 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -19,7 +19,6 @@ from torchrl.data.postprocs import MultiStep from torchrl.envs.batched_envs import ParallelEnv from torchrl.envs.common import EnvBase -from torchrl.envs.utils import ExplorationType def sync_async_collector( @@ -304,7 +303,7 @@ def make_collector_offpolicy( "init_random_frames": cfg.init_random_frames, "split_trajs": True, # trajectories must be separated if multi-step is used - "exploration_type": ExplorationType.from_str(cfg.exploration_mode), + "exploration_type": cfg.exploration_type, } collector = collector_helper(**collector_helper_kwargs) @@ -358,7 +357,7 @@ def make_collector_onpolicy( "storing_device": cfg.collector_device, "split_trajs": True, # trajectories must be separated in online settings - "exploration_mode": cfg.exploration_mode, + "exploration_type": cfg.exploration_type, } collector = collector_helper(**collector_helper_kwargs) @@ -398,7 +397,7 @@ class OnPolicyCollectorConfig: # for each of these parallel wrappers. If env_per_collector=num_workers, no parallel wrapper is created seed: int = 42 # seed used for the environment, pytorch and numpy. - exploration_mode: str = "random" + exploration_type: str = "random" # exploration mode of the data collector. async_collection: bool = False # whether data collection should be done asynchrously. Asynchrounous data collection means diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 869f0f980b3..13721b715e3 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -899,7 +899,7 @@ def make_recorder(actor_model_explore, transform_state_dict, record_interval): record_frames=1000, policy_exploration=actor_model_explore, environment=environment, - exploration_type=ExplorationType.MEAN, + exploration_type=ExplorationType.DETERMINISTIC, record_interval=record_interval, ) return recorder_obj diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index d1b094161f1..25e72dc40f4 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -651,7 +651,7 @@ # number of steps (1000, which is our ``env`` horizon). # The ``rollout`` method of the ``env`` can take a policy as argument: # it will then execute this policy at each step. - with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): # execute a rollout with the trained policy eval_rollout = env.rollout(1000, policy_module) logs["eval reward"].append(eval_rollout["next", "reward"].mean().item()) diff --git a/tutorials/sphinx-tutorials/getting-started-1.py b/tutorials/sphinx-tutorials/getting-started-1.py index 437cae26c42..4e8a1b30930 100644 --- a/tutorials/sphinx-tutorials/getting-started-1.py +++ b/tutorials/sphinx-tutorials/getting-started-1.py @@ -172,7 +172,7 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type -with set_exploration_type(ExplorationType.MEAN): +with set_exploration_type(ExplorationType.DETERMINISTIC): # takes the mean as action rollout = env.rollout(max_steps=10, policy=policy) with set_exploration_type(ExplorationType.RANDOM): @@ -221,7 +221,7 @@ exploration_policy = TensorDictSequential(policy, exploration_module) -with set_exploration_type(ExplorationType.MEAN): +with set_exploration_type(ExplorationType.DETERMINISTIC): # Turns off exploration rollout = env.rollout(max_steps=10, policy=exploration_policy) with set_exploration_type(ExplorationType.RANDOM): diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index 94bd8427e30..1593d42a0ec 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -609,7 +609,7 @@ def __init__(self, td_params=None, seed=None, device="cpu"): env, # ``Unsqueeze`` the observations that we will concatenate UnsqueezeTransform( - unsqueeze_dim=-1, + dim=-1, in_keys=["th", "thdot"], in_keys_inv=["th", "thdot"], ),