Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Logger #1858

Merged
merged 49 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
012f9c0
init
vmoens Jan 19, 2024
3485d6d
Merge remote-tracking branch 'origin/main' into remove-deprecs
vmoens Jan 31, 2024
b1344e6
Merge remote-tracking branch 'origin/main' into remove-deprecs
vmoens Jan 31, 2024
4839c09
amend
vmoens Jan 31, 2024
95ac71d
amend
vmoens Jan 31, 2024
067f4da
Merge remote-tracking branch 'origin/main' into remove-deprecs
vmoens Jan 31, 2024
e31e41b
amend
vmoens Jan 31, 2024
d7d2621
amend
vmoens Jan 31, 2024
f5b195d
amend
vmoens Jan 31, 2024
96d3a18
amend
vmoens Jan 31, 2024
4b41b5c
amend
vmoens Jan 31, 2024
755b7f4
amend
vmoens Jan 31, 2024
4b8c89b
amend
vmoens Jan 31, 2024
2fc17c2
amend
vmoens Jan 31, 2024
ba8dada
amend
vmoens Jan 31, 2024
3120f22
amend
vmoens Jan 31, 2024
ff27094
amend
vmoens Jan 31, 2024
6bdb2c4
amend
vmoens Jan 31, 2024
5dbc588
init
vmoens Jan 31, 2024
f30e02a
amend
vmoens Jan 31, 2024
b1c69b1
amend
vmoens Jan 31, 2024
ab07abe
amend
vmoens Jan 31, 2024
c7e8278
amend
vmoens Jan 31, 2024
f984105
amend
vmoens Jan 31, 2024
deb8b2e
amend
vmoens Jan 31, 2024
d0efa38
Merge remote-tracking branch 'origin/remove-deprecs' into logger
vmoens Jan 31, 2024
62b1dc8
amend
vmoens Jan 31, 2024
bd498ab
amend
vmoens Jan 31, 2024
b35c26a
Merge branch 'remove-deprecs' into logger
vmoens Jan 31, 2024
bf4a0d9
amend
vmoens Jan 31, 2024
1903d10
Merge branch 'remove-deprecs' into logger
vmoens Jan 31, 2024
e4bdde2
amend
vmoens Jan 31, 2024
ba63298
Merge remote-tracking branch 'origin/main' into logger
vmoens Jan 31, 2024
9c36712
amend
vmoens Jan 31, 2024
fdc4557
amend
vmoens Jan 31, 2024
c8d6441
amend
vmoens Jan 31, 2024
f32ce83
amend
vmoens Jan 31, 2024
f68dd4f
amend
vmoens Jan 31, 2024
2c2c9fb
amend
vmoens Jan 31, 2024
31b866a
amend
vmoens Jan 31, 2024
656e75b
amend
vmoens Jan 31, 2024
22cd51b
empty
vmoens Jan 31, 2024
23171f7
amend
vmoens Jan 31, 2024
707747e
amend
vmoens Jan 31, 2024
f44fe53
amend
vmoens Jan 31, 2024
2c737d6
amend
vmoens Jan 31, 2024
2797de8
amend
vmoens Jan 31, 2024
03c201c
amend
vmoens Jan 31, 2024
4b746f6
amend
vmoens Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Jan 31, 2024
commit 62b1dc8434da524855c9736c015d091d4c79c3a7
16 changes: 8 additions & 8 deletions examples/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def main(cfg: "DictConfig"): # noqa: F821
total_frames = cfg.collector.total_frames // frame_skip
frames_per_batch = cfg.collector.frames_per_batch // frame_skip
mini_batch_size = cfg.loss.mini_batch_size // frame_skip
test_interval = cfg.torchrl_logger.test_interval // frame_skip
test_interval = cfg.logger.test_interval // frame_skip

