Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] Simpler IQL example #998

Merged
merged 83 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
d85e307
fix batch_size
BY571 Mar 28, 2023
bcf6d46
add offline iql example
BY571 Mar 28, 2023
eb7cee0
fix eval reward sum
BY571 Mar 28, 2023
aaf6e0d
merge main
BY571 Mar 30, 2023
77caf62
update iql online average return
BY571 Mar 30, 2023
3a24bca
Merge branch 'main' into rewrite_iql_example
BY571 Apr 20, 2023
f34efa9
update iql examples
BY571 Apr 20, 2023
ddb7f1a
update rewardscale
BY571 May 22, 2023
f7f4a0c
Merge branch 'main' into rewrite_iql_example
BY571 Jun 1, 2023
c084125
update config, script, clear utils
BY571 Jun 2, 2023
9b21360
fix memmap td
BY571 Jun 2, 2023
f3f68be
update eval
BY571 Jun 2, 2023
2af47dc
udpate logger
BY571 Jun 2, 2023
9880756
undo change
BY571 Jun 2, 2023
e476641
fix
BY571 Jun 2, 2023
22cc5df
update scripts
BY571 Jun 2, 2023
d4ca3a6
Merge branch 'rewrite_iql_example' of https://github.com/BY571/rl int…
BY571 Jun 2, 2023
2cc511f
update gym version
BY571 Jun 2, 2023
32f844f
merge main
BY571 Jun 12, 2023
26d8f4f
fix
BY571 Jun 12, 2023
155b4da
Merge branch 'main' into rewrite_iql_example
BY571 Jun 15, 2023
bf80dba
fix logging and adapt config
BY571 Jun 15, 2023
3eaa1e1
update cql and iql offline example
BY571 Jun 15, 2023
8c73156
add example script tests
BY571 Jun 15, 2023
4ce418f
Merge branch 'main' into rewrite_iql_example
BY571 Jun 16, 2023
a01f45f
merge main
BY571 Sep 15, 2023
5e8dc39
update namings andadd time
BY571 Sep 15, 2023
ae82555
fixes
BY571 Sep 15, 2023
bbc85da
update offline
BY571 Sep 15, 2023
6f461de
update cql
BY571 Sep 18, 2023
2ab87b8
fixes
BY571 Sep 18, 2023
7b1af77
update tests and config
BY571 Sep 21, 2023
874fcc4
update
BY571 Sep 21, 2023
4dae15e
update
BY571 Sep 21, 2023
caa39b7
update iql offline config
BY571 Sep 21, 2023
438ad1b
update set gym backend
BY571 Sep 21, 2023
d05fd91
Merge branch 'main' into rewrite_iql_example
BY571 Sep 26, 2023
686d307
update cql bc loss
BY571 Sep 26, 2023
5b63e0a
config fix
BY571 Oct 3, 2023
6ea2176
Merge branch 'main' into rewrite_iql_example
BY571 Oct 3, 2023
4cd605f
observation transform fix
BY571 Oct 3, 2023
ab0ca80
Merge branch 'main' into rewrite_iql_example
BY571 Oct 4, 2023
0fd374c
delete file
BY571 Oct 4, 2023
38d4220
Delete .circleci/config.yml
vmoens Oct 5, 2023
0ad0323
amend
vmoens Oct 5, 2023
ace65ac
amend
vmoens Oct 5, 2023
6601235
Merge remote-tracking branch 'origin/main' into rewrite_iql_example
vmoens Oct 5, 2023
444d05c
update cql separate loss
BY571 Nov 8, 2023
4d7909f
fix
BY571 Nov 8, 2023
0cbe069
Merge branch 'rewrite_iql_example' of https://github.com/BY571/rl int…
BY571 Nov 8, 2023
a8e4e64
update iql loss separation
BY571 Nov 8, 2023
0d70875
merge main and fixes
BY571 Nov 8, 2023
5d97fb4
fix backend
BY571 Nov 8, 2023
93c2b1c
fixes
BY571 Nov 8, 2023
6704d37
fix logger none
BY571 Nov 8, 2023
aeae390
Merge branch 'main' into rewrite_iql_example
BY571 Nov 9, 2023
90fb686
fix cql tests and loss
BY571 Nov 10, 2023
fe14afd
delay_qvalue fix
BY571 Nov 10, 2023
8ebad7a
fix priority setting
BY571 Nov 10, 2023
6736e56
fix naming discrete continuous for helper functions
BY571 Nov 10, 2023
85fc878
small fixes
BY571 Nov 10, 2023
7f27b0f
fix example run tests
BY571 Nov 10, 2023
237fe76
fix num_workers cfg
BY571 Nov 10, 2023
d806994
collector device fix
BY571 Nov 10, 2023
bc209ed
fix
BY571 Nov 10, 2023
c774a3d
fixes
BY571 Nov 10, 2023
b40bf10
device fixes tests
BY571 Nov 10, 2023
433be98
logger fixes tests
BY571 Nov 10, 2023
7fdaf04
td clone fix
BY571 Nov 26, 2023
11967e0
add cql bc loss comment
BY571 Nov 26, 2023
254f8d3
clamp cql lagrange fix
BY571 Nov 26, 2023
5089035
max clamp fix
BY571 Nov 26, 2023
03b865f
fixes
BY571 Nov 26, 2023
6d0c1f0
update metadataupdates
BY571 Nov 30, 2023
76eb7d5
Merge branch 'main' into rewrite_iql_example
BY571 Nov 30, 2023
e80fdcb
merge main
BY571 Dec 7, 2023
2651c3b
fix cql objective actor parameter to module
BY571 Dec 7, 2023
cc83496
fix cql objective actor parameter to module
BY571 Dec 7, 2023
d1be2c6
Merge remote-tracking branch 'origin/main' into rewrite_iql_example
vmoens Dec 14, 2023
ec38f7b
amend
vmoens Dec 14, 2023
826d094
amend
vmoens Dec 14, 2023
fdea50e
amend
vmoens Dec 14, 2023
a85baad
fix cql batch size
vmoens Dec 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix eval reward sum
  • Loading branch information
BY571 committed Mar 28, 2023
commit eb7cee021bcac313c3b376117a9da4d19d432861
8 changes: 5 additions & 3 deletions examples/iql/iql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,18 @@ def main(cfg: "DictConfig"): # noqa: F821
)

if r0 is None:
r0 = eval_td["reward"].mean().item()
r0 = eval_td["reward"].sum(1).mean().item()
if l0 is None:
l0 = loss_val.item()

for key, value in loss_vals.items():
logger.log_scalar(key, value.item(), i)
logger.log_scalar("reward_evaluation", eval_td["reward"].mean().item(), i)
logger.log_scalar(
"reward_evaluation", eval_td["reward"].sum(1).mean().item(), i
)

pbar.set_description(
f"loss: {loss_val.item(): 4.4f} (init: {l0: 4.4f}), reward: {eval_td['reward'].mean(): 4.4f} (init={r0: 4.4f})"
f"loss: {loss_val.item(): 4.4f} (init: {l0: 4.4f}), reward: {eval_td['reward'].sum(1).mean(): 4.4f} (init={r0: 4.4f})"
)


Expand Down
2 changes: 1 addition & 1 deletion examples/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def make_offline_replay_buffer(rb_cfg, state_dict):
)
data.append_transform(
ObservationNorm(
in_keys=["observation_vector", ("next", "observation_vector")],
in_keys=["observation_vector"],
loc=state_dict["transforms.2.loc"],
scale=state_dict["transforms.2.scale"],
standard_normal=state_dict["transforms.2.standard_normal"],
Expand Down