Skip to content

Commit

Permalink
[Refactor] Refactor losses for generalization (#1286)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 15, 2023
1 parent 2dbdec9 commit 53f4c52
Show file tree
Hide file tree
Showing 9 changed files with 318 additions and 95 deletions.
37 changes: 26 additions & 11 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,7 @@ def _create_mock_actor(
action_dim=4,
device="cpu",
observation_key="observation",
action_key="action",
):
# Actor
action_spec = BoundedTensorSpec(
Expand All @@ -1495,6 +1496,7 @@ def _create_mock_actor(
in_keys=["loc", "scale"],
spec=action_spec,
distribution_class=TanhNormal,
out_keys=[action_key],
)
return actor.to(device)

Expand Down Expand Up @@ -1981,7 +1983,9 @@ def test_sac_notensordict(
done_key=done_key,
)

actor = self._create_mock_actor(observation_key=observation_key)
actor = self._create_mock_actor(
observation_key=observation_key, action_key=action_key
)
qvalue = self._create_mock_qvalue(
observation_key=observation_key,
action_key=action_key,
Expand Down Expand Up @@ -2521,6 +2525,7 @@ def _create_mock_actor(
action_dim=4,
device="cpu",
observation_key="observation",
action_key="action",
):
# Actor
action_spec = BoundedTensorSpec(
Expand All @@ -2534,6 +2539,7 @@ def _create_mock_actor(
distribution_class=TanhNormal,
return_log_prob=True,
spec=action_spec,
out_keys=[action_key],
)
return actor.to(device)

Expand Down Expand Up @@ -3051,7 +3057,10 @@ def test_redq_tensordict_keys(self, td_est):
@pytest.mark.parametrize("observation_key", ["observation", "observation2"])
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
def test_redq_notensordict(self, action_key, observation_key, reward_key, done_key):
@pytest.mark.parametrize("deprec", [True, False])
def test_redq_notensordict(
self, action_key, observation_key, reward_key, done_key, deprec
):
torch.manual_seed(self.seed)
td = self._create_mock_data_redq(
action_key=action_key,
Expand All @@ -3062,14 +3071,19 @@ def test_redq_notensordict(self, action_key, observation_key, reward_key, done_k

actor = self._create_mock_actor(
observation_key=observation_key,
action_key=action_key,
)
qvalue = self._create_mock_qvalue(
observation_key=observation_key,
action_key=action_key,
out_keys=["state_action_value"],
)

loss = REDQLoss(
if deprec:
cls = REDQLoss_deprecated
else:
cls = REDQLoss
loss = cls(
actor_network=actor,
qvalue_network=qvalue,
)
Expand All @@ -3094,14 +3108,15 @@ def test_redq_notensordict(self, action_key, observation_key, reward_key, done_k
torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_val[2])
torch.testing.assert_close(loss_val_td.get("alpha"), loss_val[3])
torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4])
torch.testing.assert_close(
loss_val_td.get("state_action_value_actor"), loss_val[5]
)
torch.testing.assert_close(
loss_val_td.get("action_log_prob_actor"), loss_val[6]
)
torch.testing.assert_close(loss_val_td.get("next.state_value"), loss_val[7])
torch.testing.assert_close(loss_val_td.get("target_value"), loss_val[8])
if not deprec:
torch.testing.assert_close(
loss_val_td.get("state_action_value_actor"), loss_val[5]
)
torch.testing.assert_close(
loss_val_td.get("action_log_prob_actor"), loss_val[6]
)
torch.testing.assert_close(loss_val_td.get("next.state_value"), loss_val[7])
torch.testing.assert_close(loss_val_td.get("target_value"), loss_val[8])
# test select
torch.manual_seed(self.seed)
loss.select_out_keys("loss_actor", "loss_alpha")
Expand Down
4 changes: 2 additions & 2 deletions torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ def __init__(
if out_keys is None:
out_keys = ["action"]
if (
"action" in out_keys
len(out_keys) == 1
and spec is not None
and not isinstance(spec, CompositeSpec)
):
spec = CompositeSpec(action=spec)
spec = CompositeSpec({out_keys[0]: spec})

super().__init__(
module,
Expand Down
110 changes: 72 additions & 38 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tensordict.utils import NestedKey
from torch import Tensor

from torchrl.data import CompositeSpec
from torchrl.envs.utils import ExplorationType, set_exploration_type

from torchrl.modules import ProbabilisticActor
Expand Down Expand Up @@ -58,6 +59,9 @@ class CQLLoss(LossModule):
Default is None (no minimum value).
max_alpha (float, optional): max value of alpha.
Default is None (no maximum value).
action_spec (TensorSpec, optional): the action tensor spec. If not provided
and the target entropy is ``"auto"``, it will be retrieved from
the actor.
fixed_alpha (bool, optional): if ``True``, alpha will be fixed to its
initial value. Otherwise, alpha will be optimized to
match the 'target_entropy' value.
Expand Down Expand Up @@ -237,6 +241,7 @@ def __init__(
alpha_init: float = 1.0,
min_alpha: float = None,
max_alpha: float = None,
action_spec=None,
fixed_alpha: bool = False,
target_entropy: Union[str, float] = "auto",
delay_actor: bool = False,
Expand Down Expand Up @@ -312,17 +317,9 @@ def __init__(
torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)),
)

if target_entropy == "auto":
if actor_network.spec is None:
raise RuntimeError(
"Cannot infer the dimensionality of the action. Consider providing "
"the target entropy explicitely or provide the spec of the "
"action tensor in the actor network."
)
target_entropy = -float(np.prod(actor_network.spec["action"].shape))
self.register_buffer(
"target_entropy", torch.tensor(target_entropy, device=device)
)
self._target_entropy = target_entropy
self._action_spec = action_spec
self.target_entropy_buffer = None

if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
Expand All @@ -342,6 +339,38 @@ def __init__(
torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)),
)

@property
def target_entropy(self):
target_entropy = self.target_entropy_buffer
if target_entropy is None:
delattr(self, "target_entropy_buffer")
target_entropy = self._target_entropy
action_spec = self._action_spec
actor_network = self.actor_network
device = next(self.parameters()).device
if target_entropy == "auto":
action_spec = (
action_spec
if action_spec is not None
else getattr(actor_network, "spec", None)
)
if action_spec is None:
raise RuntimeError(
"Cannot infer the dimensionality of the action. Consider providing "
"the target entropy explicitely or provide the spec of the "
"action tensor in the actor network."
)
if not isinstance(action_spec, CompositeSpec):
action_spec = CompositeSpec({self.tensor_keys.action: action_spec})
target_entropy = -float(
np.prod(action_spec[self.tensor_keys.action].shape)
)
self.register_buffer(
"target_entropy_buffer", torch.tensor(target_entropy, device=device)
)
return self.target_entropy_buffer
return target_entropy

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
self._value_estimator.set_keys(
Expand Down Expand Up @@ -489,23 +518,26 @@ def _loss_actor(self, tensordict: TensorDictBase) -> Tensor:
tensordict.set(self.tensor_keys.log_prob, log_prob.detach())
return self._alpha * log_prob - min_q_logprob

def _get_policy_actions(self, observation, actor_params, num_actions=10):
observation = (
observation.unsqueeze(-1)
.repeat(1, num_actions, 1)
.view(observation.shape[0] * num_actions, observation.shape[-1])
def _get_policy_actions(self, data, actor_params, num_actions=10):
batch_size = data.batch_size
batch_size = list(batch_size[:-1]) + [batch_size[-1] * num_actions]
tensordict = data.select(*self.actor_network.in_keys).apply(
lambda x: x.repeat_interleave(num_actions, dim=data.ndim - 1),
batch_size=batch_size,
)
tensordict = TensorDict({self.actor_network.in_keys[0]: observation}, [])
with torch.no_grad():
with set_exploration_type(ExplorationType.RANDOM):
dist = self.actor_network.get_dist(tensordict, params=actor_params)
action = dist.rsample()
tensordict.set(self.tensor_keys.action, action)
sample_log_prob = dist.log_prob(action)
tensordict.del_("loc")
tensordict.del_("scale")
# tensordict.del_("loc")
# tensordict.del_("scale")

return tensordict, sample_log_prob
return (
tensordict.select(*self.actor_network.in_keys, self.tensor_keys.action),
sample_log_prob,
)

def _get_value_v(self, tensordict, _alpha, actor_params, qval_params):
tensordict = tensordict.clone(False)
Expand Down Expand Up @@ -539,7 +571,7 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params):

if self.max_q_backup:
next_tensordict, _ = self._get_policy_actions(
tensordict.get(("next", "observation")),
tensordict.get("next"),
actor_params,
num_actions=self.num_random,
)
Expand Down Expand Up @@ -580,18 +612,19 @@ def _loss_qvalue_v(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
# add CQL
random_actions_tensor = (
torch.FloatTensor(
tensordict.shape[0] * self.num_random, tensordict["action"].shape[-1]
tensordict.shape[0] * self.num_random,
tensordict[self.tensor_keys.action].shape[-1],
)
.uniform_(-1, 1)
.to(tensordict.device)
)
curr_actions_td, curr_log_pis = self._get_policy_actions(
tensordict.get("observation"),
tensordict,
self.actor_network_params,
num_actions=self.num_random,
)
new_curr_actions_td, new_log_pis = self._get_policy_actions(
tensordict.get(("next", "observation")),
tensordict.get("next"),
self.actor_network_params,
num_actions=self.num_random,
)
Expand All @@ -608,20 +641,19 @@ def _loss_qvalue_v(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
)
# select and stack input params
# q value random action
tensordict_q_random = TensorDict(
{self.tensor_keys.action: random_actions_tensor}, []
tensordict_q_random = tensordict.select(*self.actor_network.in_keys)

batch_size = tensordict_q_random.batch_size
batch_size = list(batch_size[:-1]) + [batch_size[-1] * self.num_random]
tensordict_q_random = tensordict_q_random.select(
*self.actor_network.in_keys
).apply(
lambda x: x.repeat_interleave(
self.num_random, dim=tensordict_q_random.ndim - 1
),
batch_size=batch_size,
)
current_observation = tensordict.get(*self.actor_network.in_keys)
current_observation = (
current_observation.unsqueeze(-2)
.repeat(1, self.num_random, 1)
.view(
current_observation.shape[0] * self.num_random,
current_observation.shape[-1],
)
)
tensordict_q_random.set(*self.actor_network.in_keys, current_observation)

tensordict_q_random.set(self.tensor_keys.action, random_actions_tensor)
cql_tensordict = torch.cat(
[
tensordict_q_random.expand(
Expand Down Expand Up @@ -654,7 +686,9 @@ def _loss_qvalue_v(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
)

# importance sammpled version
random_density = np.log(0.5 ** curr_actions_td["action"].shape[-1])
random_density = np.log(
0.5 ** curr_actions_td[self.tensor_keys.action].shape[-1]
)
cat_q1 = torch.cat(
[
q_random[0] - random_density,
Expand Down
Loading

0 comments on commit 53f4c52

Please sign in to comment.