Skip to content

Commit

Permalink
[BugFix] Patch SAC to allow state_dict manipulation before exec (#1607)
Browse files Browse the repository at this point in the history
Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
  • Loading branch information
vmoens and matteobettini authored Oct 5, 2023
1 parent 37c01cc commit 6a3e9f8
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 39 deletions.
43 changes: 43 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -3260,6 +3260,49 @@ def test_sac_notensordict(
assert loss_actor == loss_val_td["loss_actor"]
assert loss_alpha == loss_val_td["loss_alpha"]

def test_state_dict(self, version):
if version == 1:
pytest.skip("Test not implemented for version 1.")
model = torch.nn.Linear(3, 4)
actor_module = TensorDictModule(model, in_keys=["obs"], out_keys=["logits"])
policy = ProbabilisticActor(
module=actor_module,
in_keys=["logits"],
out_keys=["action"],
distribution_class=TanhDelta,
)
value = ValueOperator(module=model, in_keys=["obs"], out_keys="value")

loss = SACLoss(
actor_network=policy,
qvalue_network=value,
action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
)
state = loss.state_dict()

loss = SACLoss(
actor_network=policy,
qvalue_network=value,
action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
)
loss.load_state_dict(state)

# with an access in between
loss = SACLoss(
actor_network=policy,
qvalue_network=value,
action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
)
loss.target_entropy
state = loss.state_dict()

loss = SACLoss(
actor_network=policy,
qvalue_network=value,
action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
)
loss.load_state_dict(state)


@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
Expand Down
93 changes: 54 additions & 39 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
import warnings
from dataclasses import dataclass
from functools import wraps
from numbers import Number
from typing import Dict, Optional, Tuple, Union

Expand Down Expand Up @@ -43,6 +44,15 @@
FUNCTORCH_ERROR = err


def _delezify(func):
@wraps(func)
def new_func(self, *args, **kwargs):
self.target_entropy
return func(self, *args, **kwargs)

return new_func


class SACLoss(LossModule):
"""TorchRL implementation of the SAC loss.
Expand Down Expand Up @@ -371,7 +381,6 @@ def __init__(

self._target_entropy = target_entropy
self._action_spec = action_spec
self.target_entropy_buffer = None
if self._version == 1:
self.actor_critic = ActorCriticWrapper(
self.actor_network, self.value_network
Expand All @@ -384,48 +393,54 @@ def __init__(
if self._version == 1:
self._vmap_qnetwork00 = vmap(qvalue_network)

@property
def target_entropy_buffer(self):
return self.target_entropy

@property
def target_entropy(self):
target_entropy = self.target_entropy_buffer
if target_entropy is None:
delattr(self, "target_entropy_buffer")
target_entropy = self._target_entropy
action_spec = self._action_spec
actor_network = self.actor_network
device = next(self.parameters()).device
if target_entropy == "auto":
action_spec = (
action_spec
if action_spec is not None
else getattr(actor_network, "spec", None)
)
if action_spec is None:
raise RuntimeError(
"Cannot infer the dimensionality of the action. Consider providing "
"the target entropy explicitely or provide the spec of the "
"action tensor in the actor network."
)
if not isinstance(action_spec, CompositeSpec):
action_spec = CompositeSpec({self.tensor_keys.action: action_spec})
if (
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
):
action_container_shape = action_spec[
self.tensor_keys.action[:-1]
].shape
else:
action_container_shape = action_spec.shape
target_entropy = -float(
action_spec[self.tensor_keys.action]
.shape[len(action_container_shape) :]
.numel()
target_entropy = self._buffers.get("_target_entropy", None)
if target_entropy is not None:
return target_entropy
target_entropy = self._target_entropy
action_spec = self._action_spec
actor_network = self.actor_network
device = next(self.parameters()).device
if target_entropy == "auto":
action_spec = (
action_spec
if action_spec is not None
else getattr(actor_network, "spec", None)
)
if action_spec is None:
raise RuntimeError(
"Cannot infer the dimensionality of the action. Consider providing "
"the target entropy explicitely or provide the spec of the "
"action tensor in the actor network."
)
self.register_buffer(
"target_entropy_buffer", torch.tensor(target_entropy, device=device)
if not isinstance(action_spec, CompositeSpec):
action_spec = CompositeSpec({self.tensor_keys.action: action_spec})
if (
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
):

action_container_shape = action_spec[self.tensor_keys.action[:-1]].shape
else:
action_container_shape = action_spec.shape
target_entropy = -float(
action_spec[self.tensor_keys.action]
.shape[len(action_container_shape) :]
.numel()
)
return self.target_entropy_buffer
return target_entropy
delattr(self, "_target_entropy")
self.register_buffer(
"_target_entropy", torch.tensor(target_entropy, device=device)
)
return self._target_entropy

state_dict = _delezify(LossModule.state_dict)
load_state_dict = _delezify(LossModule.load_state_dict)

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down

0 comments on commit 6a3e9f8

Please sign in to comment.