Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Patch SAC to allow state_dict manipulation before exec #1607

Merged
merged 4 commits into from
Oct 5, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Oct 5, 2023

Description

Closes #1594

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 5, 2023
@vmoens vmoens requested a review from matteobettini October 5, 2023 14:10
@vmoens vmoens added the bug Something isn't working label Oct 5, 2023
qvalue_network=value,
action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
)
state = loss.state_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe let's add a forward call or an access to the entropy before saving

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then we're sure it is instantiated and the test loses its value no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what we want to test is when the buffer is instantiated in the saved loss and it has to be loaded in a new loss which has just been init

Comment on lines +416 to +419
raise RuntimeError(
"Cannot infer the dimensionality of the action. Consider providing "
"the target entropy explicitely or provide the spec of the "
"action tensor in the actor network."
Copy link
Contributor

@matteobettini matteobettini Oct 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a tangential interesting bug related to this.

To see it try

from torchrl.modules import QValueActor, ProbabilisticActor, TanhDelta, ValueOperator
from tensordict.nn import TensorDictModule

from torchrl.objectives import SACLoss

if __name__ == "__main__":

    model = torch.nn.Linear(1, 1) 
    actor_module = TensorDictModule(
        torch.nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]
    )
    policy = ProbabilisticActor(
        module=actor_module,
        in_keys=["logits"],
        out_keys=["action"],
        distribution_class=TanhDelta,
    )
    value = ValueOperator(module=model, in_keys=["obs"], out_keys="value")

    loss = SACLoss(
        actor_network=policy,
        qvalue_network=value,
        # action_spec=UnboundedContinuousTensorSpec(shape=(2,)), not passing the spec
    )
    _ = loss.target_entropy

it will trigger this error but the user will see

File "/Users/matbet/PycharmProjects/rl/prova.py", line 29, in <module>
    loss.target_entropy
  File "/Users/matbet/PycharmProjects/rl/torchrl/objectives/common.py", line 348, in __getattr__
    return super().__getattr__(item)
  File "/Users/matbet/miniconda3/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'SACLoss' object has no attribute 'target_entropy'

so the message here is never shown

this is valid for all @property methods. isn't this curious?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __getattr__ of torch.nn.Module isn't its greatest feature!
Let me think of a fix...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same extends also to EnvBase and other components, so the scope of this might be outside this PR

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks a lot

@matteobettini
Copy link
Contributor

the example in #1594 still fails, to test that you can either use that script or do a forward pass before saving the loss

Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
@vmoens
Copy link
Contributor Author

vmoens commented Oct 5, 2023

Good catch i had forgotten an "_" in the delezify

@matteobettini
Copy link
Contributor

good to go for me, let's maybe add the example in #1594 in the tests

@vmoens
Copy link
Contributor Author

vmoens commented Oct 5, 2023

already done :)

@vmoens vmoens merged commit 6a3e9f8 into main Oct 5, 2023
34 of 47 checks passed
@vmoens vmoens deleted the refactor_target_entropy_sac branch October 5, 2023 14:57
vmoens added a commit to hyerra/rl that referenced this pull request Oct 10, 2023
…rch#1607)

Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] SACLoss cannot be loaded from state_dict
3 participants