Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] DDPG select also critic input for actor loss #1563

Merged
merged 5 commits into from
Sep 22, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
init
Signed-off-by: Matteo Bettini <matbet@meta.com>
  • Loading branch information
matteobettini committed Sep 22, 2023
commit 5a8b9cb975df1e2dc6ce59adae07fa93a711afeb
18 changes: 9 additions & 9 deletions torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from typing import Tuple

import torch

from tensordict.nn import dispatch, make_functional, repopulate_module, TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey

from tensordict.utils import NestedKey, unravel_key
from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
Expand Down Expand Up @@ -236,15 +236,15 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
self._set_in_keys()

def _set_in_keys(self):
keys = [
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
keys = {
unravel_key(("next", self.tensor_keys.reward)),
unravel_key(("next", self.tensor_keys.done)),
Comment on lines +240 to +241
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should check but since it's a TensorDictModuleBase the in_keys are unravelled by default
https://github.com/pytorch-labs/tensordict/blob/accd8a4a31ec749f52e75a87a875424652069163/tensordict/nn/common.py#L474-L495

So I think we can spare the effort of doing that here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but even if they are already unraveled, we are creating a new tuple ("next", already_unraveled_key) which could be
("next","done") or ("next",("nested", "done"))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is why i am only unravleing the ones where we are putting next

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

look at the link: they will be unravelled after you set them in theory

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

funky stuff! i ll make sure to test it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah maybe not for properties, this one's harder actually
Let's keep it as it is

*self.actor_in_keys,
*[("next", key) for key in self.actor_in_keys],
*[unravel_key(("next", key)) for key in self.actor_in_keys],
*self.value_network.in_keys,
*[("next", key) for key in self.value_network.in_keys],
]
self._in_keys = list(set(keys))
*[unravel_key(("next", key)) for key in self.value_network.in_keys],
}
self._in_keys = sorted(keys, keys=str)

@property
def in_keys(self):
Expand Down