Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Adds value clipping in ClipPPOLoss loss #2005

Merged
merged 44 commits into from
Mar 18, 2024
Merged
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
88b1ea3
clipping code
albertbou92 Mar 9, 2024
704f7da
fix test
albertbou92 Mar 9, 2024
fee26d6
fix test
albertbou92 Mar 9, 2024
0cd5af0
fix test
albertbou92 Mar 9, 2024
27e1822
extend logiv
albertbou92 Mar 10, 2024
3af7f0c
fix test
albertbou92 Mar 10, 2024
91577c9
register param
albertbou92 Mar 10, 2024
42071f3
register param
albertbou92 Mar 10, 2024
5e33992
minor fix
albertbou92 Mar 10, 2024
47ed022
added param for a2c and reinforce
albertbou92 Mar 11, 2024
c363a37
fix test
albertbou92 Mar 11, 2024
c0d03aa
fix test
albertbou92 Mar 11, 2024
102235e
fix test
albertbou92 Mar 11, 2024
8bf1ae2
fix test
albertbou92 Mar 11, 2024
025ceb4
fix test
albertbou92 Mar 11, 2024
3b8728b
fix test
albertbou92 Mar 11, 2024
fdcbf9f
fix test
albertbou92 Mar 11, 2024
5d75566
docstrings
albertbou92 Mar 11, 2024
4c8c783
docstrings
albertbou92 Mar 11, 2024
19e57b5
Update torchrl/objectives/a2c.py
albertbou92 Mar 12, 2024
245e562
Update torchrl/objectives/ppo.py
albertbou92 Mar 12, 2024
315c9e2
Update torchrl/objectives/ppo.py
albertbou92 Mar 12, 2024
b79c99e
Update torchrl/objectives/ppo.py
albertbou92 Mar 12, 2024
48f06ce
Update torchrl/objectives/ppo.py
albertbou92 Mar 12, 2024
c356670
Update torchrl/objectives/a2c.py
albertbou92 Mar 12, 2024
fc3c3dd
Update torchrl/objectives/a2c.py
albertbou92 Mar 12, 2024
7d04eca
Update torchrl/objectives/a2c.py
albertbou92 Mar 12, 2024
28c9cae
integrate feedback
albertbou92 Mar 12, 2024
9e273d1
fix test
albertbou92 Mar 12, 2024
c836bd8
fix test
albertbou92 Mar 12, 2024
1672c48
format
albertbou92 Mar 12, 2024
cc04ec4
return clip fractions
albertbou92 Mar 12, 2024
78ad4d8
fix test
albertbou92 Mar 12, 2024
ad1425c
update feedback
albertbou92 Mar 14, 2024
82e9efd
fix test
albertbou92 Mar 14, 2024
3723dc2
fix test
albertbou92 Mar 14, 2024
bd1c974
fix test
albertbou92 Mar 14, 2024
3e7b824
Update torchrl/objectives/ppo.py
albertbou92 Mar 14, 2024
7a5c9ee
Update torchrl/objectives/a2c.py
albertbou92 Mar 14, 2024
1374d80
Update torchrl/objectives/ppo.py
albertbou92 Mar 14, 2024
cb83230
Update torchrl/objectives/reinforce.py
albertbou92 Mar 14, 2024
021aa71
minor fixes
albertbou92 Mar 14, 2024
c123231
Merge remote-tracking branch 'origin/main' into clip_value_loss
vmoens Mar 18, 2024
c72347c
amend
vmoens Mar 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix test
  • Loading branch information
albertbou92 committed Mar 10, 2024
commit 3af7f0c34243ca518335dcfd7fae14788982ca7b
17 changes: 9 additions & 8 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6912,14 +6912,15 @@ def test_ppo_value_clipping(self, clip_value_loss):

value = td.pop(loss_fn.tensor_keys.value)

with pytest.raises(
KeyError,
match="clip_value_loss is set to True, but the key "
"state_value was not found in the input tensordict. "
"Make sure that the value_key passed to PPO exists in "
"the input tensordict.",
):
loss = loss_fn(td)
if clip_value_loss:
with pytest.raises(
KeyError,
match="clip_value_loss is set to True, but the key "
"state_value was not found in the input tensordict. "
"Make sure that the value_key passed to PPO exists in "
"the input tensordict.",
):
loss = loss_fn(td)

# Add value to td
td.set(loss_fn.tensor_keys.value, value)
Expand Down