Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Implicit Q-Learning (IQL) #933

Merged
merged 19 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add wandb logging mode
  • Loading branch information
BY571 committed Mar 1, 2023
commit b42f3c46ebbdad4928006b796972f43d775077da
1 change: 1 addition & 0 deletions examples/td3/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ gSDE: 0

# Logging
logger: wandb
mode: online

# Extra
batch_transform: 1
Expand Down
11 changes: 4 additions & 7 deletions examples/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def make_replay_buffer(
):
if prb:
replay_buffer = TensorDictPrioritizedReplayBuffer(
buffer_size,
alpha=0.7,
beta=0.5,
pin_memory=False,
Expand All @@ -81,7 +80,6 @@ def make_replay_buffer(
)
else:
replay_buffer = TensorDictReplayBuffer(
buffer_size,
pin_memory=False,
prefetch=make_replay_buffer,
storage=LazyMemmapStorage(
Expand All @@ -106,7 +104,10 @@ def main(cfg: "DictConfig"): # noqa: F821

exp_name = generate_exp_name("TD3", cfg.exp_name)
logger = get_logger(
logger_type=cfg.logger, logger_name="td3_logging", experiment_name=exp_name
logger_type=cfg.logger,
logger_name="td3_logging",
experiment_name=exp_name,
wandb_kwargs={"mode": cfg.mode},
)

torch.manual_seed(cfg.seed)
Expand Down Expand Up @@ -150,7 +151,6 @@ def main(cfg: "DictConfig"): # noqa: F821
dist_kwargs = {
"min": action_spec.space.minimum,
"max": action_spec.space.maximum,
"tanh_loc": False,
}

in_keys_actor = in_keys
Expand Down Expand Up @@ -258,7 +258,6 @@ def main(cfg: "DictConfig"): # noqa: F821
target_net_updater.init_()

collected_frames = 0
episodes = 0
pbar = tqdm.tqdm(total=cfg.total_frames)
r0 = None
q_loss = None
Expand All @@ -282,7 +281,6 @@ def main(cfg: "DictConfig"): # noqa: F821
current_frames = tensordict.numel()
replay_buffer.extend(tensordict.cpu())
collected_frames += current_frames
episodes += torch.unique(tensordict["traj_ids"]).shape[0]

# optimization steps
if collected_frames >= cfg.init_random_frames:
Expand Down Expand Up @@ -323,7 +321,6 @@ def main(cfg: "DictConfig"): # noqa: F821
train_log = {
"train_reward": rewards[-1][1],
"collected_frames": collected_frames,
"episodes": episodes,
}
if q_loss is not None:
train_log.update(
Expand Down