diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index c86ba9a543c..ac4cd71ddad 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -80,8 +80,10 @@ class BraxWrapper(_EnvWrapper): Examples: >>> import brax.envs >>> from torchrl.envs import BraxWrapper + >>> import torch + >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> base_env = brax.envs.get_environment("ant") - >>> env = BraxWrapper(base_env) + >>> env = BraxWrapper(base_env, device=device) >>> env.set_seed(0) >>> td = env.reset() >>> td["action"] = env.action_spec.rand() @@ -111,7 +113,9 @@ class BraxWrapper(_EnvWrapper): and report the execution time for a short rollout: Examples: + >>> import torch >>> from torch.utils.benchmark import Timer + >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> for batch_size in [4, 16, 128]: ... timer = Timer(''' ... env.rollout(100) @@ -119,7 +123,7 @@ class BraxWrapper(_EnvWrapper): ... setup=f''' ... import brax.envs ... from torchrl.envs import BraxWrapper - ... env = BraxWrapper(brax.envs.get_environment("ant"), batch_size=[{batch_size}]) + ... env = BraxWrapper(brax.envs.get_environment("ant"), batch_size=[{batch_size}], device="{device}") ... env.set_seed(0) ... env.rollout(2) ... ''') @@ -459,7 +463,9 @@ class BraxEnv(BraxWrapper): Examples: >>> from torchrl.envs import BraxEnv - >>> env = BraxEnv("ant") + >>> import torch + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + >>> env = BraxEnv("ant", device=device) >>> env.set_seed(0) >>> td = env.reset() >>> td["action"] = env.action_spec.rand() @@ -489,13 +495,16 @@ class BraxEnv(BraxWrapper): and report the execution time for a short rollout: Examples: + >>> import torch + >>> from torch.utils.benchmark import Timer + >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> for batch_size in [4, 16, 128]: ... timer = Timer(''' ... env.rollout(100) ... ''', ... setup=f''' ... from torchrl.envs import BraxEnv - ... env = BraxEnv("ant", batch_size=[{batch_size}]) + ... env = BraxEnv("ant", batch_size=[{batch_size}], device="{device}") ... env.set_seed(0) ... env.rollout(2) ... ''')