Skip to content

Commit

Permalink
[BugFix] Fix EXAMPLES.md (pytorch#1649)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Oct 25, 2023
1 parent 105e861 commit e353b20
Show file tree
Hide file tree
Showing 29 changed files with 1,538 additions and 314 deletions.
91 changes: 46 additions & 45 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.collector_device=cuda:0 \
collector.device=cuda:0 \
network.device=cuda:0 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
Expand Down Expand Up @@ -107,23 +107,24 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \
record_frames=4 \
buffer_size=120
python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
total_frames=48 \
init_random_frames=10 \
batch_size=10 \
frames_per_batch=16 \
num_workers=4 \
env_per_collector=2 \
collector_device=cuda:0 \
optim_steps_per_batch=1 \
record_video=True \
record_frames=4 \
buffer_size=120
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
buffer.batch_size=10 \
optim.steps_per_batch=1 \
logger.record_video=True \
logger.record_frames=4 \
buffer.size=120 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.collector_device=cuda:0 \
collector.device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
Expand Down Expand Up @@ -152,21 +153,21 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
collector.frames_per_batch=16 \
collector.num_workers=4 \
collector.env_per_collector=2 \
collector.collector_device=cuda:0 \
collector.device=cuda:0 \
collector.device=cuda:0 \
network.device=cuda:0 \
logger.mode=offline \
env.name=Pendulum-v1 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \
total_frames=48 \
batch_size=10 \
frames_per_batch=16 \
num_workers=4 \
env_per_collector=2 \
collector_device=cuda:0 \
device=cuda:0 \
mode=offline \
logger=
collector.total_frames=48 \
buffer.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
network.device=cuda:0 \
logger.mode=offline \
logger.backend=

# With single envs
python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \
Expand All @@ -188,7 +189,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
collector.collector_device=cuda:0 \
collector.device=cuda:0 \
network.device=cuda:0 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
Expand All @@ -209,23 +210,24 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \
record_frames=4 \
buffer_size=120
python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
total_frames=48 \
init_random_frames=10 \
batch_size=10 \
frames_per_batch=16 \
num_workers=2 \
env_per_collector=1 \
collector_device=cuda:0 \
optim_steps_per_batch=1 \
record_video=True \
record_frames=4 \
buffer_size=120
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
buffer.batch_size=10 \
collector.device=cuda:0 \
optim.steps_per_batch=1 \
logger.record_video=True \
logger.record_frames=4 \
buffer.size=120 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
collector.collector_device=cuda:0 \
collector.device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
network.device=cuda:0 \
Expand All @@ -235,24 +237,23 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
env.name=Pendulum-v1 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \
total_frames=48 \
batch_size=10 \
frames_per_batch=16 \
num_workers=2 \
env_per_collector=1 \
mode=offline \
device=cuda:0 \
collector_device=cuda:0 \
logger=
collector.total_frames=48 \
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
collector.device=cuda:0 \
network.device=cuda:0 \
buffer.batch_size=10 \
logger.mode=offline \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.num_workers=2 \
collector.env_per_collector=1 \
collector.device=cuda:0 \
logger.mode=offline \
collector.collector_device=cuda:0 \
optim.batch_size=10 \
env.name=Pendulum-v1 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/mappo_ippo.py \
Expand Down
76 changes: 49 additions & 27 deletions examples/EXAMPLES.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ python sac.py
```
or similar. Hyperparameters can be easily changed by providing the arguments to hydra:
```
python sac.py frames_per_batch=63
python sac.py collector.frames_per_batch=63
```
# Results

Expand All @@ -32,11 +32,11 @@ We average the results over 5 different seeds and plot the standard error.
To reproduce a single run:

```
python sac/sac.py env_name="HalfCheetah-v4" env_task="" env_library="gym"
python sac/sac.py env.name="HalfCheetah-v4" env.task="" env.library="gym"
```

```
python redq/redq.py env_name="HalfCheetah-v4" env_task="" env_library="gym"
python redq/redq.py env.name="HalfCheetah-v4" env.library="gymnasium"
```


Expand All @@ -48,39 +48,61 @@ python redq/redq.py env_name="HalfCheetah-v4" env_task="" env_library="gym"
To reproduce a single run:

```
python sac/sac.py env_name="cheetah" env_task="run" env_library="dm_control"
python sac/sac.py env.name="cheetah" env.task="run" env.library="dm_control"
```

```
python redq/redq.py env_name="cheetah" env_task="run" env_library="dm_control"
python redq/redq.py env.name="cheetah" env.task="run" env.library="dm_control"
```

## Gym's Ant-v4
[//]: # (TODO: adapt these scripts)
[//]: # (## Gym's Ant-v4)

<p align="center">
<img src="media/ant_chart.png" width="600px">
</p>
To reproduce a single run:
[//]: # ()
[//]: # (<p align="center">)

```
python sac/sac.py env_name="Ant-v4" env_task="" env_library="gym"
```
[//]: # (<img src="media/ant_chart.png" width="600px">)

```
python redq/redq.py env_name="Ant-v4" env_task="" env_library="gym"
```
[//]: # (</p>)

## Gym's Walker2D-v4
[//]: # (To reproduce a single run:)

<p align="center">
<img src="media/walker2d_chart.png" width="600px">
</p>
To reproduce a single run:
[//]: # ()
[//]: # (```)

