Skip to content

Commit

Permalink
[BugFix,Refactor] Dreamer refactor (pytorch#1918)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vmoens@meta.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
3 people authored Apr 23, 2024
1 parent fcbc6ea commit bfadce9
Show file tree
Hide file tree
Showing 19 changed files with 1,252 additions and 902 deletions.
48 changes: 22 additions & 26 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -167,19 +167,17 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/di
# logger.record_video=True \
# logger.record_frames=4 \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
total_frames=200 \
init_random_frames=10 \
batch_size=10 \
frames_per_batch=200 \
num_workers=4 \
env_per_collector=2 \
collector_device=cuda:0 \
model_device=cuda:0 \
optim_steps_per_batch=1 \
record_video=True \
record_frames=4 \
buffer_size=120 \
rssm_hidden_dim=17
collector.total_frames=200 \
collector.init_random_frames=10 \
collector.frames_per_batch=200 \
env.n_parallel_envs=4 \
optimization.optim_steps_per_batch=1 \
logger.video=True \
logger.backend=csv \
replay_buffer.buffer_size=120 \
replay_buffer.batch_size=24 \
replay_buffer.batch_length=12 \
networks.rssm_hidden_dim=17
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
Expand Down Expand Up @@ -223,19 +221,17 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq

# With single envs
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
total_frames=200 \
init_random_frames=10 \
batch_size=10 \
frames_per_batch=200 \
num_workers=2 \
env_per_collector=1 \
collector_device=cuda:0 \
model_device=cuda:0 \
optim_steps_per_batch=1 \
record_video=True \
record_frames=4 \
buffer_size=120 \
rssm_hidden_dim=17
collector.total_frames=200 \
collector.init_random_frames=10 \
collector.frames_per_batch=200 \
env.n_parallel_envs=1 \
optimization.optim_steps_per_batch=1 \
logger.backend=csv \
logger.video=True \
replay_buffer.buffer_size=120 \
replay_buffer.batch_size=24 \
replay_buffer.batch_length=12 \
networks.rssm_hidden_dim=17
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ddpg/ddpg.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ Domain-specific

ModelBasedEnvBase
model_based.dreamer.DreamerEnv
model_based.dreamer.DreamerDecoder


Libraries
Expand Down
105 changes: 66 additions & 39 deletions sota-implementations/dreamer/config.yaml
Original file line number Diff line number Diff line change
@@ -1,39 +1,66 @@
env_name: cheetah
env_task: run
env_library: dm_control
catframes: 1
async_collection: True
record_video: 0
frame_skip: 2
batch_size: 50
batch_length: 50
total_frames: 5000000
world_model_lr: 6e-4
actor_value_lr: 8e-5
from_pixels: True
# we want 50 frames / traj in the replay buffer. Given the frame_skip=2 this makes each traj 100 steps long
env_per_collector: 8
num_workers: 8
collector_device: cuda:1
model_device: cuda:0
frames_per_batch: 800
optim_steps_per_batch: 80
record_interval: 30
max_frames_per_traj: 1000
record_frames: 1000
batch_transform: 1
state_dim: 30
rssm_hidden_dim: 200
grad_clip: 100
grayscale: False
image_size : 64
buffer_size: 20000
init_env_steps: 1000
init_random_frames: 5000
logger: csv
offline_logging: False
project_name: torchrl_example_dreamer
normalize_rewards_online: True
normalize_rewards_online_scale: 5.0
normalize_rewards_online_decay: 0.99999
reward_scaling: 1.0
env:
name: cheetah
task: run
seed: 0
backend: dm_control
frame_skip: 2
from_pixels: True
grayscale: False
image_size : 64
horizon: 500
n_parallel_envs: 8
device:
_target_: dreamer_utils._default_device
device: null

collector:
total_frames: 5_000_000
init_random_frames: 3000
frames_per_batch: 1000
device:
_target_: dreamer_utils._default_device
device: null

optimization:
train_every: 1000
grad_clip: 100

world_model_lr: 6e-4
actor_lr: 8e-5
value_lr: 8e-5
kl_scale: 1.0
free_nats: 3.0
optim_steps_per_batch: 80
gamma: 0.99
lmbda: 0.95
imagination_horizon: 15
compile: False
compile_backend: inductor
use_autocast: True

networks:
exploration_noise: 0.3
device:
_target_: dreamer_utils._default_device
device: null
state_dim: 30
rssm_hidden_dim: 200
hidden_dim: 400
activation: "elu"


replay_buffer:
batch_size: 2500
buffer_size: 1000000
batch_length: 50
scratch_dir: null

logger:
backend: wandb
project: dreamer-v1
exp_name: ${env.name}-${env.task}-${env.seed}
mode: online
# eval interval, in collection counts
eval_iter: 10
eval_rollout_steps: 500
video: False
Loading

0 comments on commit bfadce9

Please sign in to comment.