Skip to content

Commit

Permalink
[BugFix] Fix tutorials (#1382)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 11, 2023
1 parent fd232c5 commit 5a3f9e0
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 18 deletions.
66 changes: 66 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6918,6 +6918,72 @@ def test_iql_notensordict(self, action_key, observation_key, reward_key, done_ke
assert loss_value == loss_val_td["loss_value"]


@pytest.mark.parametrize("create_target_params", [True, False])
def test_param_buffer_types(create_target_params):
class MyLoss(LossModule):
def __init__(self, actor_network):
super().__init__()
self.convert_to_functional(
actor_network,
"actor_network",
create_target_params=create_target_params,
)

def _forward_value_estimator_keys(self, **kwargs) -> None:
pass

actor_module = TensorDictModule(
nn.Sequential(nn.Linear(3, 4), nn.BatchNorm1d(4)),
in_keys=["obs"],
out_keys=["action"],
)
loss = MyLoss(actor_module)
assert isinstance(loss.actor_network_params["module", "0", "weight"], nn.Parameter)
assert isinstance(
loss.target_actor_network_params["module", "0", "weight"], nn.Parameter
)
assert loss.actor_network_params["module", "0", "weight"].requires_grad
assert not loss.target_actor_network_params["module", "0", "weight"].requires_grad
assert isinstance(loss.actor_network_params["module", "0", "bias"], nn.Parameter)
assert isinstance(
loss.target_actor_network_params["module", "0", "bias"], nn.Parameter
)

if create_target_params:
assert (
loss.actor_network_params["module", "0", "weight"].data.data_ptr()
!= loss.target_actor_network_params["module", "0", "weight"].data.data_ptr()
)
assert (
loss.actor_network_params["module", "0", "bias"].data.data_ptr()
!= loss.target_actor_network_params["module", "0", "bias"].data.data_ptr()
)
else:
assert (
loss.actor_network_params["module", "0", "weight"].data.data_ptr()
== loss.target_actor_network_params["module", "0", "weight"].data.data_ptr()
)
assert (
loss.actor_network_params["module", "0", "bias"].data.data_ptr()
== loss.target_actor_network_params["module", "0", "bias"].data.data_ptr()
)

assert loss.actor_network_params["module", "0", "bias"].requires_grad
assert not loss.target_actor_network_params["module", "0", "bias"].requires_grad
assert not isinstance(
loss.actor_network_params["module", "1", "running_mean"], nn.Parameter
)
assert not isinstance(
loss.target_actor_network_params["module", "1", "running_mean"], nn.Parameter
)
assert not isinstance(
loss.actor_network_params["module", "1", "running_var"], nn.Parameter
)
assert not isinstance(
loss.target_actor_network_params["module", "1", "running_var"], nn.Parameter
)


def test_hold_out():
net = torch.nn.Linear(3, 4)
x = torch.randn(1, 3)
Expand Down
12 changes: 11 additions & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,11 +670,21 @@ def _collate_contiguous(x):
return x


def _collate_as_tensor(x):
return x.contiguous()


def _get_default_collate(storage, _is_tensordict=False):
if isinstance(storage, ListStorage):
if _is_tensordict:
return _collate_list_tensordict
else:
return torch.utils.data._utils.collate.default_collate
elif isinstance(storage, (LazyTensorStorage, LazyMemmapStorage)):
elif isinstance(storage, LazyMemmapStorage):
return _collate_as_tensor
elif isinstance(storage, (LazyTensorStorage,)):
return _collate_contiguous
else:
raise NotImplementedError(
f"Could not find a default collate_fn for storage {type(storage)}."
)
1 change: 1 addition & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
SafeSequential,
TanhModule,
ValueOperator,
VmapModule,
WorldModelWrapper,
)
from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip
40 changes: 32 additions & 8 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,13 @@ def set_keys(self, **kwargs) -> None:

try:
self._forward_value_estimator_keys(**kwargs)
except AttributeError:
except AttributeError as err:
raise AttributeError(
"To utilize `.set_keys(...)` for tensordict key configuration, the subclassed loss module "
"must define an _AcceptedKeys dataclass containing all keys intended for configuration. "
"Moreover, the subclass needs to implement `._forward_value_estimator_keys()` method to "
"facilitate forwarding of any modified tensordict keys to the underlying value_estimator."
)
) from err

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""It is designed to read an input TensorDict and return another tensordict with loss keys named "loss*".
Expand Down Expand Up @@ -373,14 +373,14 @@ def _compare_and_expand(param):

