diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index e75f4b1bc1c..1d11d481e3c 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -32,7 +32,7 @@ python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_te # ================================ gym 0.23 ========================================== # # With batched environments -python .github/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/dt.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/decision_transformer/dt.py \ optim.pretrain_gradient_steps=55 \ optim.updates_per_episode=3 \ optim.warmup_steps=10 \ @@ -40,18 +40,18 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans logger.backend= \ env.backend=gymnasium \ env.name=HalfCheetah-v4 -python .github/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/online_dt.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/decision_transformer/online_dt.py \ optim.pretrain_gradient_steps=55 \ optim.updates_per_episode=3 \ optim.warmup_steps=10 \ optim.device=cuda:0 \ env.backend=gymnasium \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_offline.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/iql_offline.py \ optim.gradient_steps=55 \ optim.device=cuda:0 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_offline.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_offline.py \ optim.gradient_steps=55 \ optim.device=cuda:0 \ logger.backend= @@ -59,13 +59,13 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_offlin # ==================================================================================== # # ================================ Gymnasium ========================================= # -python .github/unittest/helpers/coverage_run_parallel.py examples/impala/impala_single_node.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/impala/impala_single_node.py \ collector.total_frames=80 \ collector.frames_per_batch=20 \ collector.num_workers=1 \ logger.backend= \ logger.test_interval=10 -python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_mujoco.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ppo/ppo_mujoco.py \ env.env_name=HalfCheetah-v4 \ collector.total_frames=40 \ collector.frames_per_batch=20 \ @@ -73,14 +73,14 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_mujoco loss.ppo_epochs=2 \ logger.backend= \ logger.test_interval=10 -python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_atari.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ppo/ppo_atari.py \ collector.total_frames=80 \ collector.frames_per_batch=20 \ loss.mini_batch_size=20 \ loss.ppo_epochs=2 \ logger.backend= \ logger.test_interval=10 -python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ddpg/ddpg.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ optim.batch_size=10 \ @@ -94,20 +94,20 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ logger.backend= # record_video=True \ # record_frames=4 \ -python .github/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c_mujoco.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/a2c/a2c_mujoco.py \ env.env_name=HalfCheetah-v4 \ collector.total_frames=40 \ collector.frames_per_batch=20 \ loss.mini_batch_size=10 \ logger.backend= \ logger.test_interval=40 -python .github/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c_atari.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/a2c/a2c_atari.py \ collector.total_frames=80 \ collector.frames_per_batch=20 \ loss.mini_batch_size=20 \ logger.backend= \ logger.test_interval=40 -python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn_atari.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dqn/dqn_atari.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ collector.frames_per_batch=16 \ @@ -116,7 +116,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn_atari. loss.num_updates=1 \ logger.backend= \ buffer.buffer_size=120 -python .github/unittest/helpers/coverage_run_parallel.py examples/cql/discrete_cql_online.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/discrete_cql_online.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ optim.batch_size=10 \ @@ -125,7 +125,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/cql/discrete_c collector.device=cuda:0 \ replay_buffer.size=120 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/redq/redq.py \ num_workers=4 \ collector.total_frames=48 \ collector.init_random_frames=10 \ @@ -138,7 +138,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \ logger.record_frames=4 \ buffer.size=120 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/sac/sac.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ collector.frames_per_batch=16 \ @@ -150,7 +150,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ env.name=Pendulum-v1 \ network.device=cuda:0 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py examples/discrete_sac/discrete_sac.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/discrete_sac/discrete_sac.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ collector.frames_per_batch=16 \ @@ -166,7 +166,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/discrete_sac/d logger.backend= # logger.record_video=True \ # logger.record_frames=4 \ -python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \ total_frames=200 \ init_random_frames=10 \ batch_size=10 \ @@ -180,7 +180,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreame record_frames=4 \ buffer_size=120 \ rssm_hidden_dim=17 -python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ optim.batch_size=10 \ @@ -193,7 +193,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \ logger.mode=offline \ env.name=Pendulum-v1 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/iql_online.py \ collector.total_frames=48 \ optim.batch_size=10 \ collector.frames_per_batch=16 \ @@ -202,7 +202,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online collector.device=cuda:0 \ logger.mode=offline \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py examples/iql/discrete_iql.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/discrete_iql.py \ collector.total_frames=48 \ optim.batch_size=10 \ collector.frames_per_batch=16 \ @@ -211,7 +211,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/iql/discrete_i collector.device=cuda:0 \ logger.mode=offline \ logger.backend= - python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_online.py \ + python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_online.py \ collector.total_frames=48 \ optim.batch_size=10 \ collector.frames_per_batch=16 \ @@ -222,7 +222,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/iql/discrete_i logger.backend= # With single envs -python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \ total_frames=200 \ init_random_frames=10 \ batch_size=10 \ @@ -236,7 +236,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreame record_frames=4 \ buffer_size=120 \ rssm_hidden_dim=17 -python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ddpg/ddpg.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ optim.batch_size=10 \ @@ -250,7 +250,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ logger.backend= # record_video=True \ # record_frames=4 \ -python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn_atari.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dqn/dqn_atari.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ collector.frames_per_batch=16 \ @@ -259,7 +259,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn_atari. loss.num_updates=1 \ logger.backend= \ buffer.buffer_size=120 -python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/redq/redq.py \ num_workers=2 \ collector.total_frames=48 \ collector.init_random_frames=10 \ @@ -272,7 +272,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \ logger.record_frames=4 \ buffer.size=120 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/iql_online.py \ collector.total_frames=48 \ optim.batch_size=10 \ collector.frames_per_batch=16 \ @@ -281,7 +281,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online optim.device=cuda:0 \ collector.device=cuda:0 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_online.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_online.py \ collector.total_frames=48 \ optim.batch_size=10 \ collector.frames_per_batch=16 \ @@ -290,7 +290,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_online optim.device=cuda:0 \ collector.device=cuda:0 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ collector.frames_per_batch=16 \ @@ -302,38 +302,38 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \ env.name=Pendulum-v1 \ network.device=cuda:0 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/mappo_ippo.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/multiagent/mappo_ippo.py \ collector.n_iters=2 \ collector.frames_per_batch=200 \ train.num_epochs=3 \ train.minibatch_size=100 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/maddpg_iddpg.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/multiagent/maddpg_iddpg.py \ collector.n_iters=2 \ collector.frames_per_batch=200 \ train.num_epochs=3 \ train.minibatch_size=100 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/iql.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/multiagent/iql.py \ collector.n_iters=2 \ collector.frames_per_batch=200 \ train.num_epochs=3 \ train.minibatch_size=100 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/qmix_vdn.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/multiagent/qmix_vdn.py \ collector.n_iters=2 \ collector.frames_per_batch=200 \ train.num_epochs=3 \ train.minibatch_size=100 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/sac.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/multiagent/sac.py \ collector.n_iters=2 \ collector.frames_per_batch=200 \ train.num_epochs=3 \ train.minibatch_size=100 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py examples/bandits/dqn.py --n_steps=100 +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/bandits/dqn.py --n_steps=100 ## RLHF # RLHF tests are executed in the dedicated workflow diff --git a/README.md b/README.md index 6adbc2decfe..53ce0127bec 100644 --- a/README.md +++ b/README.md @@ -315,7 +315,7 @@ And it is `functorch` and `torch.compile` compatible! ``` - Check our [distributed collector examples](examples/distributed/collectors) to + Check our [distributed collector examples](sota-implementations/distributed/collectors) to learn more about ultra-fast data collection with TorchRL. - efficient(2) and generic(1) [replay buffers](torchrl/data/replay_buffers/replay_buffers.py) with modularized storage: @@ -496,22 +496,22 @@ If you would like to contribute to new features, check our [call for contributio ## Examples, tutorials and demos A series of [examples](examples/) are provided with an illustrative purpose: -- [DQN](examples/dqn) -- [DDPG](examples/ddpg/ddpg.py) -- [IQL](examples/iql/iql.py) -- [CQL](examples/iql/cql.py) -- [TD3](examples/td3/td3.py) +- [DQN](sota-implementations/dqn) +- [DDPG](sota-implementations/ddpg/ddpg.py) +- [IQL](sota-implementations/iql/iql.py) +- [CQL](sota-implementations/iql/cql.py) +- [TD3](sota-implementations/td3/td3.py) - [A2C](examples/a2c_old/a2c.py) -- [PPO](examples/ppo/ppo.py) -- [SAC](examples/sac/sac.py) -- [REDQ](examples/redq/redq.py) -- [Dreamer](examples/dreamer/dreamer.py) -- [Decision Transformers](examples/decision_transformer) +- [PPO](sota-implementations/ppo/ppo.py) +- [SAC](sota-implementations/sac/sac.py) +- [REDQ](sota-implementations/redq/redq.py) +- [Dreamer](sota-implementations/dreamer/dreamer.py) +- [Decision Transformers](sota-implementations/decision_transformer) - [RLHF](examples/rlhf) and many more to come! -Check the [examples markdown](examples/EXAMPLES.md) directory for more details +Check the [examples markdown](sota-implementations/SOTA-IMPLEMENTATIONS.md) directory for more details about handling the various configuration settings. We also provide [tutorials and demos](https://pytorch.org/rl/#tutorials) that give a sense of diff --git a/examples/EXAMPLES.md b/examples/EXAMPLES.md deleted file mode 100644 index f875829b6e6..00000000000 --- a/examples/EXAMPLES.md +++ /dev/null @@ -1,108 +0,0 @@ -# Examples - -We provide examples to train the following algorithms: -- [DQN](dqn/dqn.py) -- [DDPG](ddpg/ddpg.py) -- [SAC](sac/sac.py) -- [REDQ](redq/redq.py) -- [PPO](ppo/ppo.py) - -To run these examples, make sure you have installed hydra: -``` -pip install hydra-core -``` - -Scripts can be run from the directory of interest using: -``` -python sac.py -``` -or similar. Hyperparameters can be easily changed by providing the arguments to hydra: -``` -python sac.py collector.frames_per_batch=63 -``` -# Results - -Here we can see some results for the SAC and REDQ algorithm. -We average the results over 5 different seeds and plot the standard error. -## Gym's HalfCheetah-v4 - -

