Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Allow usage of a different device on main and sub-envs in ParallelEnv and SerialEnv #1626

Merged
merged 71 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
1f539dd
init
vmoens Oct 6, 2023
d4c16e1
amend
vmoens Oct 6, 2023
d2321aa
amend
vmoens Oct 6, 2023
565115a
amend
vmoens Oct 6, 2023
3c46136
amend
vmoens Oct 6, 2023
78cfa41
amend
vmoens Oct 6, 2023
a6bd8eb
amend
vmoens Oct 6, 2023
04d4ae7
amend
vmoens Oct 6, 2023
3e31963
Merge remote-tracking branch 'origin/main' into step_maybe_reset
vmoens Oct 10, 2023
f1b0ea4
tensordict_
vmoens Oct 10, 2023
16b3538
amend rollout logic
vmoens Oct 10, 2023
7ad0864
amend
vmoens Oct 10, 2023
bcac398
amend
vmoens Oct 10, 2023
428f8ee
inference
vmoens Oct 10, 2023
6fbd0bd
cpu -> cuda
vmoens Oct 10, 2023
02db623
checks
vmoens Oct 10, 2023
08a8f47
using pipe instead of event
vmoens Oct 10, 2023
45e64f7
amend
vmoens Oct 10, 2023
7dd4821
amend
vmoens Oct 10, 2023
e1a2206
rm cuda event
vmoens Oct 10, 2023
dc2caab
amend
vmoens Oct 10, 2023
01ffbf9
amend
vmoens Oct 10, 2023
ac76ec3
amend
vmoens Oct 10, 2023
ceab010
amend
vmoens Oct 10, 2023
f0327c9
amend
vmoens Oct 10, 2023
518b3d1
amend
vmoens Oct 10, 2023
47dd93b
amend
vmoens Oct 10, 2023
354fb6f
amend
vmoens Oct 10, 2023
53d5f9a
amend
vmoens Oct 10, 2023
78c00e8
amend
vmoens Oct 10, 2023
f63480e
amend
vmoens Oct 10, 2023
9a3631f
amend
vmoens Oct 10, 2023
2ceb438
amend
vmoens Oct 10, 2023
6ecebda
amend
vmoens Oct 10, 2023
5c613c3
amend
vmoens Oct 10, 2023
9f97e58
amend
vmoens Oct 10, 2023
9cbcbb0
amend
vmoens Oct 10, 2023
72c4163
amend
vmoens Oct 10, 2023
6f4c374
amend
vmoens Oct 10, 2023
3657b41
empty
vmoens Oct 11, 2023
9206b93
amend
vmoens Oct 11, 2023
897123f
Merge remote-tracking branch 'origin/main' into parallel_cuda_refactor
vmoens Nov 10, 2023
2bac78c
amend
vmoens Nov 10, 2023
51a8856
amend
vmoens Nov 10, 2023
7a300f5
amend
vmoens Nov 10, 2023
b25921d
amend
vmoens Nov 10, 2023
d42348d
amend
vmoens Nov 10, 2023
f3421aa
amend
vmoens Nov 10, 2023
939ece4
amend
vmoens Nov 10, 2023
082ba9a
amend
vmoens Nov 10, 2023
4fd670f
amend
vmoens Nov 10, 2023
2a773f3
amend
vmoens Nov 10, 2023
e7cb5dd
Merge remote-tracking branch 'origin/main' into parallel_cuda_refactor
vmoens Nov 27, 2023
ea88bb4
Merge remote-tracking branch 'origin/main' into parallel_cuda_refactor
vmoens Nov 27, 2023
7c04f62
amend
vmoens Nov 27, 2023
b422775
amend
vmoens Nov 27, 2023
ed0287d
amend
vmoens Nov 27, 2023
05b0c08
amend
vmoens Nov 27, 2023
33de0fc
amend
vmoens Nov 28, 2023
492a884
amend
vmoens Nov 28, 2023
ff4799d
amend
vmoens Nov 28, 2023
f25b957
amend
vmoens Nov 28, 2023
8899dbd
amend
vmoens Nov 28, 2023
65c9deb
amend
vmoens Nov 28, 2023
77c2d6b
amend
vmoens Nov 28, 2023
7928744
amend
vmoens Nov 28, 2023
e7fda36
amend
vmoens Nov 29, 2023
fb9a03a
amend
vmoens Nov 29, 2023
d73ca22
amend
vmoens Nov 29, 2023
c4d4c6b
amend
vmoens Nov 29, 2023
b2840b0
doc
vmoens Nov 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
cpu -> cuda
  • Loading branch information
vmoens committed Oct 10, 2023
commit 6fbd0bd78116d90afe1ad3c932d89eb95e59e8ce
31 changes: 9 additions & 22 deletions benchmarks/ecosystem/gym_env_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,7 @@ def make(envname=envname, gym_backend=gym_backend):
log.flush()

# regular parallel env
for device in (
"cuda:0",
"cpu",
):
for device in ("cpu", "cuda:0"):

def make(envname=envname, gym_backend=gym_backend, device=device):
with set_gym_backend(gym_backend):
Expand All @@ -88,7 +85,9 @@ def make(envname=envname, gym_backend=gym_backend, device=device):
t0 = time.time()
data = None
for _ in range(100):
data = penv.rollout(100, break_when_any_done=False, out=data)
data = penv.rollout(
100, break_when_any_done=False, out=data
)
pbar.update(100 * num_workers)
log.write(
f"penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n"
Expand All @@ -98,7 +97,7 @@ def make(envname=envname, gym_backend=gym_backend, device=device):
timeit.print()
del penv

for device in ("cuda:0", "cpu"):
for device in ("cpu", "cuda:0"):

def make(envname=envname, gym_backend=gym_backend, device=device):
with set_gym_backend(gym_backend):
Expand Down Expand Up @@ -131,10 +130,7 @@ def make(envname=envname, gym_backend=gym_backend, device=device):
collector.shutdown()
del collector

for device in (
"cuda:0",
"cpu",
):
for device in ("cpu", "cuda:0"):
# gym parallel env
def make_env(
envname=envname,
Expand Down Expand Up @@ -201,10 +197,7 @@ def make_env(
collector.shutdown()
del collector

for device in (
"cuda:0",
"cpu",
):
for device in ("cpu", "cuda:0"):
# async collector
# + gym async env
def make_env(
Expand Down Expand Up @@ -248,10 +241,7 @@ def make_env(
collector.shutdown()
del collector

for device in (
"cuda:0",
"cpu",
):
for device in ("cpu", "cuda:0"):
# sync collector
# + torchrl parallel env
def make_env(
Expand Down Expand Up @@ -288,10 +278,7 @@ def make_env(
collector.shutdown()
del collector

for device in (
"cuda:0",
"cpu",
):
for device in ("cpu", "cuda:0"):
# sync collector
# + gym async env
def make_env(
Expand Down
6 changes: 2 additions & 4 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,7 +1650,7 @@ def rollout(
break_when_any_done: bool = True,
return_contiguous: bool = True,
tensordict: Optional[TensorDictBase] = None,
out = None,
out=None,
):
"""Executes a rollout in the environment.

Expand Down Expand Up @@ -1787,9 +1787,7 @@ def rollout(
raise RuntimeError("tensordict must be provided when auto_reset is False")
if policy is None:

def policy(td):
self.rand_action(td)
return td
policy = self.rand_action

kwargs = {
"tensordict": tensordict,
Expand Down
Loading