Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Logger #1858

Merged
merged 49 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
012f9c0
init
vmoens Jan 19, 2024
3485d6d
Merge remote-tracking branch 'origin/main' into remove-deprecs
vmoens Jan 31, 2024
b1344e6
Merge remote-tracking branch 'origin/main' into remove-deprecs
vmoens Jan 31, 2024
4839c09
amend
vmoens Jan 31, 2024
95ac71d
amend
vmoens Jan 31, 2024
067f4da
Merge remote-tracking branch 'origin/main' into remove-deprecs
vmoens Jan 31, 2024
e31e41b
amend
vmoens Jan 31, 2024
d7d2621
amend
vmoens Jan 31, 2024
f5b195d
amend
vmoens Jan 31, 2024
96d3a18
amend
vmoens Jan 31, 2024
4b41b5c
amend
vmoens Jan 31, 2024
755b7f4
amend
vmoens Jan 31, 2024
4b8c89b
amend
vmoens Jan 31, 2024
2fc17c2
amend
vmoens Jan 31, 2024
ba8dada
amend
vmoens Jan 31, 2024
3120f22
amend
vmoens Jan 31, 2024
ff27094
amend
vmoens Jan 31, 2024
6bdb2c4
amend
vmoens Jan 31, 2024
5dbc588
init
vmoens Jan 31, 2024
f30e02a
amend
vmoens Jan 31, 2024
b1c69b1
amend
vmoens Jan 31, 2024
ab07abe
amend
vmoens Jan 31, 2024
c7e8278
amend
vmoens Jan 31, 2024
f984105
amend
vmoens Jan 31, 2024
deb8b2e
amend
vmoens Jan 31, 2024
d0efa38
Merge remote-tracking branch 'origin/remove-deprecs' into logger
vmoens Jan 31, 2024
62b1dc8
amend
vmoens Jan 31, 2024
bd498ab
amend
vmoens Jan 31, 2024
b35c26a
Merge branch 'remove-deprecs' into logger
vmoens Jan 31, 2024
bf4a0d9
amend
vmoens Jan 31, 2024
1903d10
Merge branch 'remove-deprecs' into logger
vmoens Jan 31, 2024
e4bdde2
amend
vmoens Jan 31, 2024
ba63298
Merge remote-tracking branch 'origin/main' into logger
vmoens Jan 31, 2024
9c36712
amend
vmoens Jan 31, 2024
fdc4557
amend
vmoens Jan 31, 2024
c8d6441
amend
vmoens Jan 31, 2024
f32ce83
amend
vmoens Jan 31, 2024
f68dd4f
amend
vmoens Jan 31, 2024
2c2c9fb
amend
vmoens Jan 31, 2024
31b866a
amend
vmoens Jan 31, 2024
656e75b
amend
vmoens Jan 31, 2024
22cd51b
empty
vmoens Jan 31, 2024
23171f7
amend
vmoens Jan 31, 2024
707747e
amend
vmoens Jan 31, 2024
f44fe53
amend
vmoens Jan 31, 2024
2c737d6
amend
vmoens Jan 31, 2024
2797de8
amend
vmoens Jan 31, 2024
03c201c
amend
vmoens Jan 31, 2024
4b746f6
amend
vmoens Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Jan 19, 2024
commit 012f9c005d04508b47214a505a1aee1a4ca61eca
11 changes: 3 additions & 8 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,11 @@ def _policy_is_tensordict_compatible(policy: nn.Module):
and hasattr(policy, "in_keys")
and hasattr(policy, "out_keys")
):
warnings.warn(
"Passing a policy that is not a TensorDictModuleBase subclass but has in_keys and out_keys "
"will soon be deprecated. We'd like to motivate our users to inherit from this class (which "
raise RuntimeError(
"Passing a policy that is not a tensordict.nn.TensorDictModuleBase subclass but has in_keys and out_keys "
"is deprecated. Users should inherit from this class (which "
"has very few restrictions) to make the experience smoother.",
category=DeprecationWarning,
)
# if the policy is a TensorDictModule or takes a single argument and defines
# in_keys and out_keys then we assume it can already deal with TensorDict input
# to forward and we return True
return True
elif not hasattr(policy, "in_keys") and not hasattr(policy, "out_keys"):
# if it's not a TensorDictModule, and in_keys and out_keys are not defined then
# we assume no TensorDict compatibility and will try to wrap it.
Expand Down
31 changes: 9 additions & 22 deletions torchrl/data/datasets/d4rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(
prefetch: int | None = None,
transform: "torchrl.envs.Transform" | None = None, # noqa-F821
split_trajs: bool = False,
from_env: bool = None,
from_env: bool = False,
use_truncated_as_done: bool = True,
direct_download: bool = None,
terminate_on_end: bool = None,
Expand All @@ -165,29 +165,16 @@ def __init__(
direct_download = not self._has_d4rl

if not direct_download:
if from_env is None:
warnings.warn(
"from_env will soon default to ``False``, ie the data will be "
"downloaded without relying on d4rl by default. "
"For now, ``True`` will still be the default. "
"To disable this warning, explicitly pass the ``from_env`` argument "
"during construction of the dataset.",
category=DeprecationWarning,
)
from_env = True
else:
warnings.warn(
"You are using the D4RL library for collecting data. "
"We advise against this use, as D4RL formatting can be "
"inconsistent. "
"To download the D4RL data without the D4RL library, use "
"direct_download=True in the dataset constructor. "
"Recurring to `direct_download=False` will soon be deprecated."
)
warnings.warn(
"You are using the D4RL library for collecting data. "
"We advise against this use, as D4RL formatting can be "
"inconsistent. "
"To download the D4RL data without the D4RL library, use "
"direct_download=True in the dataset constructor. "
"Recurring to `direct_download=False` will soon be deprecated."
)
self.from_env = from_env
else:
if from_env is None:
from_env = False
self.from_env = from_env

if (download == "force") or (download and not self._is_downloaded()):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,7 @@ def _reset_batch_size(x):
shape = x.get("_rb_batch_size", None)
if shape is not None:
warnings.warn(
"Reshaping nested tensordicts will be deprecated soon.",
"Reshaping nested tensordicts will be deprecated in v0.4.0.",
category=DeprecationWarning,
)
data = x.get("_data")
Expand Down
6 changes: 3 additions & 3 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,15 +376,15 @@ def high(self, value):
@property
def minimum(self):
warnings.warn(
f"{type(self)}.minimum is going to be deprecated in favour of {type(self)}.low",
f"{type(self)}.minimum is going to be deprecated in favour of {type(self)}.low in v0.4.0",
category=DeprecationWarning,
)
return self._low.to(self.device)

@property
def maximum(self):
warnings.warn(
f"{type(self)}.maximum is going to be deprecated in favour of {type(self)}.high",
f"{type(self)}.maximum is going to be deprecated in favour of {type(self)}.high in v0.4.0",
category=DeprecationWarning,
)
return self._high.to(self.device)
Expand Down Expand Up @@ -1472,7 +1472,7 @@ class BoundedTensorSpec(TensorSpec):
# SPEC_HANDLED_FUNCTIONS = {}
DEPRECATED_KWARGS = (
"The `minimum` and `maximum` keyword arguments are now "
"deprecated in favour of `low` and `high`."
"deprecated in favour of `low` and `high` in v0.4.0."
)
CONFLICTING_KWARGS = (
"The keyword arguments {} and {} conflict. Only one of these can be passed."
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def info_dict_reader(self, value: callable):
warnings.warn(
f"Please use {type(self)}.set_info_dict_reader method to set a new info reader. "
f"This method will append a reader to the list of existing readers (if any). "
f"Setting info_dict_reader directly will be soon deprecated.",
f"Setting info_dict_reader directly will be deprecated in v0.4.0.",
category=DeprecationWarning,
)
self._info_dict_reader.append(value)
12 changes: 3 additions & 9 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2345,15 +2345,9 @@ def __init__(
standard_normal: bool = False,
):
if in_keys is None:
warnings.warn(
"Not passing in_keys to ObservationNorm will soon be deprecated. "
"Ensure you specify the entries to be normalized",
category=DeprecationWarning,
raise RuntimeError(
"Not passing in_keys to ObservationNorm is a deprecated behaviour."
)
in_keys = [
"observation",
"pixels",
]

if out_keys is None:
out_keys = copy(in_keys)
Expand Down Expand Up @@ -2692,7 +2686,7 @@ def __init__(
raise ValueError(f"padding must be one of {self.ACCEPTED_PADDING}")
if padding == "zeros":
warnings.warn(
"Padding option 'zeros' will be deprecated in the future. "
"Padding option 'zeros' will be deprecated in v0.4.0. "
"Please use 'constant' padding with padding_value 0 instead.",
category=DeprecationWarning,
)
Expand Down
9 changes: 3 additions & 6 deletions torchrl/modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,9 +872,6 @@ class DistributionalDQNnet(TensorDictModuleBase):
"""Distributional Deep Q-Network.

Args:
DQNet (nn.Module): (deprecated) Q-Network with output length equal
to the number of atoms:
output.shape = [*batch, atoms, actions].
in_keys (list of str or tuples of str): input keys to the log-softmax
operation. Defaults to ``["action_value"]``.
out_keys (list of str or tuples of str): output keys to the log-softmax
Expand All @@ -888,11 +885,11 @@ class DistributionalDQNnet(TensorDictModuleBase):
"instead."
)

def __init__(self, DQNet: nn.Module = None, in_keys=None, out_keys=None):
def __init__(self, *, in_keys=None, out_keys=None, DQNet: nn.Module = None):
super().__init__()
if DQNet is not None:
warnings.warn(
f"Passing a network to {type(self)} is going to be deprecated.",
f"Passing a network to {type(self)} is going to be deprecated in v0.4.0.",
category=DeprecationWarning,
)
if not (
Expand Down Expand Up @@ -1280,7 +1277,7 @@ def __init__(
device: Optional[DEVICE_TYPING] = None,
) -> None:
warnings.warn(
"LSTMNet is being deprecated in favour of torchrl.modules.LSTMModule, and will be removed soon.",
"LSTMNet is being deprecated in favour of torchrl.modules.LSTMModule, and will be removed in v0.4.0.",
category=DeprecationWarning,
)
super().__init__()
Expand Down
10 changes: 5 additions & 5 deletions torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def __init__(
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
"Using specs in action_space will be deprecated in v0.4.0,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
Expand Down Expand Up @@ -825,7 +825,7 @@ def __init__(
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
"Using specs in action_space will be deprecated in v0.4.0,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
Expand Down Expand Up @@ -922,7 +922,7 @@ def __init__(
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
"Using specs in action_space will be deprecated in v0.4.0,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
Expand Down Expand Up @@ -1043,7 +1043,7 @@ def __init__(
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
"Using specs in action_space will be deprecated v0.4.0,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
Expand Down Expand Up @@ -1189,7 +1189,7 @@ def __init__(
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
"Using specs in action_space will be deprecated in v0.4.0,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
Expand Down
99 changes: 2 additions & 97 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,105 +247,10 @@ def __init__(
action_mask_key: Optional[NestedKey] = None,
spec: Optional[TensorSpec] = None,
):
warnings.warn(
"EGreedyWrapper is deprecated and it will be removed in v0.3. "
"Please use torchrl.modules.EGreedyModule instead.",
category=DeprecationWarning,
raise RuntimeError(
"This class is not removed in favour of torchrl.modules.EGreedyModule."
)

super().__init__(policy)
self.register_buffer("eps_init", torch.tensor([eps_init]))
self.register_buffer("eps_end", torch.tensor([eps_end]))
if self.eps_end > self.eps_init:
raise RuntimeError("eps should decrease over time or be constant")
self.annealing_num_steps = annealing_num_steps
self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32))
self.action_key = action_key
self.action_mask_key = action_mask_key
if spec is not None:
if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1:
spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1])
self._spec = spec
elif hasattr(self.td_module, "_spec"):
self._spec = self.td_module._spec.clone()
if action_key not in self._spec.keys():
self._spec[action_key] = None
elif hasattr(self.td_module, "spec"):
self._spec = self.td_module.spec.clone()
if action_key not in self._spec.keys():
self._spec[action_key] = None
else:
self._spec = spec

@property
def spec(self):
return self._spec

def step(self, frames: int = 1) -> None:
"""A step of epsilon decay.

After self.annealing_num_steps, this function is a no-op.

Args:
frames (int): number of frames since last step.

"""
for _ in range(frames):
self.eps.data[0] = max(
self.eps_end.item(),
(
self.eps - (self.eps_init - self.eps_end) / self.annealing_num_steps
).item(),
)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = self.td_module.forward(tensordict)
if exploration_type() == ExplorationType.RANDOM or exploration_type() is None:
if isinstance(self.action_key, tuple) and len(self.action_key) > 1:
action_tensordict = tensordict.get(self.action_key[:-1])
action_key = self.action_key[-1]
else:
action_tensordict = tensordict
action_key = self.action_key

out = action_tensordict.get(action_key)
eps = self.eps.item()
cond = (
torch.rand(action_tensordict.shape, device=action_tensordict.device)
< eps
).to(out.dtype)
cond = expand_as_right(cond, out)
spec = self.spec
if spec is not None:
if isinstance(spec, CompositeSpec):
spec = spec[self.action_key]
if spec.shape != out.shape:
# In batched envs if the spec is passed unbatched, the rand() will not
# cover all batched dims
if (
not len(spec.shape)
or out.shape[-len(spec.shape) :] == spec.shape
):
spec = spec.expand(out.shape)
else:
raise ValueError(
"Action spec shape does not match the action shape"
)
if self.action_mask_key is not None:
action_mask = tensordict.get(self.action_mask_key, None)
if action_mask is None:
raise KeyError(
f"Action mask key {self.action_mask_key} not found in {tensordict}."
)
spec.update_mask(action_mask)
out = cond * spec.rand().to(out.device) + (1 - cond) * out
else:
raise RuntimeError(
"spec must be provided by the policy or directly to the exploration wrapper."
)
action_tensordict.set(action_key, out)
return tensordict


class AdditiveGaussianWrapper(TensorDictModuleWrapper):
"""Additive Gaussian PO wrapper.
Expand Down
9 changes: 2 additions & 7 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -555,11 +554,9 @@ def recurrent_mode(self, value):

@property
def temporal_mode(self):
warnings.warn(
raise RuntimeError(
"temporal_mode is deprecated, use recurrent_mode instead.",
category=DeprecationWarning,
)
return self.recurrent_mode

def set_recurrent_mode(self, mode: bool = True):
"""Returns a new copy of the module that shares the same lstm model but with a different ``recurrent_mode`` attribute (if it differs).
Expand Down Expand Up @@ -1255,11 +1252,9 @@ def recurrent_mode(self, value):

@property
def temporal_mode(self):
warnings.warn(
raise RuntimeError(
"temporal_mode is deprecated, use recurrent_mode instead.",
category=DeprecationWarning,
)
return self.recurrent_mode

def set_recurrent_mode(self, mode: bool = True):
"""Returns a new copy of the module that shares the same gru model but with a different ``recurrent_mode`` attribute (if it differs).
Expand Down
6 changes: 2 additions & 4 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple
Expand All @@ -21,7 +20,7 @@
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
_GAMMA_LMBDA_DEPREC_ERROR,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down Expand Up @@ -261,8 +260,7 @@ def __init__(
self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
self.register_buffer("critic_coef", torch.tensor(critic_coef, device=device))
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self.loss_critic_type = loss_critic_type

@property
Expand Down
Loading
Loading