Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] the usage of tensordict keys in loss modules #1175

Merged
merged 43 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
83dc591
[Refactor] the usage of tensordict keys in loss modules
Blonck May 22, 2023
09ced18
Add more loss modules
Blonck May 22, 2023
bc04cae
Add more loss modules
Blonck May 23, 2023
75c8ea1
Refactor remaining loss modules
Blonck May 23, 2023
5a74a16
Remove unnecessary tests
Blonck May 23, 2023
32725b4
tensordict_keys dict is not longer overwritten from child classes
Blonck May 23, 2023
ab94848
Merge branch 'main' into refactor_loss_keys
Blonck May 23, 2023
802fe48
Harmonize key name for "state_value"
Blonck May 23, 2023
c6186fc
Polish refactoring
Blonck May 23, 2023
b694e8c
Merge branch 'main' into refactor_loss_keys
Blonck May 23, 2023
9150b74
Apply suggestions from code review
Blonck May 23, 2023
bcd8a28
Use abstract staticmethod to provide default values
Blonck May 23, 2023
6f10920
Merge branch 'main' into refactor_loss_keys
Blonck May 23, 2023
67941df
Merge branch 'main' and rename tensordict_keys to loss_keys
Blonck May 24, 2023
7f3e129
Use simple set_keys on all loss modules
Blonck May 24, 2023
427c1e8
Implement tensor_keys via _AcceptedKeys dataclass
Blonck May 24, 2023
66fb949
Extended _AcceptedKeys to all loss modules
Blonck May 25, 2023
526ab36
Refactor unit test for tensordict keys
Blonck May 25, 2023
08e20da
Merge branch 'main' into refactor_loss_key_advanced
Blonck May 25, 2023
0d476ca
WIP
Blonck May 25, 2023
9bb616a
Fix .in_keys of ValueEstimatorBase
Blonck May 25, 2023
5d00ca0
Move tensordict key logig to base class
Blonck May 25, 2023
4db47e5
Fix make_value_estimator of a2c.py
Blonck May 25, 2023
6b422f9
Remvove '_key' from keynames in ppo.py + polish
Blonck May 26, 2023
317755d
Remvove '_key' from keynames in ddpg.py + polish
Blonck May 26, 2023
fe9fba0
Fix documentation in advantages.py
Blonck May 26, 2023
34091e0
Remvove '_key' from keynames in dqn.py + polish
Blonck May 26, 2023
4baa5dc
Remvove '_key' from keynames in dreamer.py + polish
Blonck May 26, 2023
4595546
Remvove '_key' from keynames in iql.py and redq.py + polish
Blonck May 26, 2023
8ae6ad9
Remove tensor_keys from advantage ctor
Blonck May 26, 2023
a15e220
Add documentation to a2c.py
Blonck May 26, 2023
f1187f3
Change documentation of loss modules
Blonck May 26, 2023
3e09c58
Add unit test for advantages tensordict keys
Blonck May 26, 2023
e52a3f2
Merge branch 'main' into refactor_loss_key_advanced
Blonck May 26, 2023
2dc81c9
Improve wording of docstrings
Blonck May 26, 2023
655c28d
Apply suggestions from code review
Blonck May 28, 2023
226d4d3
Merge branch 'pytorch:main' into refactor_loss_keys
Blonck May 28, 2023
75d33c6
Apply code review changes
Blonck May 28, 2023
4320db6
Merge branch 'main' into refactor_loss_keys_github
Blonck May 30, 2023
cf4cd09
Change line breaking in docstrings for _AcceptedKeys
Blonck May 30, 2023
81c0413
LossModule is not longer an abstract base class.
Blonck May 31, 2023
6e753a4
Merge branch 'main' into refactor_loss_keys_github
Blonck May 31, 2023
cc784a1
Merge branch 'main' into refactor_loss_keys
vmoens May 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Remvove '_key' from keynames in iql.py and redq.py + polish
  • Loading branch information
Blonck committed May 26, 2023
commit 459554617b7a210e15138b3990de74835f81ab22
36 changes: 18 additions & 18 deletions test/test_cost.py
Blonck marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,7 @@ def test_td3_batcher(

with _check_td_steady(ms_td):
loss_ms = loss_fn(ms_td)
assert loss_fn.tensor_keys.priority_key in ms_td.keys()
assert loss_fn.tensor_keys.priority in ms_td.keys()

with torch.no_grad():
torch.manual_seed(0) # log-prob is computed with a random action
Expand Down Expand Up @@ -2192,7 +2192,7 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est):
loss = loss_fn(td)

