Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Examples] Move examples to sota-implementations #2016

Merged
merged 3 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Mar 18, 2024
commit 9f604290aba37f3419fa52f77eb37bd4e86d4ec3
66 changes: 33 additions & 33 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,55 +32,55 @@ python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_te
# ================================ gym 0.23 ========================================== #

# With batched environments
python .github/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/dt.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/decision_transformer/dt.py \
optim.pretrain_gradient_steps=55 \
optim.updates_per_episode=3 \
optim.warmup_steps=10 \
optim.device=cuda:0 \
logger.backend= \
env.backend=gymnasium \
env.name=HalfCheetah-v4
python .github/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/online_dt.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/decision_transformer/online_dt.py \
optim.pretrain_gradient_steps=55 \
optim.updates_per_episode=3 \
optim.warmup_steps=10 \
optim.device=cuda:0 \
env.backend=gymnasium \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_offline.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/iql_offline.py \
optim.gradient_steps=55 \
optim.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_offline.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_offline.py \
optim.gradient_steps=55 \
optim.device=cuda:0 \
logger.backend=

# ==================================================================================== #
# ================================ Gymnasium ========================================= #

python .github/unittest/helpers/coverage_run_parallel.py examples/impala/impala_single_node.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/impala/impala_single_node.py \
collector.total_frames=80 \
collector.frames_per_batch=20 \
collector.num_workers=1 \
logger.backend= \
logger.test_interval=10
python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_mujoco.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ppo/ppo_mujoco.py \
env.env_name=HalfCheetah-v4 \
collector.total_frames=40 \
collector.frames_per_batch=20 \
loss.mini_batch_size=10 \
loss.ppo_epochs=2 \
logger.backend= \
logger.test_interval=10
python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_atari.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ppo/ppo_atari.py \
collector.total_frames=80 \
collector.frames_per_batch=20 \
loss.mini_batch_size=20 \
loss.ppo_epochs=2 \
logger.backend= \
logger.test_interval=10
python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ddpg/ddpg.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
optim.batch_size=10 \
Expand All @@ -94,20 +94,20 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
logger.backend=
# record_video=True \
# record_frames=4 \
python .github/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c_mujoco.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/a2c/a2c_mujoco.py \
env.env_name=HalfCheetah-v4 \
collector.total_frames=40 \
collector.frames_per_batch=20 \
loss.mini_batch_size=10 \
logger.backend= \
logger.test_interval=40
python .github/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c_atari.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/a2c/a2c_atari.py \
collector.total_frames=80 \
collector.frames_per_batch=20 \
loss.mini_batch_size=20 \
logger.backend= \
logger.test_interval=40
python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn_atari.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dqn/dqn_atari.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
Expand All @@ -116,7 +116,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn_atari.
loss.num_updates=1 \
logger.backend= \
buffer.buffer_size=120
python .github/unittest/helpers/coverage_run_parallel.py examples/cql/discrete_cql_online.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/discrete_cql_online.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
optim.batch_size=10 \
Expand All @@ -125,7 +125,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/cql/discrete_c
collector.device=cuda:0 \
replay_buffer.size=120 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/redq/redq.py \
num_workers=4 \
collector.total_frames=48 \
collector.init_random_frames=10 \
Expand All @@ -138,7 +138,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
logger.record_frames=4 \
buffer.size=120 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/sac/sac.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
Expand All @@ -150,7 +150,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
env.name=Pendulum-v1 \
network.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/discrete_sac/discrete_sac.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/discrete_sac/discrete_sac.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
Expand All @@ -166,7 +166,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/discrete_sac/d
logger.backend=
# logger.record_video=True \
# logger.record_frames=4 \
python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
total_frames=200 \
init_random_frames=10 \
batch_size=10 \
Expand All @@ -180,7 +180,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreame
record_frames=4 \
buffer_size=120 \
rssm_hidden_dim=17
python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
optim.batch_size=10 \
Expand All @@ -193,7 +193,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
logger.mode=offline \
env.name=Pendulum-v1 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/iql_online.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
Expand All @@ -202,7 +202,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online
collector.device=cuda:0 \
logger.mode=offline \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/discrete_iql.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/discrete_iql.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
Expand All @@ -211,7 +211,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/iql/discrete_i
collector.device=cuda:0 \
logger.mode=offline \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_online.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_online.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
Expand All @@ -222,7 +222,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/iql/discrete_i
logger.backend=