# Create models (check utils_atari.py)
actor, critic, critic_head = make_ppo_models(cfg.env.env_name)
Expand Down Expand Up @@ -87,20 +87,20 @@ def main(cfg: "DictConfig"): # noqa: F821
eps=cfg.optim.eps,
)

# Create torchrl_logger
# Create logger
logger = None
if cfg.torchrl_logger.backend:
if cfg.logger.backend:
exp_name = generate_exp_name(
"A2C", f"{cfg.torchrl_logger.exp_name}_{cfg.env.env_name}"
"A2C", f"{cfg.logger.exp_name}_{cfg.env.env_name}"
)
logger = get_logger(
cfg.torchrl_logger.backend,
cfg.logger.backend,
logger_name="a2c",
experiment_name=exp_name,
wandb_kwargs={
"config": dict(cfg),
"project": cfg.torchrl_logger.project_name,
"group": cfg.torchrl_logger.group_name,
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)

Expand Down Expand Up @@ -201,7 +201,7 @@ def main(cfg: "DictConfig"): # noqa: F821
actor.eval()
eval_start = time.time()
test_rewards = eval_model(
actor, test_env, num_episodes=cfg.torchrl_logger.num_test_episodes
actor, test_env, num_episodes=cfg.logger.num_test_episodes
)
eval_time = time.time() - eval_start
log_info.update(
Expand Down
18 changes: 9 additions & 9 deletions examples/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,20 @@ def main(cfg: "DictConfig"): # noqa: F821
actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr)
critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr)

