Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Rename specs to simpler names #2368

Merged
merged 23 commits into from
Aug 7, 2024
Merged
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Aug 6, 2024
commit 496938b27ab298968a3f5d3e3c0032737571b231
3 changes: 3 additions & 0 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ def main(cfg: "DictConfig"): # noqa: F821
collector.update_policy_weights_()
sampling_start = time.time()

collector.shutdown()
if not test_env.is_closed:
test_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
Expand Down
3 changes: 3 additions & 0 deletions sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ def main(cfg: "DictConfig"): # noqa: F821
collector.update_policy_weights_()
sampling_start = time.time()

collector.shutdown()
if not test_env.is_closed:
test_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
Expand Down
4 changes: 4 additions & 0 deletions sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()

collector.shutdown()
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
Expand Down
4 changes: 4 additions & 0 deletions sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ def main(cfg: "DictConfig"): # noqa: F821
collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


Expand Down
2 changes: 2 additions & 0 deletions sota-implementations/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def main(cfg: "DictConfig"): # noqa: F821
log_metrics(logger, to_log, i)

pbar.close()
if not test_env.is_closed:
test_env.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")


Expand Down
2 changes: 2 additions & 0 deletions sota-implementations/decision_transformer/online_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def main(cfg: "DictConfig"): # noqa: F821
log_metrics(logger, to_log, i)

pbar.close()
if not test_env.is_closed:
test_env.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")


Expand Down
4 changes: 4 additions & 0 deletions sota-implementations/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()

collector.shutdown()
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
Expand Down
3 changes: 3 additions & 0 deletions sota-implementations/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()

collector.shutdown()
if not test_env.is_closed:
test_env.close()

end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
Expand Down
2 changes: 2 additions & 0 deletions sota-implementations/dqn/dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()

collector.shutdown()
if not test_env.is_closed:
test_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
Expand Down
4 changes: 4 additions & 0 deletions sota-implementations/iql/iql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def main(cfg: "DictConfig"): # noqa: F821
log_metrics(logger, to_log, i)

pbar.close()
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")


Expand Down
6 changes: 6 additions & 0 deletions sota-implementations/iql/iql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ def main(cfg: "DictConfig"): # noqa: F821
collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time

if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()

torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


Expand Down
6 changes: 6 additions & 0 deletions sota-implementations/multiagent/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,12 @@ def train(cfg: "DictConfig"): # noqa: F821
logger.experiment.log({}, commit=True)
sampling_start = time.time()

collector.shutdown()
if not env.is_closed:
env.close()
if not env_test.is_closed:
env_test.close()


if __name__ == "__main__":
train()
5 changes: 5 additions & 0 deletions sota-implementations/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,11 @@ def train(cfg: "DictConfig"): # noqa: F821
if cfg.logger.backend == "wandb":
logger.experiment.log({}, commit=True)
sampling_start = time.time()
collector.shutdown()
if not env.is_closed:
env.close()
if not env_test.is_closed:
env_test.close()


if __name__ == "__main__":
Expand Down
5 changes: 5 additions & 0 deletions sota-implementations/multiagent/mappo_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,11 @@ def train(cfg: "DictConfig"): # noqa: F821
if cfg.logger.backend == "wandb":
logger.experiment.log({}, commit=True)
sampling_start = time.time()
collector.shutdown()
if not env.is_closed:
env.close()
if not env_test.is_closed:
env_test.close()


if __name__ == "__main__":
Expand Down
5 changes: 5 additions & 0 deletions sota-implementations/multiagent/qmix_vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,11 @@ def train(cfg: "DictConfig"): # noqa: F821
if cfg.logger.backend == "wandb":
logger.experiment.log({}, commit=True)
sampling_start = time.time()
collector.shutdown()
if not env.is_closed:
env.close()
if not env_test.is_closed:
env_test.close()


if __name__ == "__main__":
Expand Down
5 changes: 5 additions & 0 deletions sota-implementations/multiagent/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ def train(cfg: "DictConfig"): # noqa: F821
if cfg.logger.backend == "wandb":
logger.experiment.log({}, commit=True)
sampling_start = time.time()
collector.shutdown()
if not env.is_closed:
env.close()
if not env_test.is_closed:
env_test.close()


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()

collector.shutdown()
if not test_env.is_closed:
test_env.close()

end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
Expand Down
3 changes: 3 additions & 0 deletions sota-implementations/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ def main(cfg: "DictConfig"): # noqa: F821
collector.update_policy_weights_()
sampling_start = time.time()

collector.shutdown()
if not test_env.is_closed:
test_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
Expand Down
4 changes: 4 additions & 0 deletions sota-implementations/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()

collector.shutdown()
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
Expand Down
4 changes: 4 additions & 0 deletions sota-implementations/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()

collector.shutdown()
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
Expand Down
2 changes: 2 additions & 0 deletions sota-implementations/td3_bc/td3_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def main(cfg: "DictConfig"): # noqa: F821
if logger is not None:
log_metrics(logger, to_log, i)

if not eval_env.is_closed:
eval_env.close()
pbar.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")

Expand Down
Loading