Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Deprecation] Deprecate in prep for release #1820

Merged
merged 23 commits into from
Jan 31, 2024
Merged
Next Next commit
init
  • Loading branch information
vmoens committed Jan 19, 2024
commit 012f9c005d04508b47214a505a1aee1a4ca61eca
11 changes: 3 additions & 8 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,11 @@ def _policy_is_tensordict_compatible(policy: nn.Module):
and hasattr(policy, "in_keys")
and hasattr(policy, "out_keys")
):
warnings.warn(
"Passing a policy that is not a TensorDictModuleBase subclass but has in_keys and out_keys "
"will soon be deprecated. We'd like to motivate our users to inherit from this class (which "
raise RuntimeError(
"Passing a policy that is not a tensordict.nn.TensorDictModuleBase subclass but has in_keys and out_keys "
"is deprecated. Users should inherit from this class (which "
"has very few restrictions) to make the experience smoother.",
category=DeprecationWarning,
)
# if the policy is a TensorDictModule or takes a single argument and defines
# in_keys and out_keys then we assume it can already deal with TensorDict input
# to forward and we return True
return True
elif not hasattr(policy, "in_keys") and not hasattr(policy, "out_keys"):
# if it's not a TensorDictModule, and in_keys and out_keys are not defined then
# we assume no TensorDict compatibility and will try to wrap it.
Expand Down
31 changes: 9 additions & 22 deletions torchrl/data/datasets/d4rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(
prefetch: int | None = None,
transform: "torchrl.envs.Transform" | None = None, # noqa-F821
split_trajs: bool = False,
from_env: bool = None,
from_env: bool = False,
use_truncated_as_done: bool = True,
direct_download: bool = None,
terminate_on_end: bool = None,
Expand All @@ -165,29 +165,16 @@ def __init__(
direct_download = not self._has_d4rl

if not direct_download:
if from_env is None:
warnings.warn(
"from_env will soon default to ``False``, ie the data will be "
"downloaded without relying on d4rl by default. "
"For now, ``True`` will still be the default. "
"To disable this warning, explicitly pass the ``from_env`` argument "
"during construction of the dataset.",
category=DeprecationWarning,
)
from_env = True
else:
warnings.warn(
"You are using the D4RL library for collecting data. "
"We advise against this use, as D4RL formatting can be "
"inconsistent. "
"To download the D4RL data without the D4RL library, use "
"direct_download=True in the dataset constructor. "
"Recurring to `direct_download=False` will soon be deprecated."
)
warnings.warn(
"You are using the D4RL library for collecting data. "
"We advise against this use, as D4RL formatting can be "
"inconsistent. "
"To download the D4RL data without the D4RL library, use "
"direct_download=True in the dataset constructor. "
"Recurring to `direct_download=False` will soon be deprecated."
)
self.from_env = from_env
else:
if from_env is None:
from_env = False
self.from_env = from_env

if (download == "force") or (download and not self._is_downloaded()):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,7 @@ def _reset_batch_size(x):
shape = x.get("_rb_batch_size", None)
if shape is not None:
warnings.warn(
"Reshaping nested tensordicts will be deprecated soon.",
"Reshaping nested tensordicts will be deprecated in v0.4.0.",
category=DeprecationWarning,
)
data = x.get("_data")
Expand Down
6 changes: 3 additions & 3 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,15 +376,15 @@ def high(self, value):
@property
def minimum(self):
warnings.warn(
f"{type(self)}.minimum is going to be deprecated in favour of {type(self)}.low",
f"{type(self)}.minimum is going to be deprecated in favour of {type(self)}.low in v0.4.0",
category=DeprecationWarning,
)
return self._low.to(self.device)

@property
def maximum(self):
warnings.warn(
f"{type(self)}.maximum is going to be deprecated in favour of {type(self)}.high",
f"{type(self)}.maximum is going to be deprecated in favour of {type(self)}.high in v0.4.0",
category=DeprecationWarning,
)
return self._high.to(self.device)
Expand Down Expand Up @@ -1472,7 +1472,7 @@ class BoundedTensorSpec(TensorSpec):
# SPEC_HANDLED_FUNCTIONS = {}
DEPRECATED_KWARGS = (
"The `minimum` and `maximum` keyword arguments are now "
"deprecated in favour of `low` and `high`."
"deprecated in favour of `low` and `high` in v0.4.0."
)
CONFLICTING_KWARGS = (
"The keyword arguments {} and {} conflict. Only one of these can be passed."
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def info_dict_reader(self, value: callable):
warnings.warn(
f"Please use {type(self)}.set_info_dict_reader method to set a new info reader. "
f"This method will append a reader to the list of existing readers (if any). "
f"Setting info_dict_reader directly will be soon deprecated.",
f"Setting info_dict_reader directly will be deprecated in v0.4.0.",
category=DeprecationWarning,
)
self._info_dict_reader.append(value)
12 changes: 3 additions & 9 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2345,15 +2345,9 @@ def __init__(
standard_normal: bool = False,
):
if in_keys is None:
warnings.warn(
"Not passing in_keys to ObservationNorm will soon be deprecated. "
"Ensure you specify the entries to be normalized",
category=DeprecationWarning,
raise RuntimeError(
"Not passing in_keys to ObservationNorm is a deprecated behaviour."
)
in_keys = [
"observation",
"pixels",
]

if out_keys is None:
out_keys = copy(in_keys)
Expand Down Expand Up @@ -2692,7 +2686,7 @@ def __init__(
raise ValueError(f"padding must be one of {self.ACCEPTED_PADDING}")
if padding == "zeros":
warnings.warn(
"Padding option 'zeros' will be deprecated in the future. "
"Padding option 'zeros' will be deprecated in v0.4.0. "
"Please use 'constant' padding with padding_value 0 instead.",
category=DeprecationWarning,
)
Expand Down
9 changes: 3 additions & 6 deletions torchrl/modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,9 +872,6 @@ class DistributionalDQNnet(TensorDictModuleBase):
"""Distributional Deep Q-Network.

Args:
DQNet (nn.Module): (deprecated) Q-Network with output length equal
to the number of atoms:
output.shape = [*batch, atoms, actions].
in_keys (list of str or tuples of str): input keys to the log-softmax
operation. Defaults to ``["action_value"]``.
out_keys (list of str or tuples of str): output keys to the log-softmax
Expand All @@ -888,11 +885,11 @@ class DistributionalDQNnet(TensorDictModuleBase):
"instead."
)

def __init__(self, DQNet: nn.Module = None, in_keys=None, out_keys=None):
def __init__(self, *, in_keys=None, out_keys=None, DQNet: nn.Module = None):
super().__init__()
if DQNet is not None:
warnings.warn(
f"Passing a network to {type(self)} is going to be deprecated.",
f"Passing a network to {type(self)} is going to be deprecated in v0.4.0.",
category=DeprecationWarning,
)
if not (
Expand Down Expand Up @@ -1280,7 +1277,7 @@ def __init__(
device: Optional[DEVICE_TYPING] = None,
) -> None:
warnings.warn(
"LSTMNet is being deprecated in favour of torchrl.modules.LSTMModule, and will be removed soon.",
"LSTMNet is being deprecated in favour of torchrl.modules.LSTMModule, and will be removed in v0.4.0.",
category=DeprecationWarning,
)
super().__init__()
Expand Down
10 changes: 5 additions & 5 deletions torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def __init__(
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
"Using specs in action_space will be deprecated in v0.4.0,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
Expand Down Expand Up @@ -825,7 +825,7 @@ def __init__(
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
"Using specs in action_space will be deprecated in v0.4.0,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
Expand Down Expand Up @@ -922,7 +922,7 @@ def __init__(
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
"Using specs in action_space will be deprecated in v0.4.0,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
Expand Down Expand Up @@ -1043,7 +1043,7 @@ def __init__(
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
"Using specs in action_space will be deprecated v0.4.0,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
Expand Down Expand Up @@ -1189,7 +1189,7 @@ def __init__(
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
"Using specs in action_space will be deprecated in v0.4.0,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
Expand Down
99 changes: 2 additions & 97 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,105 +247,10 @@ def __init__(
action_mask_key: Optional[NestedKey] = None,
spec: Optional[TensorSpec] = None,
):
warnings.warn(
"EGreedyWrapper is deprecated and it will be removed in v0.3. "
"Please use torchrl.modules.EGreedyModule instead.",
category=DeprecationWarning,
raise RuntimeError(
"This class is not removed in favour of torchrl.modules.EGreedyModule."
)

super().__init__(policy)
self.register_buffer("eps_init", torch.tensor([eps_init]))
self.register_buffer("eps_end", torch.tensor([eps_end]))
if self.eps_end > self.eps_init:
raise RuntimeError("eps should decrease over time or be constant")
self.annealing_num_steps = annealing_num_steps
self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32))
self.action_key = action_key
self.action_mask_key = action_mask_key
if spec is not None:
if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1:
spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1])
self._spec = spec
elif hasattr(self.td_module, "_spec"):
self._spec = self.td_module._spec.clone()
if action_key not in self._spec.keys():
self._spec[action_key] = None
elif hasattr(self.td_module, "spec"):
self._spec = self.td_module.spec.clone()
if action_key not in self._spec.keys():
self._spec[action_key] = None
else:
self._spec = spec

@property
def spec(self):
return self._spec

def step(self, frames: int = 1) -> None:
"""A step of epsilon decay.

After self.annealing_num_steps, this function is a no-op.

Args:
frames (int): number of frames since last step.

"""
for _ in range(frames):
self.eps.data[0] = max(
self.eps_end.item(),
(
self.eps - (self.eps_init - self.eps_end) / self.annealing_num_steps
).item(),
)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = self.td_module.forward(tensordict)
if exploration_type() == ExplorationType.RANDOM or exploration_type() is None:
if isinstance(self.action_key, tuple) and len(self.action_key) > 1:
action_tensordict = tensordict.get(self.action_key[:-1])
action_key = self.action_key[-1]
else:
action_tensordict = tensordict
action_key = self.action_key

out = action_tensordict.get(action_key)
eps = self.eps.item()
cond = (
torch.rand(action_tensordict.shape, device=action_tensordict.device)
< eps
).to(out.dtype)
cond = expand_as_right(cond, out)
spec = self.spec
if spec is not None:
if isinstance(spec, CompositeSpec):
spec = spec[self.action_key]
if spec.shape != out.shape:
# In batched envs if the spec is passed unbatched, the rand() will not
# cover all batched dims
if (
not len(spec.shape)
or out.shape[-len(spec.shape) :] == spec.shape
):
spec = spec.expand(out.shape)
else:
raise ValueError(
"Action spec shape does not match the action shape"
)
if self.action_mask_key is not None:
action_mask = tensordict.get(self.action_mask_key, None)
if action_mask is None:
raise KeyError(
f"Action mask key {self.action_mask_key} not found in {tensordict}."
)
spec.update_mask(action_mask)
out = cond * spec.rand().to(out.device) + (1 - cond) * out
else:
raise RuntimeError(
"spec must be provided by the policy or directly to the exploration wrapper."
)
action_tensordict.set(action_key, out)
return tensordict


class AdditiveGaussianWrapper(TensorDictModuleWrapper):
"""Additive Gaussian PO wrapper.
Expand Down
9 changes: 2 additions & 7 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -555,11 +554,9 @@ def recurrent_mode(self, value):

@property
def temporal_mode(self):
warnings.warn(
raise RuntimeError(
"temporal_mode is deprecated, use recurrent_mode instead.",
category=DeprecationWarning,
)
return self.recurrent_mode

def set_recurrent_mode(self, mode: bool = True):
"""Returns a new copy of the module that shares the same lstm model but with a different ``recurrent_mode`` attribute (if it differs).
Expand Down Expand Up @@ -1255,11 +1252,9 @@ def recurrent_mode(self, value):

@property
def temporal_mode(self):
warnings.warn(
raise RuntimeError(
"temporal_mode is deprecated, use recurrent_mode instead.",
category=DeprecationWarning,
)
return self.recurrent_mode

def set_recurrent_mode(self, mode: bool = True):
"""Returns a new copy of the module that shares the same gru model but with a different ``recurrent_mode`` attribute (if it differs).
Expand Down
6 changes: 2 additions & 4 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple
Expand All @@ -21,7 +20,7 @@
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
_GAMMA_LMBDA_DEPREC_ERROR,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down Expand Up @@ -261,8 +260,7 @@ def __init__(
self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
self.register_buffer("critic_coef", torch.tensor(critic_coef, device=device))
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self.loss_critic_type = loss_critic_type

@property
Expand Down
Loading
Loading