```
python sac/sac.py env_name="Walker2D-v4" env_task="" env_library="gym"
```
[//]: # (python sac/sac.py env.name="Ant-v4" env.task="" env.library="gym")

```
python redq/redq.py env_name="Walker2D-v4" env_task="" env_library="gym"
```
[//]: # (```)

[//]: # ()
[//]: # (``` )

[//]: # (python redq/redq.py env_name="Ant-v4" env_task="" env_library="gym")

[//]: # (```)

[//]: # ()
[//]: # (## Gym's Walker2D-v4)

[//]: # ()
[//]: # (<p align="center">)

[//]: # (<img src="media/walker2d_chart.png" width="600px">)

[//]: # (</p>)

[//]: # (To reproduce a single run:)

[//]: # ()
[//]: # (```)

[//]: # (python sac/sac.py env_name="Walker2D-v4" env_task="" env_library="gym")

[//]: # (```)

[//]: # ()
[//]: # (``` )

[//]: # (python redq/redq.py env_name="Walker2D-v4" env_task="" env_library="gym")

[//]: # (```)
2 changes: 1 addition & 1 deletion examples/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)


@hydra.main(config_path=".", config_name="offline_config")
@hydra.main(config_path=".", config_name="offline_config", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
exp_name = generate_exp_name("CQL-offline", cfg.env.exp_name)
logger = None
Expand Down
2 changes: 1 addition & 1 deletion examples/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)


@hydra.main(config_path=".", config_name="online_config")
@hydra.main(version_base="1.1", config_path=".", config_name="online_config")
def main(cfg: "DictConfig"): # noqa: F821
exp_name = generate_exp_name("CQL-online", cfg.env.exp_name)
logger = None
Expand Down
2 changes: 1 addition & 1 deletion examples/cql/online_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ collector:
multi_step: 0
init_random_frames: 1000
env_per_collector: 1
collector_device: cpu
device: cpu
max_frames_per_traj: 200

# logger
Expand Down
27 changes: 21 additions & 6 deletions examples/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
from torchrl.data.replay_buffers import SamplerWithoutReplacement
from torchrl.envs import (
CatTensors,
Compose,
DMControlEnv,
DoubleToFloat,
EnvCreator,
ParallelEnv,
RewardScaling,
TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import CQLLoss, SoftUpdate
Expand All @@ -32,8 +34,21 @@
# -----------------


def env_maker(task, frame_skip=1, device="cpu", from_pixels=False):
return GymEnv(task, device=device, frame_skip=frame_skip, from_pixels=from_pixels)
def env_maker(cfg, device="cpu"):
lib = cfg.env.library
if lib in ("gym", "gymnasium"):
with set_gym_backend(lib):
return GymEnv(
cfg.env.name,
device=device,
)
elif lib == "dm_control":
env = DMControlEnv(cfg.env.name, cfg.env.task)
return TransformedEnv(
env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
)
else:
raise NotImplementedError(f"Unknown lib {lib}.")


def apply_env_transforms(env, reward_scaling=1.0):
Expand All @@ -51,7 +66,7 @@ def make_environment(cfg, num_envs=1):
"""Make environments for training and evaluation."""
parallel_env = ParallelEnv(
num_envs,
EnvCreator(lambda: env_maker(task=cfg.env.name)),
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
)
parallel_env.set_seed(cfg.env.seed)

Expand All @@ -60,7 +75,7 @@ def make_environment(cfg, num_envs=1):
eval_env = TransformedEnv(
ParallelEnv(
num_envs,
EnvCreator(lambda: env_maker(task=cfg.env.name)),
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
),
train_env.transform.clone(),
)
Expand All @@ -80,7 +95,7 @@ def make_collector(cfg, train_env, actor_model_explore):
frames_per_batch=cfg.collector.frames_per_batch,
max_frames_per_traj=cfg.collector.max_frames_per_traj,
total_frames=cfg.collector.total_frames,
device=cfg.collector.collector_device,
device=cfg.collector.device,
)
collector.set_seed(cfg.env.seed)
return collector
Expand Down
2 changes: 1 addition & 1 deletion examples/ddpg/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ collector:
frames_per_batch: 1000
init_env_steps: 1000
reset_at_each_iter: False
collector_device: cpu
device: cpu
env_per_collector: 1


Expand Down
Loading

0 comments on commit e353b20

Please sign in to comment.