-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Algorithm] Update TD3 Example #1523
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Some minor comments
test/test_cost.py
Outdated
for i in loss_val: | ||
assert i in loss_val_td.values(), f"{i} not in {loss_val_td.values()}" | ||
# for i, key in enumerate(loss_val_td.keys()): | ||
# torch.testing.assert_close(loss_val_td.get(key), loss_val[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is dangerous as keys in the tensordict get ordered by name but output tuple loss_val doesn't. So for now im just checking if all values in the loss_val tuple are also in the loss_val_td.
# actor metadata
metadata = {
"state_action_value_actor": state_action_value_actor.mean().detach(),
}
# value metadata
metadata = {
"td_error": td_error,
"next_state_value": next_target_qvalue.mean().detach(),
"pred_value": current_qvalue.mean().detach(),
"target_value": target_value.mean().detach(),
}
# out tensordict
td_out = TensorDict(
source={
"loss_actor": loss_actor,
"loss_qvalue": loss_qval,
**metadata_actor,
**metadata_value,
},
batch_size=[],
)
loss_vals will be in that order (loss_actor, loss_qvalue, state_action_value_actor, next_state_value, pred_value, target_value)
However, as the items are getting ordered in the TD by the keys the output tensordict has actually this order:
(loss_actor, loss_qvalue, next_state_value, pred_value, state_action_value_actor, target_value)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dispatch returns the keys in the order of out_keys.
So this is predictable, we can just do
for i, key in enumerate(loss.out_keys):
torch.testing.assert_close(loss_val_td.get(key), loss_val[i])
does that solve the problem?
# Conflicts: # examples/td3/utils.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
let's wait for the tests to pass!
Description
Updated TD3 script similar to PPO Update, added time logging, more comments, cleaner structure, and fixes here and there. Running some tests right now to verify performance.
What do you think @vmoens, @albertbou92 @matteobettini, how could we improve the example further?
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213
if this solves the issue #15213Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!