name_params_target = "_target_" + module_name
if create_target_params:
target_params = params_and_buffers.detach().clone()
target_params = params_and_buffers.apply(_make_target_param(clone=True))
target_params_items = target_params.items(True, True)
target_params_list = []
for (key, val) in target_params_items:
if not isinstance(key, tuple):
key = (key,)
name = sep.join([name_params_target, *key])
self.register_buffer(name, Buffer(val))
self.register_buffer(name, val)
target_params_list.append((name, key))
setattr(self, name_params_target + "_params", target_params)
else:
Expand Down Expand Up @@ -409,7 +409,18 @@ def _param_getter(self, network_name):
if isinstance(value_to_set, str):
if value_to_set.endswith("_detached"):
value_to_set = value_to_set[:-9]
value_to_set = getattr(self, value_to_set).detach()
value_to_set = getattr(self, value_to_set)
is_param = isinstance(value_to_set, nn.Parameter)
is_buffer = isinstance(value_to_set, Buffer)
value_to_set = value_to_set.detach()
if is_param:
value_to_set = nn.Parameter(
value_to_set, requires_grad=False
)
elif is_buffer:
value_to_set = Buffer(
value_to_set, requires_grad=False
)
else:
value_to_set = getattr(self, value_to_set)
# params.set(key, value_to_set)
Expand All @@ -419,7 +430,7 @@ def _param_getter(self, network_name):
return params
else:
params = getattr(self, param_name)
return params.detach()
return params.apply(_make_target_param(clone=False))

else:
raise RuntimeError(
Expand Down Expand Up @@ -454,11 +465,12 @@ def _target_param_getter(self, network_name):
target_params._set_tuple(
key, value_to_set, inplace=False, validated=True
)
return target_params
else:
params = getattr(self, param_name)
# should we clone here?
return params.detach() # .clone()
target_params = params.apply(_make_target_param(clone=False))

return target_params

else:
raise RuntimeError(
Expand Down Expand Up @@ -611,3 +623,15 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
)
else:
raise NotImplementedError(f"Unknown value type {value_type}")


class _make_target_param:
def __init__(self, clone):
self.clone = clone

def __call__(self, x):
if isinstance(x, nn.Parameter):
return nn.Parameter(
x.data.clone() if self.clone else x.data, requires_grad=False
)
return Buffer(x.data.clone() if self.clone else x.data)
15 changes: 7 additions & 8 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class PPOLoss(LossModule):
>>> data = TensorDict({
... "observation": torch.randn(*batch, n_obs),
... "action": action,
... "sample_log_prob": torch.randn_like(action[..., 1]) / 10,
... "sample_log_prob": torch.randn_like(action[..., 1]),
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "reward"): torch.randn(*batch, 1),
... ("next", "observation"): torch.randn(*batch, n_obs),
Expand Down Expand Up @@ -266,12 +266,6 @@ def __init__(
self._in_keys = None
self._out_keys = None
super().__init__()
self._set_deprecated_ctor_keys(
advantage=advantage_key,
value_target=value_target_key,
value=value_key,
)

self.convert_to_functional(
actor, "actor", funs_to_decorate=["forward", "get_dist"]
)
Expand All @@ -296,6 +290,11 @@ def __init__(
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
self._set_deprecated_ctor_keys(
advantage=advantage_key,
value_target=value_target_key,
value=value_key,
)

def _set_in_keys(self):
keys = [
Expand Down Expand Up @@ -335,7 +334,7 @@ def out_keys(self, values):
self._out_keys = values

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
if hasattr(self, "_value_estimator") and self._value_estimator is not None:
self._value_estimator.set_keys(
advantage=self.tensor_keys.advantage,
value_target=self.tensor_keys.value_target,
Expand Down
24 changes: 24 additions & 0 deletions torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,30 @@ class _AcceptedKeys:
default_keys = _AcceptedKeys()
value_network: Union[TensorDictModule, Callable]

@property
def advantage_key(self):
return self.tensor_keys.advantage

@property
def value_key(self):
return self.tensor_keys.value

@property
def value_target_key(self):
return self.tensor_keys.value_target

@property
def reward_key(self):
return self.tensor_keys.reward

@property
def done_key(self):
return self.tensor_keys.done

@property
def steps_to_next_obs_key(self):
return self.tensor_keys.steps_to_next_obs

@abc.abstractmethod
def forward(
self,
Expand Down
1 change: 0 additions & 1 deletion tutorials/sphinx-tutorials/coding_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,6 @@
loss_module = ClipPPOLoss(
actor=policy_module,
critic=value_module,
advantage_key="advantage",
clip_epsilon=clip_epsilon,
entropy_bonus=bool(entropy_eps),
entropy_coef=entropy_eps,
Expand Down

0 comments on commit 5a3f9e0

Please sign in to comment.