Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] Simpler IQL example #998

Merged
merged 83 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
d85e307
fix batch_size
BY571 Mar 28, 2023
bcf6d46
add offline iql example
BY571 Mar 28, 2023
eb7cee0
fix eval reward sum
BY571 Mar 28, 2023
aaf6e0d
merge main
BY571 Mar 30, 2023
77caf62
update iql online average return
BY571 Mar 30, 2023
3a24bca
Merge branch 'main' into rewrite_iql_example
BY571 Apr 20, 2023
f34efa9
update iql examples
BY571 Apr 20, 2023
ddb7f1a
update rewardscale
BY571 May 22, 2023
f7f4a0c
Merge branch 'main' into rewrite_iql_example
BY571 Jun 1, 2023
c084125
update config, script, clear utils
BY571 Jun 2, 2023
9b21360
fix memmap td
BY571 Jun 2, 2023
f3f68be
update eval
BY571 Jun 2, 2023
2af47dc
udpate logger
BY571 Jun 2, 2023
9880756
undo change
BY571 Jun 2, 2023
e476641
fix
BY571 Jun 2, 2023
22cc5df
update scripts
BY571 Jun 2, 2023
d4ca3a6
Merge branch 'rewrite_iql_example' of https://github.com/BY571/rl int…
BY571 Jun 2, 2023
2cc511f
update gym version
BY571 Jun 2, 2023
32f844f
merge main
BY571 Jun 12, 2023
26d8f4f
fix
BY571 Jun 12, 2023
155b4da
Merge branch 'main' into rewrite_iql_example
BY571 Jun 15, 2023
bf80dba
fix logging and adapt config
BY571 Jun 15, 2023
3eaa1e1
update cql and iql offline example
BY571 Jun 15, 2023
8c73156
add example script tests
BY571 Jun 15, 2023
4ce418f
Merge branch 'main' into rewrite_iql_example
BY571 Jun 16, 2023
a01f45f
merge main
BY571 Sep 15, 2023
5e8dc39
update namings andadd time
BY571 Sep 15, 2023
ae82555
fixes
BY571 Sep 15, 2023
bbc85da
update offline
BY571 Sep 15, 2023
6f461de
update cql
BY571 Sep 18, 2023
2ab87b8
fixes
BY571 Sep 18, 2023
7b1af77
update tests and config
BY571 Sep 21, 2023
874fcc4
update
BY571 Sep 21, 2023
4dae15e
update
BY571 Sep 21, 2023
caa39b7
update iql offline config
BY571 Sep 21, 2023
438ad1b
update set gym backend
BY571 Sep 21, 2023
d05fd91
Merge branch 'main' into rewrite_iql_example
BY571 Sep 26, 2023
686d307
update cql bc loss
BY571 Sep 26, 2023
5b63e0a
config fix
BY571 Oct 3, 2023
6ea2176
Merge branch 'main' into rewrite_iql_example
BY571 Oct 3, 2023
4cd605f
observation transform fix
BY571 Oct 3, 2023
ab0ca80
Merge branch 'main' into rewrite_iql_example
BY571 Oct 4, 2023
0fd374c
delete file
BY571 Oct 4, 2023
38d4220
Delete .circleci/config.yml
vmoens Oct 5, 2023
0ad0323
amend
vmoens Oct 5, 2023
ace65ac
amend
vmoens Oct 5, 2023
6601235
Merge remote-tracking branch 'origin/main' into rewrite_iql_example
vmoens Oct 5, 2023
444d05c
update cql separate loss
BY571 Nov 8, 2023
4d7909f
fix
BY571 Nov 8, 2023
0cbe069
Merge branch 'rewrite_iql_example' of https://github.com/BY571/rl int…
BY571 Nov 8, 2023
a8e4e64
update iql loss separation
BY571 Nov 8, 2023
0d70875
merge main and fixes
BY571 Nov 8, 2023
5d97fb4
fix backend
BY571 Nov 8, 2023
93c2b1c
fixes
BY571 Nov 8, 2023
6704d37
fix logger none
BY571 Nov 8, 2023
aeae390
Merge branch 'main' into rewrite_iql_example
BY571 Nov 9, 2023
90fb686
fix cql tests and loss
BY571 Nov 10, 2023
fe14afd
delay_qvalue fix
BY571 Nov 10, 2023
8ebad7a
fix priority setting
BY571 Nov 10, 2023
6736e56
fix naming discrete continuous for helper functions
BY571 Nov 10, 2023
85fc878
small fixes
BY571 Nov 10, 2023
7f27b0f
fix example run tests
BY571 Nov 10, 2023
237fe76
fix num_workers cfg
BY571 Nov 10, 2023
d806994
collector device fix
BY571 Nov 10, 2023
bc209ed
fix
BY571 Nov 10, 2023
c774a3d
fixes
BY571 Nov 10, 2023
b40bf10
device fixes tests
BY571 Nov 10, 2023
433be98
logger fixes tests
BY571 Nov 10, 2023
7fdaf04
td clone fix
BY571 Nov 26, 2023
11967e0
add cql bc loss comment
BY571 Nov 26, 2023
254f8d3
clamp cql lagrange fix
BY571 Nov 26, 2023
5089035
max clamp fix
BY571 Nov 26, 2023
03b865f
fixes
BY571 Nov 26, 2023
6d0c1f0
update metadataupdates
BY571 Nov 30, 2023
76eb7d5
Merge branch 'main' into rewrite_iql_example
BY571 Nov 30, 2023
e80fdcb
merge main
BY571 Dec 7, 2023
2651c3b
fix cql objective actor parameter to module
BY571 Dec 7, 2023
cc83496
fix cql objective actor parameter to module
BY571 Dec 7, 2023
d1be2c6
Merge remote-tracking branch 'origin/main' into rewrite_iql_example
vmoens Dec 14, 2023
ec38f7b
amend
vmoens Dec 14, 2023
826d094
amend
vmoens Dec 14, 2023
fdea50e
amend
vmoens Dec 14, 2023
a85baad
fix cql batch size
vmoens Dec 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add offline iql example
  • Loading branch information