# With single envs
python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
total_frames=200 \
init_random_frames=10 \
batch_size=10 \
Expand All @@ -236,7 +236,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreame
record_frames=4 \
buffer_size=120 \
rssm_hidden_dim=17
python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ddpg/ddpg.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
optim.batch_size=10 \
Expand All @@ -250,7 +250,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
logger.backend=
# record_video=True \
# record_frames=4 \
python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn_atari.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dqn/dqn_atari.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
Expand All @@ -259,7 +259,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn_atari.
loss.num_updates=1 \
logger.backend= \
buffer.buffer_size=120
python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/redq/redq.py \
num_workers=2 \
collector.total_frames=48 \
collector.init_random_frames=10 \
Expand All @@ -272,7 +272,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
logger.record_frames=4 \
buffer.size=120 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/iql_online.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
Expand All @@ -281,7 +281,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online
optim.device=cuda:0 \
collector.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_online.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_online.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
Expand All @@ -290,7 +290,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_online
optim.device=cuda:0 \
collector.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
Expand All @@ -302,38 +302,38 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
env.name=Pendulum-v1 \
network.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/mappo_ippo.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/multiagent/mappo_ippo.py \
collector.n_iters=2 \
collector.frames_per_batch=200 \
train.num_epochs=3 \
train.minibatch_size=100 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/maddpg_iddpg.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/multiagent/maddpg_iddpg.py \
collector.n_iters=2 \
collector.frames_per_batch=200 \
train.num_epochs=3 \
train.minibatch_size=100 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/iql.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/multiagent/iql.py \
collector.n_iters=2 \
collector.frames_per_batch=200 \
train.num_epochs=3 \
train.minibatch_size=100 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/qmix_vdn.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/multiagent/qmix_vdn.py \
collector.n_iters=2 \
collector.frames_per_batch=200 \
train.num_epochs=3 \
train.minibatch_size=100 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/sac.py \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/multiagent/sac.py \
collector.n_iters=2 \
collector.frames_per_batch=200 \
train.num_epochs=3 \
train.minibatch_size=100 \
logger.backend=

python .github/unittest/helpers/coverage_run_parallel.py examples/bandits/dqn.py --n_steps=100
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/bandits/dqn.py --n_steps=100

## RLHF
# RLHF tests are executed in the dedicated workflow
Expand Down
24 changes: 12 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ And it is `functorch` and `torch.compile` compatible!
```
</details>

Check our [distributed collector examples](examples/distributed/collectors) to
Check our [distributed collector examples](sota-implementations/distributed/collectors) to
learn more about ultra-fast data collection with TorchRL.

- efficient<sup>(2)</sup> and generic<sup>(1)</sup> [replay buffers](torchrl/data/replay_buffers/replay_buffers.py) with modularized storage:
Expand Down Expand Up @@ -496,22 +496,22 @@ If you would like to contribute to new features, check our [call for contributio
## Examples, tutorials and demos

A series of [examples](examples/) are provided with an illustrative purpose:
- [DQN](examples/dqn)
- [DDPG](examples/ddpg/ddpg.py)
- [IQL](examples/iql/iql.py)
- [CQL](examples/iql/cql.py)
- [TD3](examples/td3/td3.py)
- [DQN](sota-implementations/dqn)
- [DDPG](sota-implementations/ddpg/ddpg.py)
- [IQL](sota-implementations/iql/iql.py)
- [CQL](sota-implementations/iql/cql.py)
- [TD3](sota-implementations/td3/td3.py)
- [A2C](examples/a2c_old/a2c.py)
- [PPO](examples/ppo/ppo.py)
- [SAC](examples/sac/sac.py)
- [REDQ](examples/redq/redq.py)
- [Dreamer](examples/dreamer/dreamer.py)
- [Decision Transformers](examples/decision_transformer)
- [PPO](sota-implementations/ppo/ppo.py)
- [SAC](sota-implementations/sac/sac.py)
- [REDQ](sota-implementations/redq/redq.py)
- [Dreamer](sota-implementations/dreamer/dreamer.py)
- [Decision Transformers](sota-implementations/decision_transformer)
- [RLHF](examples/rlhf)

and many more to come!

Check the [examples markdown](examples/EXAMPLES.md) directory for more details
Check the [examples markdown](sota-implementations/SOTA-IMPLEMENTATIONS.md) directory for more details
about handling the various configuration settings.

We also provide [tutorials and demos](https://pytorch.org/rl/#tutorials) that give a sense of
Expand Down
Loading
Loading