- -

-To reproduce a single run: - -``` -python sac/sac.py env.name="HalfCheetah-v4" env.task="" env.library="gym" -``` - -``` -python redq/redq.py env.name="HalfCheetah-v4" env.library="gymnasium" -``` - - -## dm_control's cheetah-run - -

- -

-To reproduce a single run: - -``` -python sac/sac.py env.name="cheetah" env.task="run" env.library="dm_control" -``` - -``` -python redq/redq.py env.name="cheetah" env.task="run" env.library="dm_control" -``` - -[//]: # (TODO: adapt these scripts) -[//]: # (## Gym's Ant-v4) - -[//]: # () -[//]: # (

) - -[//]: # () - -[//]: # (

) - -[//]: # (To reproduce a single run:) - -[//]: # () -[//]: # (```) - -[//]: # (python sac/sac.py env.name="Ant-v4" env.task="" env.library="gym") - -[//]: # (```) - -[//]: # () -[//]: # (``` ) - -[//]: # (python redq/redq.py env_name="Ant-v4" env_task="" env_library="gym") - -[//]: # (```) - -[//]: # () -[//]: # (## Gym's Walker2D-v4) - -[//]: # () -[//]: # (

) - -[//]: # () - -[//]: # (

) - -[//]: # (To reproduce a single run:) - -[//]: # () -[//]: # (```) - -[//]: # (python sac/sac.py env_name="Walker2D-v4" env_task="" env_library="gym") - -[//]: # (```) - -[//]: # () -[//]: # (``` ) - -[//]: # (python redq/redq.py env_name="Walker2D-v4" env_task="" env_library="gym") - -[//]: # (```) diff --git a/sota-check/run_a2c_atari.sh b/sota-check/run_a2c_atari.sh index 610cf5389f8..225ba717ed2 100644 --- a/sota-check/run_a2c_atari.sh +++ b/sota-check/run_a2c_atari.sh @@ -12,7 +12,7 @@ project_name="torchrl-example-check-$current_commit" group_name="a2c_atari" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/a2c/a2c_atari.py \ +python $PYTHONPATH/sota-implementations/a2c/a2c_atari.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_a2c_mujoco.sh b/sota-check/run_a2c_mujoco.sh index f26bc96fe01..ce316c9cfb1 100644 --- a/sota-check/run_a2c_mujoco.sh +++ b/sota-check/run_a2c_mujoco.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="a2c_mujoco" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/a2c/a2c_mujoco.py \ +python $PYTHONPATH/sota-implementations/a2c/a2c_mujoco.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_cql_offline.sh b/sota-check/run_cql_offline.sh index fa3a42c7429..7dc36a91192 100644 --- a/sota-check/run_cql_offline.sh +++ b/sota-check/run_cql_offline.sh @@ -12,7 +12,7 @@ project_name="torchrl-example-check-$current_commit" group_name="cql_offline" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/cql/cql_offline.py \ +python $PYTHONPATH/sota-implementations/cql/cql_offline.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_cql_online.sh b/sota-check/run_cql_online.sh index 78548d9f418..f9bd31e3760 100644 --- a/sota-check/run_cql_online.sh +++ b/sota-check/run_cql_online.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="cql_online" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/cql/cql_online.py \ +python $PYTHONPATH/sota-implementations/cql/cql_online.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_ddpg.sh b/sota-check/run_ddpg.sh index 7131db8a6e7..0d7b1b873e1 100644 --- a/sota-check/run_ddpg.sh +++ b/sota-check/run_ddpg.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="ddpg" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/ddpg/ddpg.py \ +python $PYTHONPATH/sota-implementations/ddpg/ddpg.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_discrete_sac.sh b/sota-check/run_discrete_sac.sh index dfb6a68ce02..7defd381109 100644 --- a/sota-check/run_discrete_sac.sh +++ b/sota-check/run_discrete_sac.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="discrete_sac" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/discrete_sac/discrete_sac.py \ +python $PYTHONPATH/sota-implementations/discrete_sac/discrete_sac.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_dqn_atari.sh b/sota-check/run_dqn_atari.sh index 35aa2adb3be..eb43968e133 100644 --- a/sota-check/run_dqn_atari.sh +++ b/sota-check/run_dqn_atari.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="dqn_atari" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/dqn/dqn_atari.py \ +python $PYTHONPATH/sota-implementations/dqn/dqn_atari.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_dqn_cartpole.sh b/sota-check/run_dqn_cartpole.sh index cfe954a4f09..ece33cd297e 100644 --- a/sota-check/run_dqn_cartpole.sh +++ b/sota-check/run_dqn_cartpole.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="dqn_cartpole" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/dqn/dqn_cartpole.py \ +python $PYTHONPATH/sota-implementations/dqn/dqn_cartpole.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_dt.sh b/sota-check/run_dt.sh index 41ec685664d..37159bbb27d 100644 --- a/sota-check/run_dt.sh +++ b/sota-check/run_dt.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="dt_offline" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/decision_transformer/dt.py \ +python $PYTHONPATH/sota-implementations/decision_transformer/dt.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_dt_online.sh b/sota-check/run_dt_online.sh index 2f116aa3bcf..fa9947707ce 100644 --- a/sota-check/run_dt_online.sh +++ b/sota-check/run_dt_online.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="dt_online" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/decision_transformer/online_dt.py \ +python $PYTHONPATH/sota-implementations/decision_transformer/online_dt.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_impala_single_node.sh b/sota-check/run_impala_single_node.sh index 3dc3cd56ac2..23bcbc83fcf 100644 --- a/sota-check/run_impala_single_node.sh +++ b/sota-check/run_impala_single_node.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="impala_1node" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/impala/impala_single_node.py \ +python $PYTHONPATH/sota-implementations/impala/impala_single_node.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_iql_discrete.sh b/sota-check/run_iql_discrete.sh index b659ed6dc31..9ad8de024d4 100644 --- a/sota-check/run_iql_discrete.sh +++ b/sota-check/run_iql_discrete.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="iql_discrete" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/iql/discrete_iql.py \ +python $PYTHONPATH/sota-implementations/iql/discrete_iql.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_iql_offline.sh b/sota-check/run_iql_offline.sh index bd4ef8f6e69..dc50babd13d 100644 --- a/sota-check/run_iql_offline.sh +++ b/sota-check/run_iql_offline.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="iql_offline" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/iql/iql_offline.py \ +python $PYTHONPATH/sota-implementations/iql/iql_offline.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_iql_online.sh b/sota-check/run_iql_online.sh index 702d2b8cbff..cdaf0a989a7 100644 --- a/sota-check/run_iql_online.sh +++ b/sota-check/run_iql_online.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="iql_online" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/iql/iql_online.py \ +python $PYTHONPATH/sota-implementations/iql/iql_online.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_multiagent_iddpg.sh b/sota-check/run_multiagent_iddpg.sh index 4629fbff228..c30bcdf4068 100644 --- a/sota-check/run_multiagent_iddpg.sh +++ b/sota-check/run_multiagent_iddpg.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="marl_iddpg" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/multiagent/maddpg_iddpg.py \ +python $PYTHONPATH/sota-implementations/multiagent/maddpg_iddpg.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_multiagent_ippo.sh b/sota-check/run_multiagent_ippo.sh index 036f739e2e2..8e0dd880f56 100644 --- a/sota-check/run_multiagent_ippo.sh +++ b/sota-check/run_multiagent_ippo.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="mappo_ippo" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/multiagent/mappo_ippo.py \ +python $PYTHONPATH/sota-implementations/multiagent/mappo_ippo.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_multiagent_iql.sh b/sota-check/run_multiagent_iql.sh index f5bb6a7af23..dc5a05df8dd 100644 --- a/sota-check/run_multiagent_iql.sh +++ b/sota-check/run_multiagent_iql.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="marl_iql" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/multiagent/iql.py \ +python $PYTHONPATH/sota-implementations/multiagent/iql.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_multiagent_qmix.sh b/sota-check/run_multiagent_qmix.sh index 08b32ce257a..61bb2e3d281 100644 --- a/sota-check/run_multiagent_qmix.sh +++ b/sota-check/run_multiagent_qmix.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="marl_qmix_vdn" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/multiagent/qmix_vdn.py \ +python $PYTHONPATH/sota-implementations/multiagent/qmix_vdn.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_multiagent_sac.sh b/sota-check/run_multiagent_sac.sh index 10e1bbb2d4d..77e3ec6ca84 100644 --- a/sota-check/run_multiagent_sac.sh +++ b/sota-check/run_multiagent_sac.sh @@ -12,7 +12,7 @@ project_name="torchrl-example-check-$current_commit" group_name="marl_sac" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/multiagent/sac.py \ +python $PYTHONPATH/sota-implementations/multiagent/sac.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_ppo_atari.sh b/sota-check/run_ppo_atari.sh index 764727acb7e..bc9699290ba 100644 --- a/sota-check/run_ppo_atari.sh +++ b/sota-check/run_ppo_atari.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="ppo_atari" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/ppo/ppo_atari.py \ +python $PYTHONPATH/sota-implementations/ppo/ppo_atari.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_ppo_mujoco.sh b/sota-check/run_ppo_mujoco.sh index 0e35974ffcc..e9fbe34e650 100644 --- a/sota-check/run_ppo_mujoco.sh +++ b/sota-check/run_ppo_mujoco.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="ppo_mujoco" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/ppo/ppo_mujoco.py \ +python $PYTHONPATH/sota-implementations/ppo/ppo_mujoco.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_sac.sh b/sota-check/run_sac.sh index 8c7b8ffa5ab..a2811a4e1a4 100644 --- a/sota-check/run_sac.sh +++ b/sota-check/run_sac.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="sac" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/sac/sac.py \ +python $PYTHONPATH/sota-implementations/sac/sac.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-check/run_td3.sh b/sota-check/run_td3.sh index 314ba68b4ac..e13cdb4fbcf 100644 --- a/sota-check/run_td3.sh +++ b/sota-check/run_td3.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="td3" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/examples/td3/td3.py \ +python $PYTHONPATH/sota-implementations/td3/td3.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" diff --git a/sota-implementations/SOTA-IMPLEMENTATIONS.md b/sota-implementations/SOTA-IMPLEMENTATIONS.md new file mode 100644 index 00000000000..1bdcc50ed76 --- /dev/null +++ b/sota-implementations/SOTA-IMPLEMENTATIONS.md @@ -0,0 +1,143 @@ +# Examples + +We provide examples to train the following algorithms: +- [CQL](../sota-implementations/cql/) +- [DDPG](ddpg/ddpg.py) +- [DQN](../sota-implementations/dqn/) +- [Decision Transformers](../sota-implementations/decision_transformer) +- [Decision Transformers](../sota-implementations/decision_transformer) +- [Discrete SAC](discrete_sac/discrete_sac.py) +- [Dreamer](../sota-implementations/dreamer) +- [IQL](iql/) +- [Impala](impala/) +- [PPO](../sota-implementations/ppo/ppo.py) +- [REDQ](redq/redq.py) +- [SAC](sac/sac.py) +- [TD3](../sota-implementations/td3/td3.py) +- [Various multiagent examples](multiagent/) + +To run these examples, make sure you have installed hydra: +``` +pip install hydra-core +``` + +Scripts can be run from the directory of interest using: +``` +python sac.py +``` +or similar. Hyperparameters can be easily changed by providing the arguments to hydra: +``` +python sac.py collector.frames_per_batch=63 +``` + +[//]: # (# Results) + +[//]: # () +[//]: # (Here we can see some results for the SAC and REDQ algorithm.) + +[//]: # (We average the results over 5 different seeds and plot the standard error.) + +[//]: # (## Gym's HalfCheetah-v4) + +[//]: # () +[//]: # (

) + +[//]: # () + +[//]: # (

) + +[//]: # (To reproduce a single run:) + +[//]: # () +[//]: # (```) + +[//]: # (python sac/sac.py env.name="HalfCheetah-v4" env.task="" env.library="gym") + +[//]: # (```) + +[//]: # () +[//]: # (``` ) + +[//]: # (python redq/redq.py env.name="HalfCheetah-v4" env.library="gymnasium") + +[//]: # (```) + +[//]: # () +[//]: # () +[//]: # (## dm_control's cheetah-run) + +[//]: # () +[//]: # (

) + +[//]: # () + +[//]: # (

) + +[//]: # (To reproduce a single run:) + +[//]: # () +[//]: # (```) + +[//]: # (python sac/sac.py env.name="cheetah" env.task="run" env.library="dm_control") + +[//]: # (```) + +[//]: # () +[//]: # (``` ) + +[//]: # (python redq/redq.py env.name="cheetah" env.task="run" env.library="dm_control") + +[//]: # (```) + +[//]: # () +[//]: # ([//]: # (TODO: adapt these scripts)) +[//]: # ([//]: # (## Gym's Ant-v4)) +[//]: # () +[//]: # ([//]: # ()) +[//]: # ([//]: # (

)) +[//]: # () +[//]: # ([//]: # ()) +[//]: # () +[//]: # ([//]: # (

)) +[//]: # () +[//]: # ([//]: # (To reproduce a single run:)) +[//]: # () +[//]: # ([//]: # ()) +[//]: # ([//]: # (```)) +[//]: # () +[//]: # ([//]: # (python sac/sac.py env.name="Ant-v4" env.task="" env.library="gym")) +[//]: # () +[//]: # ([//]: # (```)) +[//]: # () +[//]: # ([//]: # ()) +[//]: # ([//]: # (``` )) +[//]: # () +[//]: # ([//]: # (python redq/redq.py env_name="Ant-v4" env_task="" env_library="gym")) +[//]: # () +[//]: # ([//]: # (```)) +[//]: # () +[//]: # ([//]: # ()) +[//]: # ([//]: # (## Gym's Walker2D-v4)) +[//]: # () +[//]: # ([//]: # ()) +[//]: # ([//]: # (

)) +[//]: # () +[//]: # ([//]: # ()) +[//]: # () +[//]: # ([//]: # (

)) +[//]: # () +[//]: # ([//]: # (To reproduce a single run:)) +[//]: # () +[//]: # ([//]: # ()) +[//]: # ([//]: # (```)) +[//]: # () +[//]: # ([//]: # (python sac/sac.py env_name="Walker2D-v4" env_task="" env_library="gym")) +[//]: # () +[//]: # ([//]: # (```)) +[//]: # () +[//]: # ([//]: # ()) +[//]: # ([//]: # (``` )) +[//]: # () +[//]: # ([//]: # (python redq/redq.py env_name="Walker2D-v4" env_task="" env_library="gym")) +[//]: # () +[//]: # ([//]: # (```)) diff --git a/examples/a2c/README.md b/sota-implementations/a2c/README.md similarity index 100% rename from examples/a2c/README.md rename to sota-implementations/a2c/README.md diff --git a/examples/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py similarity index 98% rename from examples/a2c/a2c_atari.py rename to sota-implementations/a2c/a2c_atari.py index d6e78ad1575..7ad39ed43e5 100644 --- a/examples/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -6,7 +6,7 @@ from torchrl._utils import logger as torchrl_logger -@hydra.main(config_path=".", config_name="config_atari", version_base="1.1") +@hydra.main(config_path="", config_name="config_atari", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 import time diff --git a/examples/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py similarity index 98% rename from examples/a2c/a2c_mujoco.py rename to sota-implementations/a2c/a2c_mujoco.py index 6a95814fe4e..7b4a153e150 100644 --- a/examples/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -6,7 +6,7 @@ from torchrl._utils import logger as torchrl_logger -@hydra.main(config_path=".", config_name="config_mujoco", version_base="1.1") +@hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 import time diff --git a/examples/a2c/config_atari.yaml b/sota-implementations/a2c/config_atari.yaml similarity index 100% rename from examples/a2c/config_atari.yaml rename to sota-implementations/a2c/config_atari.yaml diff --git a/examples/a2c/config_mujoco.yaml b/sota-implementations/a2c/config_mujoco.yaml similarity index 100% rename from examples/a2c/config_mujoco.yaml rename to sota-implementations/a2c/config_mujoco.yaml diff --git a/examples/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py similarity index 100% rename from examples/a2c/utils_atari.py rename to sota-implementations/a2c/utils_atari.py diff --git a/examples/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py similarity index 100% rename from examples/a2c/utils_mujoco.py rename to sota-implementations/a2c/utils_mujoco.py diff --git a/examples/bandits/README.md b/sota-implementations/bandits/README.md similarity index 100% rename from examples/bandits/README.md rename to sota-implementations/bandits/README.md diff --git a/examples/bandits/dqn.py b/sota-implementations/bandits/dqn.py similarity index 100% rename from examples/bandits/dqn.py rename to sota-implementations/bandits/dqn.py diff --git a/examples/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py similarity index 98% rename from examples/cql/cql_offline.py rename to sota-implementations/cql/cql_offline.py index e0f59a5f406..99b391b9db8 100644 --- a/examples/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -29,7 +29,7 @@ ) -@hydra.main(config_path=".", config_name="offline_config", version_base="1.1") +@hydra.main(config_path="", config_name="offline_config", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 # Create logger exp_name = generate_exp_name("CQL-offline", cfg.logger.exp_name) diff --git a/examples/cql/cql_online.py b/sota-implementations/cql/cql_online.py similarity index 98% rename from examples/cql/cql_online.py rename to sota-implementations/cql/cql_online.py index fec979e2289..a70f9091cb6 100644 --- a/examples/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -33,7 +33,7 @@ ) -@hydra.main(version_base="1.1", config_path=".", config_name="online_config") +@hydra.main(version_base="1.1", config_path="", config_name="online_config") def main(cfg: "DictConfig"): # noqa: F821 # Create logger exp_name = generate_exp_name("CQL-online", cfg.logger.exp_name) diff --git a/examples/cql/discrete_cql_config.yaml b/sota-implementations/cql/discrete_cql_config.yaml similarity index 100% rename from examples/cql/discrete_cql_config.yaml rename to sota-implementations/cql/discrete_cql_config.yaml diff --git a/examples/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py similarity index 98% rename from examples/cql/discrete_cql_online.py rename to sota-implementations/cql/discrete_cql_online.py index cb15919e252..fd07684774d 100644 --- a/examples/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -33,7 +33,7 @@ ) -@hydra.main(version_base="1.1", config_path=".", config_name="discrete_cql_config") +@hydra.main(version_base="1.1", config_path="", config_name="discrete_cql_config") def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.optim.device) diff --git a/examples/cql/offline_config.yaml b/sota-implementations/cql/offline_config.yaml similarity index 100% rename from examples/cql/offline_config.yaml rename to sota-implementations/cql/offline_config.yaml diff --git a/examples/cql/online_config.yaml b/sota-implementations/cql/online_config.yaml similarity index 100% rename from examples/cql/online_config.yaml rename to sota-implementations/cql/online_config.yaml diff --git a/examples/cql/utils.py b/sota-implementations/cql/utils.py similarity index 100% rename from examples/cql/utils.py rename to sota-implementations/cql/utils.py diff --git a/examples/ddpg/config.yaml b/sota-implementations/ddpg/config.yaml similarity index 100% rename from examples/ddpg/config.yaml rename to sota-implementations/ddpg/config.yaml diff --git a/examples/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py similarity index 98% rename from examples/ddpg/ddpg.py rename to sota-implementations/ddpg/ddpg.py index e10507cc7f3..e8313e6c342 100644 --- a/examples/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -33,7 +33,7 @@ ) -@hydra.main(version_base="1.1", config_path=".", config_name="config") +@hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) diff --git a/examples/ddpg/utils.py b/sota-implementations/ddpg/utils.py similarity index 100% rename from examples/ddpg/utils.py rename to sota-implementations/ddpg/utils.py diff --git a/examples/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py similarity index 97% rename from examples/decision_transformer/dt.py rename to sota-implementations/decision_transformer/dt.py index 9dc4c855f30..a79c0037205 100644 --- a/examples/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -29,7 +29,7 @@ ) -@hydra.main(config_path=".", config_name="dt_config", version_base="1.1") +@hydra.main(config_path="", config_name="dt_config", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 set_gym_backend(cfg.env.backend).set() diff --git a/examples/decision_transformer/dt_config.yaml b/sota-implementations/decision_transformer/dt_config.yaml similarity index 100% rename from examples/decision_transformer/dt_config.yaml rename to sota-implementations/decision_transformer/dt_config.yaml diff --git a/examples/decision_transformer/lamb.py b/sota-implementations/decision_transformer/lamb.py similarity index 100% rename from examples/decision_transformer/lamb.py rename to sota-implementations/decision_transformer/lamb.py diff --git a/examples/decision_transformer/odt_config.yaml b/sota-implementations/decision_transformer/odt_config.yaml similarity index 100% rename from examples/decision_transformer/odt_config.yaml rename to sota-implementations/decision_transformer/odt_config.yaml diff --git a/examples/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py similarity index 98% rename from examples/decision_transformer/online_dt.py rename to sota-implementations/decision_transformer/online_dt.py index 0ea70c73093..427b5d8eaa3 100644 --- a/examples/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -29,7 +29,7 @@ ) -@hydra.main(config_path=".", config_name="odt_config", version_base="1.1") +@hydra.main(config_path="", config_name="odt_config", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 set_gym_backend(cfg.env.backend).set() diff --git a/examples/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py similarity index 100% rename from examples/decision_transformer/utils.py rename to sota-implementations/decision_transformer/utils.py diff --git a/examples/discrete_sac/config.yaml b/sota-implementations/discrete_sac/config.yaml similarity index 100% rename from examples/discrete_sac/config.yaml rename to sota-implementations/discrete_sac/config.yaml diff --git a/examples/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py similarity index 99% rename from examples/discrete_sac/discrete_sac.py rename to sota-implementations/discrete_sac/discrete_sac.py index 6bc4ad91d1a..40d9a1743c2 100644 --- a/examples/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -33,7 +33,7 @@ ) -@hydra.main(version_base="1.1", config_path=".", config_name="config") +@hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) diff --git a/examples/discrete_sac/utils.py b/sota-implementations/discrete_sac/utils.py similarity index 100% rename from examples/discrete_sac/utils.py rename to sota-implementations/discrete_sac/utils.py diff --git a/examples/distributed/collectors/README.md b/sota-implementations/distributed/collectors/README.md similarity index 100% rename from examples/distributed/collectors/README.md rename to sota-implementations/distributed/collectors/README.md diff --git a/examples/distributed/collectors/multi_nodes/delayed_dist.py b/sota-implementations/distributed/collectors/multi_nodes/delayed_dist.py similarity index 100% rename from examples/distributed/collectors/multi_nodes/delayed_dist.py rename to sota-implementations/distributed/collectors/multi_nodes/delayed_dist.py diff --git a/examples/distributed/collectors/multi_nodes/delayed_rpc.py b/sota-implementations/distributed/collectors/multi_nodes/delayed_rpc.py similarity index 100% rename from examples/distributed/collectors/multi_nodes/delayed_rpc.py rename to sota-implementations/distributed/collectors/multi_nodes/delayed_rpc.py diff --git a/examples/distributed/collectors/multi_nodes/generic.py b/sota-implementations/distributed/collectors/multi_nodes/generic.py similarity index 100% rename from examples/distributed/collectors/multi_nodes/generic.py rename to sota-implementations/distributed/collectors/multi_nodes/generic.py diff --git a/examples/distributed/collectors/multi_nodes/ray.py b/sota-implementations/distributed/collectors/multi_nodes/ray.py similarity index 100% rename from examples/distributed/collectors/multi_nodes/ray.py rename to sota-implementations/distributed/collectors/multi_nodes/ray.py diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/sota-implementations/distributed/collectors/multi_nodes/ray_train.py similarity index 100% rename from examples/distributed/collectors/multi_nodes/ray_train.py rename to sota-implementations/distributed/collectors/multi_nodes/ray_train.py diff --git a/examples/distributed/collectors/multi_nodes/rpc.py b/sota-implementations/distributed/collectors/multi_nodes/rpc.py similarity index 100% rename from examples/distributed/collectors/multi_nodes/rpc.py rename to sota-implementations/distributed/collectors/multi_nodes/rpc.py diff --git a/examples/distributed/collectors/multi_nodes/sync.py b/sota-implementations/distributed/collectors/multi_nodes/sync.py similarity index 100% rename from examples/distributed/collectors/multi_nodes/sync.py rename to sota-implementations/distributed/collectors/multi_nodes/sync.py diff --git a/examples/distributed/collectors/single_machine/generic.py b/sota-implementations/distributed/collectors/single_machine/generic.py similarity index 100% rename from examples/distributed/collectors/single_machine/generic.py rename to sota-implementations/distributed/collectors/single_machine/generic.py diff --git a/examples/distributed/collectors/single_machine/rpc.py b/sota-implementations/distributed/collectors/single_machine/rpc.py similarity index 100% rename from examples/distributed/collectors/single_machine/rpc.py rename to sota-implementations/distributed/collectors/single_machine/rpc.py diff --git a/examples/distributed/collectors/single_machine/sync.py b/sota-implementations/distributed/collectors/single_machine/sync.py similarity index 100% rename from examples/distributed/collectors/single_machine/sync.py rename to sota-implementations/distributed/collectors/single_machine/sync.py diff --git a/examples/distributed/replay_buffers/distributed_replay_buffer.py b/sota-implementations/distributed/replay_buffers/distributed_replay_buffer.py similarity index 100% rename from examples/distributed/replay_buffers/distributed_replay_buffer.py rename to sota-implementations/distributed/replay_buffers/distributed_replay_buffer.py diff --git a/examples/dqn/README.md b/sota-implementations/dqn/README.md similarity index 100% rename from examples/dqn/README.md rename to sota-implementations/dqn/README.md diff --git a/examples/dqn/config_atari.yaml b/sota-implementations/dqn/config_atari.yaml similarity index 100% rename from examples/dqn/config_atari.yaml rename to sota-implementations/dqn/config_atari.yaml diff --git a/examples/dqn/config_cartpole.yaml b/sota-implementations/dqn/config_cartpole.yaml similarity index 100% rename from examples/dqn/config_cartpole.yaml rename to sota-implementations/dqn/config_cartpole.yaml diff --git a/examples/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py similarity index 98% rename from examples/dqn/dqn_atari.py rename to sota-implementations/dqn/dqn_atari.py index 1d7f5dd81b5..ba5f7cbf761 100644 --- a/examples/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -26,7 +26,7 @@ from utils_atari import eval_model, make_dqn_model, make_env -@hydra.main(config_path=".", config_name="config_atari", version_base="1.1") +@hydra.main(config_path="", config_name="config_atari", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.device) diff --git a/examples/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py similarity index 98% rename from examples/dqn/dqn_cartpole.py rename to sota-implementations/dqn/dqn_cartpole.py index 74f5ea99249..cfe734173f5 100644 --- a/examples/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -20,7 +20,7 @@ from utils_cartpole import eval_model, make_dqn_model, make_env -@hydra.main(config_path=".", config_name="config_cartpole", version_base="1.1") +@hydra.main(config_path="", config_name="config_cartpole", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.device) diff --git a/examples/dqn/utils_atari.py b/sota-implementations/dqn/utils_atari.py similarity index 100% rename from examples/dqn/utils_atari.py rename to sota-implementations/dqn/utils_atari.py diff --git a/examples/dqn/utils_cartpole.py b/sota-implementations/dqn/utils_cartpole.py similarity index 100% rename from examples/dqn/utils_cartpole.py rename to sota-implementations/dqn/utils_cartpole.py diff --git a/examples/dreamer/README.md b/sota-implementations/dreamer/README.md similarity index 100% rename from examples/dreamer/README.md rename to sota-implementations/dreamer/README.md diff --git a/examples/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml similarity index 100% rename from examples/dreamer/config.yaml rename to sota-implementations/dreamer/config.yaml diff --git a/examples/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py similarity index 99% rename from examples/dreamer/dreamer.py rename to sota-implementations/dreamer/dreamer.py index 27732fd96f7..a1d8c8aec4e 100644 --- a/examples/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -70,7 +70,7 @@ def retrieve_stats_from_state_dict(obs_norm_state_dict): } -@hydra.main(version_base="1.1", config_path=".", config_name="config") +@hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 cfg = correct_for_frame_skip(cfg) diff --git a/examples/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py similarity index 100% rename from examples/dreamer/dreamer_utils.py rename to sota-implementations/dreamer/dreamer_utils.py diff --git a/examples/impala/README.md b/sota-implementations/impala/README.md similarity index 100% rename from examples/impala/README.md rename to sota-implementations/impala/README.md diff --git a/examples/impala/config_multi_node_ray.yaml b/sota-implementations/impala/config_multi_node_ray.yaml similarity index 100% rename from examples/impala/config_multi_node_ray.yaml rename to sota-implementations/impala/config_multi_node_ray.yaml diff --git a/examples/impala/config_multi_node_submitit.yaml b/sota-implementations/impala/config_multi_node_submitit.yaml similarity index 100% rename from examples/impala/config_multi_node_submitit.yaml rename to sota-implementations/impala/config_multi_node_submitit.yaml diff --git a/examples/impala/config_single_node.yaml b/sota-implementations/impala/config_single_node.yaml similarity index 100% rename from examples/impala/config_single_node.yaml rename to sota-implementations/impala/config_single_node.yaml diff --git a/examples/impala/impala_multi_node_ray.py b/sota-implementations/impala/impala_multi_node_ray.py similarity index 99% rename from examples/impala/impala_multi_node_ray.py rename to sota-implementations/impala/impala_multi_node_ray.py index e52b3af8342..0482a595ffa 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/sota-implementations/impala/impala_multi_node_ray.py @@ -11,7 +11,7 @@ from torchrl._utils import logger as torchrl_logger -@hydra.main(config_path=".", config_name="config_multi_node_ray", version_base="1.1") +@hydra.main(config_path="", config_name="config_multi_node_ray", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 import time diff --git a/examples/impala/impala_multi_node_submitit.py b/sota-implementations/impala/impala_multi_node_submitit.py similarity index 99% rename from examples/impala/impala_multi_node_submitit.py rename to sota-implementations/impala/impala_multi_node_submitit.py index acd36f09d5a..ce96cf06ce8 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/sota-implementations/impala/impala_multi_node_submitit.py @@ -12,7 +12,7 @@ @hydra.main( - config_path=".", config_name="config_multi_node_submitit", version_base="1.1" + config_path="", config_name="config_multi_node_submitit", version_base="1.1" ) def main(cfg: "DictConfig"): # noqa: F821 diff --git a/examples/impala/impala_single_node.py b/sota-implementations/impala/impala_single_node.py similarity index 99% rename from examples/impala/impala_single_node.py rename to sota-implementations/impala/impala_single_node.py index 1faff37d1e0..bb0f314197a 100644 --- a/examples/impala/impala_single_node.py +++ b/sota-implementations/impala/impala_single_node.py @@ -11,7 +11,7 @@ from torchrl._utils import logger as torchrl_logger -@hydra.main(config_path=".", config_name="config_single_node", version_base="1.1") +@hydra.main(config_path="", config_name="config_single_node", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 import time diff --git a/examples/impala/utils.py b/sota-implementations/impala/utils.py similarity index 100% rename from examples/impala/utils.py rename to sota-implementations/impala/utils.py diff --git a/examples/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py similarity index 99% rename from examples/iql/discrete_iql.py rename to sota-implementations/iql/discrete_iql.py index 9a685faa036..c0101f1c941 100644 --- a/examples/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -34,7 +34,7 @@ ) -@hydra.main(config_path=".", config_name="discrete_iql") +@hydra.main(config_path="", config_name="discrete_iql") def main(cfg: "DictConfig"): # noqa: F821 set_gym_backend(cfg.env.backend).set() diff --git a/examples/iql/discrete_iql.yaml b/sota-implementations/iql/discrete_iql.yaml similarity index 100% rename from examples/iql/discrete_iql.yaml rename to sota-implementations/iql/discrete_iql.yaml diff --git a/examples/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py similarity index 98% rename from examples/iql/iql_offline.py rename to sota-implementations/iql/iql_offline.py index d3e15221f30..da65c5c246e 100644 --- a/examples/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -31,7 +31,7 @@ ) -@hydra.main(config_path=".", config_name="offline_config") +@hydra.main(config_path="", config_name="offline_config") def main(cfg: "DictConfig"): # noqa: F821 set_gym_backend(cfg.env.backend).set() diff --git a/examples/iql/iql_online.py b/sota-implementations/iql/iql_online.py similarity index 99% rename from examples/iql/iql_online.py rename to sota-implementations/iql/iql_online.py index e1295eaabfe..307f6df5e2b 100644 --- a/examples/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -34,7 +34,7 @@ ) -@hydra.main(config_path=".", config_name="online_config") +@hydra.main(config_path="", config_name="online_config") def main(cfg: "DictConfig"): # noqa: F821 set_gym_backend(cfg.env.backend).set() diff --git a/examples/iql/offline_config.yaml b/sota-implementations/iql/offline_config.yaml similarity index 100% rename from examples/iql/offline_config.yaml rename to sota-implementations/iql/offline_config.yaml diff --git a/examples/iql/online_config.yaml b/sota-implementations/iql/online_config.yaml similarity index 100% rename from examples/iql/online_config.yaml rename to sota-implementations/iql/online_config.yaml diff --git a/examples/iql/utils.py b/sota-implementations/iql/utils.py similarity index 100% rename from examples/iql/utils.py rename to sota-implementations/iql/utils.py diff --git a/examples/media/ant_chart.png b/sota-implementations/media/ant_chart.png similarity index 100% rename from examples/media/ant_chart.png rename to sota-implementations/media/ant_chart.png diff --git a/examples/media/cheetah_chart.png b/sota-implementations/media/cheetah_chart.png similarity index 100% rename from examples/media/cheetah_chart.png rename to sota-implementations/media/cheetah_chart.png diff --git a/examples/media/halfcheetah_chart.png b/sota-implementations/media/halfcheetah_chart.png similarity index 100% rename from examples/media/halfcheetah_chart.png rename to sota-implementations/media/halfcheetah_chart.png diff --git a/examples/media/walker2d_chart.png b/sota-implementations/media/walker2d_chart.png similarity index 100% rename from examples/media/walker2d_chart.png rename to sota-implementations/media/walker2d_chart.png diff --git a/examples/multiagent/README.md b/sota-implementations/multiagent/README.md similarity index 100% rename from examples/multiagent/README.md rename to sota-implementations/multiagent/README.md diff --git a/examples/multiagent/iql.py b/sota-implementations/multiagent/iql.py similarity index 99% rename from examples/multiagent/iql.py rename to sota-implementations/multiagent/iql.py index bb374b99941..81551ebefb7 100644 --- a/examples/multiagent/iql.py +++ b/sota-implementations/multiagent/iql.py @@ -28,7 +28,7 @@ def rendering_callback(env, td): env.frames.append(env.render(mode="rgb_array", agent_index_focus=None)) -@hydra.main(version_base="1.1", config_path=".", config_name="iql") +@hydra.main(version_base="1.1", config_path="", config_name="iql") def train(cfg: "DictConfig"): # noqa: F821 # Device cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" diff --git a/examples/multiagent/iql.yaml b/sota-implementations/multiagent/iql.yaml similarity index 100% rename from examples/multiagent/iql.yaml rename to sota-implementations/multiagent/iql.yaml diff --git a/examples/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py similarity index 99% rename from examples/multiagent/maddpg_iddpg.py rename to sota-implementations/multiagent/maddpg_iddpg.py index bed6240d244..c02a5007318 100644 --- a/examples/multiagent/maddpg_iddpg.py +++ b/sota-implementations/multiagent/maddpg_iddpg.py @@ -33,7 +33,7 @@ def rendering_callback(env, td): env.frames.append(env.render(mode="rgb_array", agent_index_focus=None)) -@hydra.main(version_base="1.1", config_path=".", config_name="maddpg_iddpg") +@hydra.main(version_base="1.1", config_path="", config_name="maddpg_iddpg") def train(cfg: "DictConfig"): # noqa: F821 # Device cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" diff --git a/examples/multiagent/maddpg_iddpg.yaml b/sota-implementations/multiagent/maddpg_iddpg.yaml similarity index 100% rename from examples/multiagent/maddpg_iddpg.yaml rename to sota-implementations/multiagent/maddpg_iddpg.yaml diff --git a/examples/multiagent/mappo_ippo.py b/sota-implementations/multiagent/mappo_ippo.py similarity index 99% rename from examples/multiagent/mappo_ippo.py rename to sota-implementations/multiagent/mappo_ippo.py index 6b7206511ca..0fe547b3cd6 100644 --- a/examples/multiagent/mappo_ippo.py +++ b/sota-implementations/multiagent/mappo_ippo.py @@ -29,7 +29,7 @@ def rendering_callback(env, td): env.frames.append(env.render(mode="rgb_array", agent_index_focus=None)) -@hydra.main(version_base="1.1", config_path=".", config_name="mappo_ippo") +@hydra.main(version_base="1.1", config_path="", config_name="mappo_ippo") def train(cfg: "DictConfig"): # noqa: F821 # Device cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" diff --git a/examples/multiagent/mappo_ippo.yaml b/sota-implementations/multiagent/mappo_ippo.yaml similarity index 100% rename from examples/multiagent/mappo_ippo.yaml rename to sota-implementations/multiagent/mappo_ippo.yaml diff --git a/examples/multiagent/qmix_vdn.py b/sota-implementations/multiagent/qmix_vdn.py similarity index 99% rename from examples/multiagent/qmix_vdn.py rename to sota-implementations/multiagent/qmix_vdn.py index 008e01b28b9..d294a9c783e 100644 --- a/examples/multiagent/qmix_vdn.py +++ b/sota-implementations/multiagent/qmix_vdn.py @@ -28,7 +28,7 @@ def rendering_callback(env, td): env.frames.append(env.render(mode="rgb_array", agent_index_focus=None)) -@hydra.main(version_base="1.1", config_path=".", config_name="qmix_vdn") +@hydra.main(version_base="1.1", config_path="", config_name="qmix_vdn") def train(cfg: "DictConfig"): # noqa: F821 # Device cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" diff --git a/examples/multiagent/qmix_vdn.yaml b/sota-implementations/multiagent/qmix_vdn.yaml similarity index 100% rename from examples/multiagent/qmix_vdn.yaml rename to sota-implementations/multiagent/qmix_vdn.yaml diff --git a/examples/multiagent/sac.py b/sota-implementations/multiagent/sac.py similarity index 99% rename from examples/multiagent/sac.py rename to sota-implementations/multiagent/sac.py index d76ddd1f913..78756a012a5 100644 --- a/examples/multiagent/sac.py +++ b/sota-implementations/multiagent/sac.py @@ -30,7 +30,7 @@ def rendering_callback(env, td): env.frames.append(env.render(mode="rgb_array", agent_index_focus=None)) -@hydra.main(version_base="1.1", config_path=".", config_name="sac") +@hydra.main(version_base="1.1", config_path="", config_name="sac") def train(cfg: "DictConfig"): # noqa: F821 # Device cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" diff --git a/examples/multiagent/sac.yaml b/sota-implementations/multiagent/sac.yaml similarity index 100% rename from examples/multiagent/sac.yaml rename to sota-implementations/multiagent/sac.yaml diff --git a/examples/multiagent/utils/__init__.py b/sota-implementations/multiagent/utils/__init__.py similarity index 100% rename from examples/multiagent/utils/__init__.py rename to sota-implementations/multiagent/utils/__init__.py diff --git a/examples/multiagent/utils/logging.py b/sota-implementations/multiagent/utils/logging.py similarity index 100% rename from examples/multiagent/utils/logging.py rename to sota-implementations/multiagent/utils/logging.py diff --git a/examples/multiagent/utils/utils.py b/sota-implementations/multiagent/utils/utils.py similarity index 100% rename from examples/multiagent/utils/utils.py rename to sota-implementations/multiagent/utils/utils.py diff --git a/examples/ppo/README.md b/sota-implementations/ppo/README.md similarity index 100% rename from examples/ppo/README.md rename to sota-implementations/ppo/README.md diff --git a/examples/ppo/config_atari.yaml b/sota-implementations/ppo/config_atari.yaml similarity index 100% rename from examples/ppo/config_atari.yaml rename to sota-implementations/ppo/config_atari.yaml diff --git a/examples/ppo/config_mujoco.yaml b/sota-implementations/ppo/config_mujoco.yaml similarity index 100% rename from examples/ppo/config_mujoco.yaml rename to sota-implementations/ppo/config_mujoco.yaml diff --git a/examples/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py similarity index 99% rename from examples/ppo/ppo_atari.py rename to sota-implementations/ppo/ppo_atari.py index 6b9a18ae5bb..69468e133a8 100644 --- a/examples/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -11,7 +11,7 @@ from torchrl._utils import logger as torchrl_logger -@hydra.main(config_path=".", config_name="config_atari", version_base="1.1") +@hydra.main(config_path="", config_name="config_atari", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 import time diff --git a/examples/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py similarity index 99% rename from examples/ppo/ppo_mujoco.py rename to sota-implementations/ppo/ppo_mujoco.py index fa497230b6e..ae4ba9ea9e5 100644 --- a/examples/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -11,7 +11,7 @@ from torchrl._utils import logger as torchrl_logger -@hydra.main(config_path=".", config_name="config_mujoco", version_base="1.1") +@hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 import time diff --git a/examples/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py similarity index 100% rename from examples/ppo/utils_atari.py rename to sota-implementations/ppo/utils_atari.py diff --git a/examples/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py similarity index 100% rename from examples/ppo/utils_mujoco.py rename to sota-implementations/ppo/utils_mujoco.py diff --git a/examples/redq/README.md b/sota-implementations/redq/README.md similarity index 100% rename from examples/redq/README.md rename to sota-implementations/redq/README.md diff --git a/examples/redq/config.yaml b/sota-implementations/redq/config.yaml similarity index 100% rename from examples/redq/config.yaml rename to sota-implementations/redq/config.yaml diff --git a/examples/redq/redq.py b/sota-implementations/redq/redq.py similarity index 98% rename from examples/redq/redq.py rename to sota-implementations/redq/redq.py index f89098e1441..d9aef64b525 100644 --- a/examples/redq/redq.py +++ b/sota-implementations/redq/redq.py @@ -39,7 +39,7 @@ } -@hydra.main(version_base="1.1", config_path=".", config_name="config") +@hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 cfg = correct_for_frame_skip(cfg) diff --git a/examples/redq/utils.py b/sota-implementations/redq/utils.py similarity index 99% rename from examples/redq/utils.py rename to sota-implementations/redq/utils.py index fe78fa83432..37e7da91b4a 100644 --- a/examples/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -971,7 +971,7 @@ def make_collector_offpolicy( cfg: "DictConfig", # noqa: F821 make_env_kwargs: Dict | None = None, ) -> DataCollectorBase: - """Returns a data collector for off-policy algorithms. + """Returns a data collector for off-policy sota-implementations. Args: make_env (Callable): environment creator diff --git a/examples/sac/config.yaml b/sota-implementations/sac/config.yaml similarity index 100% rename from examples/sac/config.yaml rename to sota-implementations/sac/config.yaml diff --git a/examples/sac/sac.py b/sota-implementations/sac/sac.py similarity index 99% rename from examples/sac/sac.py rename to sota-implementations/sac/sac.py index 5b0cad1a7c9..576de96394d 100644 --- a/examples/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -34,7 +34,7 @@ ) -@hydra.main(version_base="1.1", config_path=".", config_name="config") +@hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) diff --git a/examples/sac/utils.py b/sota-implementations/sac/utils.py similarity index 100% rename from examples/sac/utils.py rename to sota-implementations/sac/utils.py diff --git a/examples/td3/config.yaml b/sota-implementations/td3/config.yaml similarity index 100% rename from examples/td3/config.yaml rename to sota-implementations/td3/config.yaml diff --git a/examples/td3/td3.py b/sota-implementations/td3/td3.py similarity index 99% rename from examples/td3/td3.py rename to sota-implementations/td3/td3.py index ef2edd578cb..6b1ee046d55 100644 --- a/examples/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -33,7 +33,7 @@ ) -@hydra.main(version_base="1.1", config_path=".", config_name="config") +@hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) diff --git a/examples/td3/utils.py b/sota-implementations/td3/utils.py similarity index 100% rename from examples/td3/utils.py rename to sota-implementations/td3/utils.py diff --git a/test/mocking_classes.py b/test/mocking_classes.py index d68c7f30aa3..43d7f79c329 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -904,7 +904,7 @@ def _get_in_obs(self, tensordict): class DummyModelBasedEnvBase(ModelBasedEnvBase): - """Dummy environnement for Model Based RL algorithms. + """Dummy environnement for Model Based RL sota-implementations. This class is meant to be used to test the model based environnement. diff --git a/test/test_helpers.py b/test/test_helpers.py index e39e6cc6082..3d97f177f01 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -5,6 +5,7 @@ import argparse import dataclasses +import pathlib import sys from time import sleep @@ -68,12 +69,14 @@ @pytest.fixture def dreamer_constructor_fixture(): - import os # we hack the env constructor import sys - sys.path.append(os.path.dirname(__file__) + "/../examples/dreamer/") + sys.path.append( + str(pathlib.Path(__file__).parent.parent / "sota-implementations" / "dreamer") + ) + from dreamer_utils import transformed_env_constructor yield transformed_env_constructor diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 5cbfe5c62ae..7da775a176a 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1771,7 +1771,7 @@ class MultiSyncDataCollector(_MultiDataCollector): The collection starts when the next item of the collector is queried, and no environment step is computed in between the reception of a batch of trajectory and the start of the next collection. - This class can be safely used with online RL algorithms. + This class can be safely used with online RL sota-implementations. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -2008,7 +2008,7 @@ class MultiaSyncDataCollector(_MultiDataCollector): The collection keeps on occuring on all processes even between the time the batch of rollouts is collected and the next call to the iterator. - This class can be safely used with offline RL algorithms. + This class can be safely used with offline RL sota-implementations. Examples: >>> from torchrl.envs.libs.gym import GymEnv diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index 5952132ca19..a04607829c6 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -17,7 +17,7 @@ class ModelBasedEnvBase(EnvBase): - """Basic environnement for Model Based RL algorithms. + """Basic environnement for Model Based RL sota-implementations. Wrapper around the model of the MBRL algorithm. It is meant to give an env framework to a world model (including but not limited to observations, reward, done state and safety constraints models). diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 124250157ec..7d8619fcb31 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4049,7 +4049,7 @@ class FrameSkipTransform(Transform): """A frame-skip transform. This transform applies the same action repeatedly in the parent environment, - which improves stability on certain training algorithms. + which improves stability on certain training sota-implementations. Args: frame_skip (int, optional): a positive integer representing the number diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 3822f01cb73..96f15e8ab69 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -36,7 +36,7 @@ class ReinforceLoss(LossModule): """Reinforce loss module. - Presented in "Simple statistical gradient-following algorithms for connectionist reinforcement learning", Williams, 1992 + Presented in "Simple statistical gradient-following sota-implementations for connectionist reinforcement learning", Williams, 1992 https://doi.org/10.1007/BF00992696 diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index 7063fb2f1c4..b192d115a54 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -256,7 +256,7 @@ def make_collector_offpolicy( cfg: "DictConfig", # noqa: F821 make_env_kwargs: Optional[Dict] = None, ) -> DataCollectorBase: - """Returns a data collector for off-policy algorithms. + """Returns a data collector for off-policy sota-implementations. Args: make_env (Callable): environment creator diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 03a7be37573..526b3c967e8 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -1003,7 +1003,7 @@ def mask_batch(batch: TensorDictBase) -> TensorDictBase: class BatchSubSampler(TrainerHookBase): - """Data subsampler for online RL algorithms. + """Data subsampler for online RL sota-implementations. This class subsamples a part of a whole batch of data just collected from the environment. diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 4a818474985..17166453cba 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -12,7 +12,7 @@ # Overview # -------- # -# TorchRL separates the training of RL algorithms in various pieces that will be +# TorchRL separates the training of RL sota-implementations in various pieces that will be # assembled in your training script: the environment, the data collection and # storage, the model and finally the loss function. # @@ -168,7 +168,7 @@ # the losses without it. However, we encourage its usage for the following # reason. # -# The reason TorchRL does this is that RL algorithms often execute the same +# The reason TorchRL does this is that RL sota-implementations often execute the same # model with different sets of parameters, called "trainable" and "target" # parameters. # The "trainable" parameters are those that the optimizer needs to fit. The @@ -407,7 +407,7 @@ class DDPGLoss(LossModule): # Environment # ----------- # -# In most algorithms, the first thing that needs to be taken care of is the +# In most sota-implementations, the first thing that needs to be taken care of is the # construction of the environment as it conditions the remainder of the # training script. # @@ -1059,7 +1059,7 @@ def ceil_div(x, y): # Target network updater # ~~~~~~~~~~~~~~~~~~~~~~ # -# Target networks are a crucial part of off-policy RL algorithms. +# Target networks are a crucial part of off-policy RL sota-implementations. # Updating the target network parameters is made easy thanks to the # :class:`~torchrl.objectives.HardUpdate` and :class:`~torchrl.objectives.SoftUpdate` # classes. They're built with the loss module as argument, and the update is diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index eb476dfcc15..dbce0c29804 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -43,7 +43,7 @@ # estimated return; # - how to collect data from your environment efficiently and store them # in a replay buffer; -# - how to use multi-step, a simple preprocessing step for off-policy algorithms; +# - how to use multi-step, a simple preprocessing step for off-policy sota-implementations; # - and finally how to evaluate your model. # # **Prerequisites**: We encourage you to get familiar with torchrl through the @@ -356,7 +356,7 @@ def make_model(dummy_env): # Replay buffers # ~~~~~~~~~~~~~~ # -# Replay buffers play a central role in off-policy RL algorithms such as DQN. +# Replay buffers play a central role in off-policy RL sota-implementations such as DQN. # They constitute the dataset we will be sampling from during training. # # Here, we will use a regular sampling strategy, although a prioritized RB @@ -455,13 +455,13 @@ def get_collector( # Target parameters # ~~~~~~~~~~~~~~~~~ # -# Many off-policy RL algorithms use the concept of "target parameters" when it +# Many off-policy RL sota-implementations use the concept of "target parameters" when it # comes to estimate the value of the next state or state-action pair. # The target parameters are lagged copies of the model parameters. Because # their predictions mismatch those of the current model configuration, they # help learning by putting a pessimistic bound on the value being estimated. # This is a powerful trick (known as "Double Q-Learning") that is ubiquitous -# in similar algorithms. +# in similar sota-implementations. # diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index 6f31a0aed1a..b986f1e80bd 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -519,7 +519,7 @@ # Replay buffer # ------------- # -# Replay buffers are a common building piece of off-policy RL algorithms. +# Replay buffers are a common building piece of off-policy RL sota-implementations. # In on-policy contexts, a replay buffer is refilled every time a batch of # data is collected, and its data is repeatedly consumed for a certain number # of epochs. diff --git a/tutorials/sphinx-tutorials/getting-started-1.py b/tutorials/sphinx-tutorials/getting-started-1.py index 75ccf7cf8e7..27df0fafb6e 100644 --- a/tutorials/sphinx-tutorials/getting-started-1.py +++ b/tutorials/sphinx-tutorials/getting-started-1.py @@ -118,7 +118,7 @@ # Probabilistic policies # ---------------------- # -# Policy-optimization algorithms like +# Policy-optimization sota-implementations like # `PPO `_ require the policy to be # stochastic: unlike in the examples above, the module now encodes a map from # the observation space to a parameter space encoding a distribution over the @@ -162,7 +162,7 @@ # # - Since we asked for it during the construction of the actor, the # log-probability of the actions given the distribution at that time is -# also written. This is necessary for algorithms like PPO. +# also written. This is necessary for sota-implementations like PPO. # - The parameters of the distribution are returned within the output # tensordict too under the ``"loc"`` and ``"scale"`` entries. # diff --git a/tutorials/sphinx-tutorials/getting-started-2.py b/tutorials/sphinx-tutorials/getting-started-2.py index 0a16071bed2..a03df6a32cb 100644 --- a/tutorials/sphinx-tutorials/getting-started-2.py +++ b/tutorials/sphinx-tutorials/getting-started-2.py @@ -40,9 +40,9 @@ # ---------------------- # # In RL, innovation typically involves the exploration of novel methods -# for optimizing a policy (i.e., new algorithms), rather than focusing +# for optimizing a policy (i.e., new sota-implementations), rather than focusing # on new architectures, as seen in other domains. Within TorchRL, -# these algorithms are encapsulated within loss modules. A loss +# these sota-implementations are encapsulated within loss modules. A loss # module orchestrates the various components of your algorithm and # yields a set of loss values that can be backpropagated # through to train the corresponding components. @@ -146,7 +146,7 @@ # ----------------------------------------- # # Another important aspect to consider is the presence of target parameters -# in off-policy algorithms like DDPG. Target parameters typically represent +# in off-policy sota-implementations like DDPG. Target parameters typically represent # a delayed or smoothed version of the parameters over time, and they play # a crucial role in value estimation during policy training. Utilizing target # parameters for policy training often proves to be significantly more diff --git a/tutorials/sphinx-tutorials/getting-started-3.py b/tutorials/sphinx-tutorials/getting-started-3.py index cf80b47f859..a4652c78518 100644 --- a/tutorials/sphinx-tutorials/getting-started-3.py +++ b/tutorials/sphinx-tutorials/getting-started-3.py @@ -30,7 +30,7 @@ # dataloaders are referred to as ``DataCollectors``. Most of the time, # data collection does not stop at the collection of raw data, # as the data needs to be stored temporarily in a buffer -# (or equivalent structure for on-policy algorithms) before being consumed +# (or equivalent structure for on-policy sota-implementations) before being consumed # by the :ref:`loss module `. This tutorial will explore # these two classes. # @@ -94,7 +94,7 @@ ################################# # Data collectors are very useful when it comes to coding state-of-the-art -# algorithms, as performance is usually measured by the capability of a +# sota-implementations, as performance is usually measured by the capability of a # specific technique to solve a problem in a given number of interactions with # the environment (the ``total_frames`` argument in the collector). # For this reason, most training loops in our examples look like this: diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index 7451d6b39e7..16d264efb65 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -9,14 +9,14 @@ If you are interested in Multi-Agent Reinforcement Learning (MARL) in TorchRL, check out `BenchMARL `__: a benchmarking library where you - can train and compare MARL algorithms, tasks, and models using TorchRL! + can train and compare MARL sota-implementations, tasks, and models using TorchRL! This tutorial demonstrates how to use PyTorch and :py:mod:`torchrl` to solve a Multi-Agent Reinforcement Learning (MARL) problem. A code-only version of this tutorial is available in the `TorchRL examples `__, -alongside other simple scripts for many MARL algorithms (QMIX, MADDPG, IQL). +alongside other simple scripts for many MARL sota-implementations (QMIX, MADDPG, IQL). For ease of use, this tutorial will follow the general structure of the already available `single agent PPO tutorial `__. @@ -63,7 +63,7 @@ # the foundational policy-optimization algorithm. For more information, see the # `Proximal Policy Optimization Algorithms `_ paper. # -# This type of algorithms is usually trained *on-policy*. This means that, at every learning iteration, we have a +# This type of sota-implementations is usually trained *on-policy*. This means that, at every learning iteration, we have a # **sampling** and a **training** phase. In the **sampling** phase of iteration :math:`t`, rollouts are collected # form agents' interactions in the environment using the current policies :math:`\mathbf{\pi}_t`. # In the **training** phase, all the collected rollouts are immediately fed to the training process to perform @@ -560,7 +560,7 @@ # Replay buffer # ------------- # -# Replay buffers are a common building piece of off-policy RL algorithms. +# Replay buffers are a common building piece of off-policy RL sota-implementations. # In on-policy contexts, a replay buffer is refilled every time a batch of # data is collected, and its data is repeatedly consumed for a certain number # of epochs. @@ -789,7 +789,7 @@ # # Now that you are proficient with multi-agent PPO, you can check out all # `TorchRL multi-agent examples `__. -# These are code-only scripts of many popular MARL algorithms such as the ones seen in this tutorial, +# These are code-only scripts of many popular MARL sota-implementations such as the ones seen in this tutorial, # QMIX, MADDPG, IQL, and many more! # # If you are interested in creating or wrapping your own multi-agent environments in TorchRL, diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 25213503e19..3738542e3a7 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -165,13 +165,13 @@ # │ └── "trainers.py" # └── "version.py" # -# Unlike other domains, RL is less about media than *algorithms*. As such, it +# Unlike other domains, RL is less about media than *sota-implementations*. As such, it # is harder to make truly independent components. # # What TorchRL is not: # -# * a collection of algorithms: we do not intend to provide SOTA implementations of RL algorithms, -# but we provide these algorithms only as examples of how to use the library. +# * a collection of sota-implementations: we do not intend to provide SOTA implementations of RL sota-implementations, +# but we provide these sota-implementations only as examples of how to use the library. # # * a research framework: modularity in TorchRL comes in two flavours. First, we try # to build re-usable components, such that they can be easily swapped with each other.