BY571 committed Mar 28, 2023
commit bcf6d46ea553415d60f2e26f531590be1daa6a1a
92 changes: 92 additions & 0 deletions examples/iql/iql_offline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""IQL Example.

This is a self-contained example of an offline IQL training script.

The helper functions are coded in the utils.py associated with this script.

"""

import hydra
import torch
import tqdm
from torchrl.envs.utils import set_exploration_mode

from utils import (
get_stats,
make_iql_model,
make_iql_optimizer,
make_logger,
make_loss,
make_offline_replay_buffer,
make_parallel_env,
)


@hydra.main(config_path=".", config_name="offline_config")
def main(cfg: "DictConfig"): # noqa: F821

model_device = cfg.optim.device

state_dict = get_stats(cfg.env)
evaluation_env = make_parallel_env(cfg.env, state_dict=state_dict)
logger = make_logger(cfg.logger)
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer, state_dict)

actor_network, qvalue_network, value_network = make_iql_model(cfg)
policy = actor_network.to(model_device)
qvalue_network = qvalue_network.to(model_device)
value_network = value_network.to(model_device)

loss, target_net_updater = make_loss(
cfg.loss, policy, qvalue_network, value_network
)
optim = make_iql_optimizer(cfg.optim, policy, qvalue_network, value_network)

pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)

r0 = None
l0 = None

for i in range(cfg.optim.gradient_steps):
pbar.update(i)
data = replay_buffer.sample()
# loss
loss_vals = loss(data)
# backprop
actor_loss = loss_vals["loss_actor"]
q_loss = loss_vals["loss_qvalue"]
value_loss = loss_vals["loss_value"]
loss_val = actor_loss + q_loss + value_loss

optim.zero_grad()
loss_val.backward()
optim.step()
target_net_updater.step()

# evaluation
if i % cfg.env.evaluation_interval == 0:
with set_exploration_mode("random"), torch.no_grad():
eval_td = evaluation_env.rollout(
max_steps=1000, policy=policy, auto_cast_to_device=True
)

if r0 is None:
r0 = eval_td["reward"].mean().item()
if l0 is None:
l0 = loss_val.item()

for key, value in loss_vals.items():
logger.log_scalar(key, value.item(), i)
logger.log_scalar("reward_evaluation", eval_td["reward"].mean().item(), i)

pbar.set_description(
f"loss: {loss_val.item(): 4.4f} (init: {l0: 4.4f}), reward: {eval_td['reward'].mean(): 4.4f} (init={r0: 4.4f})"
)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion examples/iql/iql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
"""IQL Example.