# Create torchrl_logger
# Create logger
logger = None
if cfg.torchrl_logger.backend:
if cfg.logger.backend:
exp_name = generate_exp_name(
"A2C", f"{cfg.torchrl_logger.exp_name}_{cfg.env.env_name}"
"A2C", f"{cfg.logger.exp_name}_{cfg.env.env_name}"
)
logger = get_logger(
cfg.torchrl_logger.backend,
cfg.logger.backend,
logger_name="a2c",
experiment_name=exp_name,
wandb_kwargs={
"config": dict(cfg),
"project": cfg.torchrl_logger.project_name,
"group": cfg.torchrl_logger.group_name,
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)

Expand Down Expand Up @@ -180,13 +180,13 @@ def main(cfg: "DictConfig"): # noqa: F821

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
if ((i - 1) * frames_in_batch) // cfg.torchrl_logger.test_interval < (
if ((i - 1) * frames_in_batch) // cfg.logger.test_interval < (
i * frames_in_batch
) // cfg.torchrl_logger.test_interval:
) // cfg.logger.test_interval:
actor.eval()
eval_start = time.time()
test_rewards = eval_model(
actor, test_env, num_episodes=cfg.torchrl_logger.num_test_episodes
actor, test_env, num_episodes=cfg.logger.num_test_episodes
)
eval_time = time.time() - eval_start
log_info.update(
Expand Down
2 changes: 1 addition & 1 deletion examples/a2c/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ collector:
frames_per_batch: 80
total_frames: 40_000_000

# torchrl_logger
# logger
logger:
backend: wandb
project_name: torchrl_example_a2c
Expand Down
2 changes: 1 addition & 1 deletion examples/a2c/config_mujoco.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ collector:
frames_per_batch: 64
total_frames: 1_000_000

# torchrl_logger
# logger
logger:
backend: wandb
project_name: torchrl_example_a2c
Expand Down
20 changes: 10 additions & 10 deletions examples/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@

@hydra.main(config_path=".", config_name="offline_config", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
# Create torchrl_logger
exp_name = generate_exp_name("CQL-offline", cfg.torchrl_logger.exp_name)
# Create logger
exp_name = generate_exp_name("CQL-offline", cfg.logger.exp_name)
logger = None
if cfg.torchrl_logger.backend:
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.torchrl_logger.backend,
logger_type=cfg.logger.backend,
logger_name="cql_logging",
experiment_name=exp_name,
wandb_kwargs={
"mode": cfg.torchrl_logger.mode,
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.torchrl_logger.project_name,
"group": cfg.torchrl_logger.group_name,
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)
# Set seeds
Expand All @@ -52,7 +52,7 @@ def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(cfg.optim.device)

# Create env
train_env, eval_env = make_environment(cfg, cfg.torchrl_logger.eval_envs)
train_env, eval_env = make_environment(cfg, cfg.logger.eval_envs)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
Expand All @@ -75,8 +75,8 @@ def main(cfg: "DictConfig"): # noqa: F821

gradient_steps = cfg.optim.gradient_steps
policy_eval_start = cfg.optim.policy_eval_start
evaluation_interval = cfg.torchrl_logger.eval_iter
eval_steps = cfg.torchrl_logger.eval_steps
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps

# Training loop
start_time = time.time()
Expand Down
16 changes: 8 additions & 8 deletions examples/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,19 @@

@hydra.main(version_base="1.1", config_path=".", config_name="online_config")
def main(cfg: "DictConfig"): # noqa: F821
# Create torchrl_logger
exp_name = generate_exp_name("CQL-online", cfg.torchrl_logger.exp_name)
# Create logger
exp_name = generate_exp_name("CQL-online", cfg.logger.exp_name)
logger = None
if cfg.torchrl_logger.backend:
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.torchrl_logger.backend,
logger_type=cfg.logger.backend,
logger_name="cql_logging",
experiment_name=exp_name,
wandb_kwargs={
"mode": cfg.torchrl_logger.mode,
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.torchrl_logger.project_name,
"group": cfg.torchrl_logger.group_name,
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)

Expand Down Expand Up @@ -99,7 +99,7 @@ def main(cfg: "DictConfig"): # noqa: F821
* cfg.optim.utd_ratio
)
prb = cfg.replay_buffer.prb
eval_iter = cfg.torchrl_logger.eval_iter
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
eval_rollout_steps = cfg.collector.max_frames_per_traj

Expand Down
14 changes: 7 additions & 7 deletions examples/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@
def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(cfg.optim.device)

# Create torchrl_logger
exp_name = generate_exp_name("DiscreteCQL", cfg.torchrl_logger.exp_name)
# Create logger
exp_name = generate_exp_name("DiscreteCQL", cfg.logger.exp_name)
logger = None
if cfg.torchrl_logger.backend:
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.torchrl_logger.backend,
logger_type=cfg.logger.backend,
logger_name="discretecql_logging",
experiment_name=exp_name,
wandb_kwargs={
"mode": cfg.torchrl_logger.mode,
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.torchrl_logger.project_name,
"project": cfg.logger.project_name,
},
)

Expand Down Expand Up @@ -92,7 +92,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)
prb = cfg.replay_buffer.prb
eval_rollout_steps = cfg.env.max_episode_steps
eval_iter = cfg.torchrl_logger.eval_iter
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch

start_time = sampling_start = time.time()
Expand Down
2 changes: 1 addition & 1 deletion examples/cql/offline_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ env:
seed: 0
backend: gymnasium

# torchrl_logger
# logger
logger:
backend: wandb
project_name: torchrl_example_cql
Expand Down
2 changes: 1 addition & 1 deletion examples/cql/online_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ collector:
max_frames_per_traj: 1000


# torchrl_logger
# logger
logger:
backend: wandb
project_name: torchrl_example_cql
Expand Down
16 changes: 8 additions & 8 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,19 @@
def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(cfg.network.device)

# Create torchrl_logger
exp_name = generate_exp_name("DDPG", cfg.torchrl_logger.exp_name)
# Create logger
exp_name = generate_exp_name("DDPG", cfg.logger.exp_name)
logger = None
if cfg.torchrl_logger.backend:
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.torchrl_logger.backend,
logger_type=cfg.logger.backend,
logger_name="ddpg_logging",
experiment_name=exp_name,
wandb_kwargs={
"mode": cfg.torchrl_logger.mode,
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.torchrl_logger.project_name,
"group": cfg.torchrl_logger.group_name,
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)

