-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Algorithm] Simpler IQL example #998
Merged
Merged
Changes from 1 commit
Commits
Show all changes
83 commits
Select commit
Hold shift + click to select a range
d85e307
fix batch_size
BY571 bcf6d46
add offline iql example
BY571 eb7cee0
fix eval reward sum
BY571 aaf6e0d
merge main
BY571 77caf62
update iql online average return
BY571 3a24bca
Merge branch 'main' into rewrite_iql_example
BY571 f34efa9
update iql examples
BY571 ddb7f1a
update rewardscale
BY571 f7f4a0c
Merge branch 'main' into rewrite_iql_example
BY571 c084125
update config, script, clear utils
BY571 9b21360
fix memmap td
BY571 f3f68be
update eval
BY571 2af47dc
udpate logger
BY571 9880756
undo change
BY571 e476641
fix
BY571 22cc5df
update scripts
BY571 d4ca3a6
Merge branch 'rewrite_iql_example' of https://github.com/BY571/rl int…
BY571 2cc511f
update gym version
BY571 32f844f
merge main
BY571 26d8f4f
fix
BY571 155b4da
Merge branch 'main' into rewrite_iql_example
BY571 bf80dba
fix logging and adapt config
BY571 3eaa1e1
update cql and iql offline example
BY571 8c73156
add example script tests
BY571 4ce418f
Merge branch 'main' into rewrite_iql_example
BY571 a01f45f
merge main
BY571 5e8dc39
update namings andadd time
BY571 ae82555
fixes
BY571 bbc85da
update offline
BY571 6f461de
update cql
BY571 2ab87b8
fixes
BY571 7b1af77
update tests and config
BY571 874fcc4
update
BY571 4dae15e
update
BY571 caa39b7
update iql offline config
BY571 438ad1b
update set gym backend
BY571 d05fd91
Merge branch 'main' into rewrite_iql_example
BY571 686d307
update cql bc loss
BY571 5b63e0a
config fix
BY571 6ea2176
Merge branch 'main' into rewrite_iql_example
BY571 4cd605f
observation transform fix
BY571 ab0ca80
Merge branch 'main' into rewrite_iql_example
BY571 0fd374c
delete file
BY571 38d4220
Delete .circleci/config.yml
vmoens 0ad0323
amend
vmoens ace65ac
amend
vmoens 6601235
Merge remote-tracking branch 'origin/main' into rewrite_iql_example
vmoens 444d05c
update cql separate loss
BY571 4d7909f
fix
BY571 0cbe069
Merge branch 'rewrite_iql_example' of https://github.com/BY571/rl int…
BY571 a8e4e64
update iql loss separation
BY571 0d70875
merge main and fixes
BY571 5d97fb4
fix backend
BY571 93c2b1c
fixes
BY571 6704d37
fix logger none
BY571 aeae390
Merge branch 'main' into rewrite_iql_example
BY571 90fb686
fix cql tests and loss
BY571 fe14afd
delay_qvalue fix
BY571 8ebad7a
fix priority setting
BY571 6736e56
fix naming discrete continuous for helper functions
BY571 85fc878
small fixes
BY571 7f27b0f
fix example run tests
BY571 237fe76
fix num_workers cfg
BY571 d806994
collector device fix
BY571 bc209ed
fix
BY571 c774a3d
fixes
BY571 b40bf10
device fixes tests
BY571 433be98
logger fixes tests
BY571 7fdaf04
td clone fix
BY571 11967e0
add cql bc loss comment
BY571 254f8d3
clamp cql lagrange fix
BY571 5089035
max clamp fix
BY571 03b865f
fixes
BY571 6d0c1f0
update metadataupdates
BY571 76eb7d5
Merge branch 'main' into rewrite_iql_example
BY571 e80fdcb
merge main
BY571 2651c3b
fix cql objective actor parameter to module
BY571 cc83496
fix cql objective actor parameter to module
BY571 d1be2c6
Merge remote-tracking branch 'origin/main' into rewrite_iql_example
vmoens ec38f7b
amend
vmoens 826d094
amend
vmoens fdea50e
amend
vmoens a85baad
fix cql batch size
vmoens File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
add offline iql example
- Loading branch information
commit bcf6d46ea553415d60f2e26f531590be1daa6a1a
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
"""IQL Example. | ||
|
||
This is a self-contained example of an offline IQL training script. | ||
|
||
The helper functions are coded in the utils.py associated with this script. | ||
|
||
""" | ||
|
||
import hydra | ||
import torch | ||
import tqdm | ||
from torchrl.envs.utils import set_exploration_mode | ||
|
||
from utils import ( | ||
get_stats, | ||
make_iql_model, | ||
make_iql_optimizer, | ||
make_logger, | ||
make_loss, | ||
make_offline_replay_buffer, | ||
make_parallel_env, | ||
) | ||
|
||
|
||
@hydra.main(config_path=".", config_name="offline_config") | ||
def main(cfg: "DictConfig"): # noqa: F821 | ||
|
||
model_device = cfg.optim.device | ||
|
||
state_dict = get_stats(cfg.env) | ||
evaluation_env = make_parallel_env(cfg.env, state_dict=state_dict) | ||
logger = make_logger(cfg.logger) | ||
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer, state_dict) | ||
|
||
actor_network, qvalue_network, value_network = make_iql_model(cfg) | ||
policy = actor_network.to(model_device) | ||
qvalue_network = qvalue_network.to(model_device) | ||
value_network = value_network.to(model_device) | ||
|
||
loss, target_net_updater = make_loss( | ||
cfg.loss, policy, qvalue_network, value_network | ||
) | ||
optim = make_iql_optimizer(cfg.optim, policy, qvalue_network, value_network) | ||
|
||
pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) | ||
|
||
r0 = None | ||
l0 = None | ||
|
||
for i in range(cfg.optim.gradient_steps): | ||
pbar.update(i) | ||
data = replay_buffer.sample() | ||
# loss | ||
loss_vals = loss(data) | ||
# backprop | ||
actor_loss = loss_vals["loss_actor"] | ||
q_loss = loss_vals["loss_qvalue"] | ||
value_loss = loss_vals["loss_value"] | ||
loss_val = actor_loss + q_loss + value_loss | ||
|
||
optim.zero_grad() | ||
loss_val.backward() | ||
optim.step() | ||
target_net_updater.step() | ||
|
||
# evaluation | ||
if i % cfg.env.evaluation_interval == 0: | ||
with set_exploration_mode("random"), torch.no_grad(): | ||
eval_td = evaluation_env.rollout( | ||
max_steps=1000, policy=policy, auto_cast_to_device=True | ||
) | ||
|
||
if r0 is None: | ||
r0 = eval_td["reward"].mean().item() | ||
if l0 is None: | ||
l0 = loss_val.item() | ||
|
||
for key, value in loss_vals.items(): | ||
logger.log_scalar(key, value.item(), i) | ||
logger.log_scalar("reward_evaluation", eval_td["reward"].mean().item(), i) | ||
|
||
pbar.set_description( | ||
f"loss: {loss_val.item(): 4.4f} (init: {l0: 4.4f}), reward: {eval_td['reward'].mean(): 4.4f} (init={r0: 4.4f})" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Task and env | ||
env: | ||
env_name: Hopper-v3 | ||
env_task: "" | ||
env_library: gym | ||
record_video: 0 | ||
n_samples_stats: 1000 | ||
frame_skip: 1 | ||
from_pixels: False | ||
num_eval_envs: 1 | ||
reward_scaling: | ||
noop: 1 | ||
seed: 0 | ||
evaluation_interval: 1000 | ||
|
||
# Eval | ||
recorder: | ||
video: False | ||
interval: 10000 # record interval in frames | ||
frames: 10000 | ||
|
||
# logger | ||
logger: | ||
backend: wandb | ||
exp_name: iql_hopper-medium-v2 | ||
|
||
# Buffer | ||
replay_buffer: | ||
dataset: hopper-medium-v2 | ||
batch_size: 256 | ||
|
||
# Optimization | ||
optim: | ||
device: cpu | ||
lr: 3e-4 | ||
weight_decay: 0.0 | ||
batch_size: 256 | ||
lr_scheduler: "" | ||
gradient_steps: 1000000 | ||
|
||
|
||
# Policy and model | ||
model: | ||
activation: relu | ||
default_policy_scale: 1.0 | ||
scale_lb: 0.1 | ||
|
||
# loss | ||
loss: | ||
loss_function: smooth_l1 | ||
gamma: 0.99 | ||
tau: 0.05 | ||
# IQL hyperparameter | ||
temperature: 3.0 | ||
expectile: 0.7 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in our examples, we could default to the version that does not require the d4rl library, wdyt?
It's pretty annoying to install and the dataset works without it.