This is a self-contained example of a IQL training script.
This is a self-contained example of an online IQL training script.

It works across Gym and DM-control over a variety of tasks.

Expand Down
55 changes: 55 additions & 0 deletions examples/iql/offline_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Task and env
env:
env_name: Hopper-v3
env_task: ""
env_library: gym
record_video: 0
n_samples_stats: 1000
frame_skip: 1
from_pixels: False
num_eval_envs: 1
reward_scaling:
noop: 1
seed: 0
evaluation_interval: 1000

# Eval
recorder:
video: False
interval: 10000 # record interval in frames
frames: 10000

# logger
logger:
backend: wandb
exp_name: iql_hopper-medium-v2

# Buffer
replay_buffer:
dataset: hopper-medium-v2
batch_size: 256

# Optimization
optim:
device: cpu
lr: 3e-4
weight_decay: 0.0
batch_size: 256
lr_scheduler: ""
gradient_steps: 1000000


# Policy and model
model:
activation: relu
default_policy_scale: 1.0
scale_lb: 0.1

# loss
loss:
loss_function: smooth_l1
gamma: 0.99
tau: 0.05
# IQL hyperparameter
temperature: 3.0
expectile: 0.7
1 change: 0 additions & 1 deletion examples/iql/online_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ optim:
batch_size: 256
lr_scheduler: ""
optim_steps_per_batch: 1000
policy_update_delay: 2

# Policy and model
model:
Expand Down
44 changes: 43 additions & 1 deletion examples/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
MultiStep,
TensorDictReplayBuffer,
)
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
from torchrl.data.replay_buffers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
from torchrl.envs import (
CatFrames,
Expand All @@ -22,6 +24,7 @@
NoopResetEnv,
ObservationNorm,
ParallelEnv,
RenameTransform,
Resize,
RewardScaling,
ToTensorImage,
Expand Down Expand Up @@ -169,7 +172,7 @@ def make_transformed_env_states(base_env, env_cfg):


def make_parallel_env(env_cfg, state_dict):
num_envs = env_cfg.num_envs
num_envs = env_cfg.num_eval_envs
env = make_transformed_env(
ParallelEnv(num_envs, EnvCreator(lambda: make_base_env(env_cfg))), env_cfg
)
Expand Down Expand Up @@ -242,6 +245,45 @@ def make_replay_buffer(rb_cfg):
)


def make_offline_replay_buffer(rb_cfg, state_dict):

data = D4RLExperienceReplay(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in our examples, we could default to the version that does not require the d4rl library, wdyt?
It's pretty annoying to install and the dataset works without it.

rb_cfg.dataset,
split_trajs=False,
batch_size=rb_cfg.batch_size,
sampler=SamplerWithoutReplacement(drop_last=False),
)
data.append_transform(
RewardScaling(
loc=state_dict["transforms.0.loc"],
scale=state_dict["transforms.0.scale"],
standard_normal=state_dict["transforms.0.standard_normal"],
)
)
data.append_transform(
RenameTransform(
["observation", ("next", "observation")],
["observation_vector", ("next", "observation_vector")],
)
)
data.append_transform(
ObservationNorm(
in_keys=["observation_vector", ("next", "observation_vector")],
loc=state_dict["transforms.2.loc"],
scale=state_dict["transforms.2.scale"],
standard_normal=state_dict["transforms.2.standard_normal"],
)
)
data.append_transform(
DoubleToFloat(
in_keys=["observation_vector", ("next", "observation_vector")],
in_keys_inv=[],
)
)

return data


# ====================================================================
# Model
# -----
Expand Down