Expand Down Expand Up @@ -94,7 +94,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)
prb = cfg.replay_buffer.prb
frames_per_batch = cfg.collector.frames_per_batch
eval_iter = cfg.torchrl_logger.eval_iter
eval_iter = cfg.logger.eval_iter
eval_rollout_steps = cfg.env.max_episode_steps

sampling_start = time.time()
Expand Down
4 changes: 2 additions & 2 deletions examples/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def main(cfg: "DictConfig"): # noqa: F821

pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
clip_grad = cfg.optim.clip_grad
eval_steps = cfg.torchrl_logger.eval_steps
pretrain_log_interval = cfg.torchrl_logger.pretrain_log_interval
eval_steps = cfg.logger.eval_steps
pretrain_log_interval = cfg.logger.pretrain_log_interval
reward_scaling = cfg.env.reward_scaling

torchrl_logger.info(" ***Pretraining*** ")
Expand Down
4 changes: 2 additions & 2 deletions examples/decision_transformer/online_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def main(cfg: "DictConfig"): # noqa: F821

pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
clip_grad = cfg.optim.clip_grad
eval_steps = cfg.torchrl_logger.eval_steps
pretrain_log_interval = cfg.torchrl_logger.pretrain_log_interval
eval_steps = cfg.logger.eval_steps
pretrain_log_interval = cfg.logger.pretrain_log_interval
reward_scaling = cfg.env.reward_scaling

torchrl_logger.info(" ***Pretraining*** ")
Expand Down
12 changes: 6 additions & 6 deletions examples/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,19 +493,19 @@ def make_dt_optimizer(optim_cfg, loss_module):


def make_logger(cfg):
if not cfg.torchrl_logger.backend:
if not cfg.logger.backend:
return None
exp_name = generate_exp_name(
cfg.torchrl_logger.model_name, cfg.torchrl_logger.exp_name
cfg.logger.model_name, cfg.logger.exp_name
)
logger = get_logger(
cfg.torchrl_logger.backend,
logger_name=cfg.torchrl_logger.model_name,
cfg.logger.backend,
logger_name=cfg.logger.model_name,
experiment_name=exp_name,
wandb_kwargs={
"config": dict(cfg),
"project": cfg.torchrl_logger.project_name,
"group": cfg.torchrl_logger.group_name,
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)
return logger
Expand Down
14 changes: 7 additions & 7 deletions examples/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,18 @@ def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(cfg.network.device)

# Create torchrl_logger
exp_name = generate_exp_name("DiscreteSAC", cfg.torchrl_logger.exp_name)
exp_name = generate_exp_name("DiscreteSAC", cfg.logger.exp_name)
logger = None
if cfg.torchrl_logger.backend:
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.torchrl_logger.backend,
logger_type=cfg.logger.backend,
logger_name="DiscreteSAC_logging",
experiment_name=exp_name,
wandb_kwargs={
"mode": cfg.torchrl_logger.mode,
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.torchrl_logger.project_name,
"group": cfg.torchrl_logger.group_name,
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)

Expand Down Expand Up @@ -96,7 +96,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)
prb = cfg.replay_buffer.prb
eval_rollout_steps = cfg.env.max_episode_steps
eval_iter = cfg.torchrl_logger.eval_iter
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch

sampling_start = time.time()
Expand Down
2 changes: 1 addition & 1 deletion examples/dqn/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ buffer:
batch_size: 32
scratch_dir: null

# torchrl_logger
# logger
logger:
backend: wandb
project_name: torchrl_example_dqn
Expand Down
2 changes: 1 addition & 1 deletion examples/dqn/config_cartpole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ buffer:
buffer_size: 10_000
batch_size: 128

# torchrl_logger
# logger
logger:
backend: wandb
project_name: torchrl_example_dqn
Expand Down
Loading
Loading