-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Patch SAC to allow state_dict manipulation before exec #1607
Conversation
qvalue_network=value, | ||
action_spec=UnboundedContinuousTensorSpec(shape=(2,)), | ||
) | ||
state = loss.state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe let's add a forward call or an access to the entropy before saving
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then we're sure it is instantiated and the test loses its value no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what we want to test is when the buffer is instantiated in the saved loss and it has to be loaded in a new loss which has just been init
raise RuntimeError( | ||
"Cannot infer the dimensionality of the action. Consider providing " | ||
"the target entropy explicitely or provide the spec of the " | ||
"action tensor in the actor network." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is a tangential interesting bug related to this.
To see it try
from torchrl.modules import QValueActor, ProbabilisticActor, TanhDelta, ValueOperator
from tensordict.nn import TensorDictModule
from torchrl.objectives import SACLoss
if __name__ == "__main__":
model = torch.nn.Linear(1, 1)
actor_module = TensorDictModule(
torch.nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]
)
policy = ProbabilisticActor(
module=actor_module,
in_keys=["logits"],
out_keys=["action"],
distribution_class=TanhDelta,
)
value = ValueOperator(module=model, in_keys=["obs"], out_keys="value")
loss = SACLoss(
actor_network=policy,
qvalue_network=value,
# action_spec=UnboundedContinuousTensorSpec(shape=(2,)), not passing the spec
)
_ = loss.target_entropy
it will trigger this error but the user will see
File "/Users/matbet/PycharmProjects/rl/prova.py", line 29, in <module>
loss.target_entropy
File "/Users/matbet/PycharmProjects/rl/torchrl/objectives/common.py", line 348, in __getattr__
return super().__getattr__(item)
File "/Users/matbet/miniconda3/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'SACLoss' object has no attribute 'target_entropy'
so the message here is never shown
this is valid for all @property
methods. isn't this curious?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __getattr__
of torch.nn.Module isn't its greatest feature!
Let me think of a fix...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the same extends also to EnvBase
and other components, so the scope of this might be outside this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks a lot
the example in #1594 still fails, to test that you can either use that script or do a forward pass before saving the loss |
Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
Good catch i had forgotten an "_" in the delezify |
…o refactor_target_entropy_sac
good to go for me, let's maybe add the example in #1594 in the tests |
already done :) |
…rch#1607) Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
Description
Closes #1594