# check td is left untouched
assert loss_fn.tensor_keys.priority_key in td.keys()
assert loss_fn.tensor_keys.priority in td.keys()

# check that losses are independent
for k in loss.keys():
Expand Down Expand Up @@ -2317,7 +2317,7 @@ def test_redq_shared(self, delay_qvalue, num_qvalue, device):
loss_fn.zero_grad()

# check td is left untouched
assert loss_fn.tensor_keys.priority_key in td.keys()
assert loss_fn.tensor_keys.priority in td.keys()

sum([item for _, item in loss.items()]).backward()
named_parameters = list(loss_fn.named_parameters())
Expand Down Expand Up @@ -2440,7 +2440,7 @@ def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9):

with _check_td_steady(ms_td):
loss_ms = loss_fn(ms_td)
assert loss_fn.tensor_keys.priority_key in ms_td.keys()
assert loss_fn.tensor_keys.priority in ms_td.keys()

with torch.no_grad():
torch.manual_seed(0) # log-prob is computed with a random action
Expand Down Expand Up @@ -2514,13 +2514,13 @@ def test_redq_tensordict_keys(self, td_est):
)

default_keys = {
"priority_key": "td_error",
"action_key": "action",
"value_key": "state_value",
"sample_log_prob_key": "sample_log_prob",
"state_action_value_key": "state_action_value",
"priority": "td_error",
"action": "action",
"value": "state_value",
"sample_log_prob": "sample_log_prob",
"state_action_value": "state_action_value",
}
key_mapping = {"value_key": "value_key"}
key_mapping = {"value": "value_key"}
self.tensordict_keys_test(
loss_fn,
default_keys=default_keys,
Expand Down Expand Up @@ -4150,7 +4150,7 @@ def test_iql(

with _check_td_steady(td):
loss = loss_fn(td)
assert loss_fn.tensor_keys.priority_key in td.keys()
assert loss_fn.tensor_keys.priority in td.keys()

# check that losses are independent
for k in loss.keys():
Expand Down Expand Up @@ -4271,7 +4271,7 @@ def test_iql_batcher(
np.random.seed(0)
with _check_td_steady(ms_td):
loss_ms = loss_fn(ms_td)
assert loss_fn.tensor_keys.priority_key in ms_td.keys()
assert loss_fn.tensor_keys.priority in ms_td.keys()

with torch.no_grad():
torch.manual_seed(0) # log-prob is computed with a random action
Expand Down Expand Up @@ -4338,13 +4338,13 @@ def test_iql_tensordict_keys(self, td_est):
)

default_keys = {
"priority_key": "td_error",
"log_prob_key": "_log_prob",
"action_key": "action",
"state_action_value_key": "state_action_value",
"value_key": "state_value",
"priority": "td_error",
"log_prob": "_log_prob",
"action": "action",
"state_action_value": "state_action_value",
"value": "state_value",
}
key_mapping = {"value_key": "value_key"}
key_mapping = {"value": "value_key"}
self.tensordict_keys_test(
loss_fn,
default_keys=default_keys,
Expand Down
61 changes: 41 additions & 20 deletions torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,32 @@ class IQLLoss(LossModule):

@dataclass
class _AcceptedKeys:
priority_key: NestedKey = "td_error"
log_prob_key: NestedKey = "_log_prob"
action_key: NestedKey = "action"
state_action_value_key: NestedKey = "state_action_value"
value_key: NestedKey = "state_value"
"""Stores default values for all configurable tensordict keys.

This class is used to define and store which tensordict keys are configurable
via `.set_keys(key_name=key_value) and their default values.

Attributes:
------------
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

value : NestedKey
The input tensordict key where the state value is expected.
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
action : NestedKey
The input tensordict key where the action is expected. Defaults to ``"action"``.
log_prob : NestedKey
The input tensordict key where the log probability is expected. Defaults to ``"_log_prob"``.
priority : NestedKey
The input tensordict key where the target priority is written to. Defaults to ``"td_error"``.
state_action_value : NestedKey
The input tensordict key where the state action value is expected.
Will be used for the underlying value estimator as value key. Defaults to ``"state_action_value"``.
"""

value: NestedKey = "state_value"
action: NestedKey = "action"
log_prob: NestedKey = "_log_prob"
priority: NestedKey = "td_error"
state_action_value: NestedKey = "state_action_value"

default_keys = _AcceptedKeys()
default_value_estimator = ValueEstimators.TD0
Expand All @@ -86,7 +107,7 @@ def __init__(
if not _has_functorch:
raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR
super().__init__()
self._set_deprecated_ctor_keys(priority_key=priority_key)
self._set_deprecated_ctor_keys(priority=priority_key)

# IQL parameter
self.temperature = temperature
Expand Down Expand Up @@ -143,7 +164,7 @@ def loss_value_diff(diff, expectile=0.8):
def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
self._value_estimator.set_keys(
value_key=self._tensor_keys.value_key,
value_key=self._tensor_keys.value,
)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
Expand All @@ -161,7 +182,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
loss_qvalue, priority = self._loss_qvalue(td_device)
loss_value = self._loss_value(td_device)

tensordict_reshape.set(self.tensor_keys.priority_key, priority)
tensordict_reshape.set(self.tensor_keys.priority, priority)
if (loss_actor.shape != loss_qvalue.shape) or (
loss_value is not None and loss_actor.shape != loss_value.shape
):
Expand All @@ -174,7 +195,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"loss_actor": loss_actor.mean(),
"loss_qvalue": loss_qvalue.mean(),
"loss_value": loss_value.mean(),
"entropy": -td_device.get(self.tensor_keys.log_prob_key).mean().detach(),
"entropy": -td_device.get(self.tensor_keys.log_prob).mean().detach(),
}

return TensorDict(
Expand All @@ -189,14 +210,14 @@ def _loss_actor(self, tensordict: TensorDictBase) -> Tensor:
params=self.actor_network_params,
)

log_prob = dist.log_prob(tensordict[self.tensor_keys.action_key])
log_prob = dist.log_prob(tensordict[self.tensor_keys.action])

# Min Q value
td_q = tensordict.select(*self.qvalue_network.in_keys)
td_q = vmap(self.qvalue_network, (None, 0))(
td_q, self.target_qvalue_network_params
)
min_q = td_q.get(self.tensor_keys.state_action_value_key).min(0)[0].squeeze(-1)
min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)

if log_prob.shape != min_q.shape:
raise RuntimeError(
Expand All @@ -209,15 +230,15 @@ def _loss_actor(self, tensordict: TensorDictBase) -> Tensor:
td_copy,
params=self.value_network_params,
)
value = td_copy.get(self.tensor_keys.value_key).squeeze(
value = td_copy.get(self.tensor_keys.value).squeeze(
-1
) # assert has no gradient

exp_a = torch.exp((min_q - value) * self.temperature)
exp_a = torch.min(exp_a, torch.FloatTensor([100.0]).to(self.device))

# write log_prob in tensordict for alpha loss
tensordict.set(self.tensor_keys.log_prob_key, log_prob.detach())
tensordict.set(self.tensor_keys.log_prob, log_prob.detach())
return -(exp_a * log_prob).mean()

def _loss_value(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
Expand All @@ -226,21 +247,21 @@ def _loss_value(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
td_q = vmap(self.qvalue_network, (None, 0))(
td_q, self.target_qvalue_network_params
)
min_q = td_q.get(self.tensor_keys.state_action_value_key).min(0)[0].squeeze(-1)
min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
# state value
td_copy = tensordict.select(*self.value_network.in_keys)
self.value_network(
td_copy,
params=self.value_network_params,
)
value = td_copy.get(self.tensor_keys.value_key).squeeze(-1)
value = td_copy.get(self.tensor_keys.value).squeeze(-1)
value_loss = self.loss_value_diff(min_q - value, self.expectile).mean()
return value_loss

def _loss_qvalue(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
obs_keys = self.actor_network.in_keys
# TODO (refactor key usage): what to do with dynamically generated keys
tensordict = tensordict.select("next", *obs_keys, self.tensor_keys.action_key)
tensordict = tensordict.select("next", *obs_keys, self.tensor_keys.action)

target_value = self.value_estimator.value_estimate(
tensordict, target_params=self.target_value_network_params
Expand All @@ -249,9 +270,9 @@ def _loss_qvalue(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
tensordict.select(*self.qvalue_network.in_keys),
self.qvalue_network_params,
)
pred_val = tensordict_expand.get(
self.tensor_keys.state_action_value_key
).squeeze(-1)
pred_val = tensordict_expand.get(self.tensor_keys.state_action_value).squeeze(
-1
)
td_error = abs(pred_val - target_value)
loss_qval = (
distance_loss(
Expand All @@ -276,7 +297,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
hp.update(hyperparams)
tensor_keys = {
"value_target_key": "value_target",
"value_key": self.tensor_keys.value_key,
"value_key": self.tensor_keys.value,
}
if value_type is ValueEstimators.TD1:
self._value_estimator = TD1Estimator(
Expand Down
52 changes: 36 additions & 16 deletions torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,31 @@ class REDQLoss(LossModule):

@dataclass
class _AcceptedKeys:
priority_key: NestedKey = "td_error"
action_key: NestedKey = "action"
value_key: NestedKey = "state_value"
sample_log_prob_key: NestedKey = "sample_log_prob"
state_action_value_key: NestedKey = "state_action_value"
"""Stores default values for all configurable tensordict keys.

This class is used to define and store which tensordict keys are configurable
via `.set_keys(key_name=key_value) and their default values.

Attributes:
------------
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

value : NestedKey
The input tensordict key where the state value is expected.
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
action : NestedKey
The input tensordict key where the action is expected. Defaults to ``"action"``.
sample_log_prob : NestedKey
The input tensordict key where the sample log probability is expected. Defaults to ``"sample_log_prob"``.
priority : NestedKey
The input tensordict key where the target priority is written to. Defaults to ``"td_error"``.
state_action_value : NestedKey
The input tensordict key where the state action value is expected. Defaults to ``"state_action_value"``.
"""

action: NestedKey = "action"
value: NestedKey = "state_value"
sample_log_prob: NestedKey = "sample_log_prob"
priority: NestedKey = "td_error"
state_action_value: NestedKey = "state_action_value"

default_keys = _AcceptedKeys()
delay_actor: bool = False
Expand Down Expand Up @@ -169,7 +189,7 @@ def __init__(
"action tensor in the actor network."
)
target_entropy = -float(
np.prod(actor_network.spec[self.tensor_keys.action_key].shape)
np.prod(actor_network.spec[self.tensor_keys.action].shape)
)
self.register_buffer(
"target_entropy", torch.tensor(target_entropy, device=device)
Expand All @@ -182,7 +202,7 @@ def __init__(
def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
self._value_estimator.set_keys(
value_key=self._tensor_keys.value_key,
value_key=self._tensor_keys.value,
)

@property
Expand All @@ -195,7 +215,7 @@ def alpha(self):
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
obs_keys = self.actor_network.in_keys
tensordict_select = tensordict.clone(False).select(
"next", *obs_keys, self.tensor_keys.action_key
"next", *obs_keys, self.tensor_keys.action
)
selected_models_idx = torch.randperm(self.num_qvalue_nets)[
: self.sub_sample_len
Expand Down Expand Up @@ -227,18 +247,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
actor_params,
)
if isinstance(self.actor_network, TensorDictSequential):
sample_key = self.tensor_keys.action_key
sample_key = self.tensor_keys.action
tensordict_actor_dist = self.actor_network.build_dist_from_params(
td_params
)
else:
sample_key = self.tensor_keys.action_key
sample_key = self.tensor_keys.action
tensordict_actor_dist = self.actor_network.build_dist_from_params(
td_params
)
tensordict_actor.set(sample_key, tensordict_actor_dist.rsample())
tensordict_actor.set(
self.tensor_keys.sample_log_prob_key,
self.tensor_keys.sample_log_prob,
tensordict_actor_dist.log_prob(tensordict_actor.get(sample_key)),
)

Expand Down Expand Up @@ -277,7 +297,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)

state_action_value = tensordict_qval.get(
self.tensor_keys.state_action_value_key
self.tensor_keys.state_action_value
).squeeze(-1)
(
state_action_value_actor,
Expand All @@ -288,7 +308,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
dim=0,
)
sample_log_prob = tensordict_actor.get(
self.tensor_keys.sample_log_prob_key
self.tensor_keys.sample_log_prob
).squeeze(-1)
(
action_log_prob_actor,
Expand All @@ -305,7 +325,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
next_state_value = next_state_value.min(0)[0]

tensordict_select.set(
("next", self.tensor_keys.value_key), next_state_value.unsqueeze(-1)
("next", self.tensor_keys.value), next_state_value.unsqueeze(-1)
)
target_value = self.value_estimator.value_estimate(tensordict_select).squeeze(
-1
Expand All @@ -319,7 +339,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
loss_function=self.loss_function,
).mean(0)

tensordict.set(self.tensor_keys.priority_key, td_error.detach().max(0)[0])
tensordict.set(self.tensor_keys.priority, td_error.detach().max(0)[0])

loss_alpha = self._loss_alpha(sample_log_prob)
if not loss_qval.shape == loss_actor.shape:
Expand Down Expand Up @@ -364,7 +384,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
if hasattr(self, "gamma"):
hp["gamma"] = self.gamma
hp.update(hyperparams)
tensor_keys = {"value_key": self.tensor_keys.value_key}
tensor_keys = {"value_key": self.tensor_keys.value}
# we do not need a value network bc the next state value is already passed
if value_type == ValueEstimators.TD1:
self._value_estimator = TD1Estimator(
Expand Down