Skip to content

Commit

Permalink
[BugFix] brax examples (pytorch#2318)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jendker authored Jul 24, 2024
1 parent c6ef080 commit 8a74642
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions torchrl/envs/libs/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ class BraxWrapper(_EnvWrapper):
Examples:
>>> import brax.envs
>>> from torchrl.envs import BraxWrapper
>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> base_env = brax.envs.get_environment("ant")
>>> env = BraxWrapper(base_env)
>>> env = BraxWrapper(base_env, device=device)
>>> env.set_seed(0)
>>> td = env.reset()
>>> td["action"] = env.action_spec.rand()
Expand Down Expand Up @@ -111,15 +113,17 @@ class BraxWrapper(_EnvWrapper):
and report the execution time for a short rollout:
Examples:
>>> import torch
>>> from torch.utils.benchmark import Timer
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> for batch_size in [4, 16, 128]:
... timer = Timer('''
... env.rollout(100)
... ''',
... setup=f'''
... import brax.envs
... from torchrl.envs import BraxWrapper
... env = BraxWrapper(brax.envs.get_environment("ant"), batch_size=[{batch_size}])
... env = BraxWrapper(brax.envs.get_environment("ant"), batch_size=[{batch_size}], device="{device}")
... env.set_seed(0)
... env.rollout(2)
... ''')
Expand Down Expand Up @@ -459,7 +463,9 @@ class BraxEnv(BraxWrapper):
Examples:
>>> from torchrl.envs import BraxEnv
>>> env = BraxEnv("ant")
>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> env = BraxEnv("ant", device=device)
>>> env.set_seed(0)
>>> td = env.reset()
>>> td["action"] = env.action_spec.rand()
Expand Down Expand Up @@ -489,13 +495,16 @@ class BraxEnv(BraxWrapper):
and report the execution time for a short rollout:
Examples:
>>> import torch
>>> from torch.utils.benchmark import Timer
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> for batch_size in [4, 16, 128]:
... timer = Timer('''
... env.rollout(100)
... ''',
... setup=f'''
... from torchrl.envs import BraxEnv
... env = BraxEnv("ant", batch_size=[{batch_size}])
... env = BraxEnv("ant", batch_size=[{batch_size}], device="{device}")
... env.set_seed(0)
... env.rollout(2)
... ''')
Expand Down

0 comments on commit 8a74642

Please sign in to comment.