Skip to content

Commit

Permalink
[BugFix] Fix submitit IP address/node name retrieval (#1672)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 3, 2023
1 parent ce8a1c1 commit 8ca7a39
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 17 deletions.
11 changes: 9 additions & 2 deletions examples/distributed/collectors/multi_nodes/delayed_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,22 @@
tcpport=tcp_port,
)
def main():
import gym
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
from torchrl.collectors.collectors import RandomPolicy
from torchrl.data import BoundedTensorSpec
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend

collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector
device_str = "device" if num_workers == 1 else "devices"

def make_env():
# gymnasium breaks when using multiproc
with set_gym_backend(gym):
return GymEnv("ALE/Pong-v5")

collector = DistributedDataCollector(
[EnvCreator(lambda: GymEnv("ALE/Pong-v5"))] * num_jobs,
[EnvCreator(make_env)] * num_jobs,
policy=RandomPolicy(BoundedTensorSpec(-1, 1, shape=(1,))),
launcher="submitit_delayed",
frames_per_batch=frames_per_batch,
Expand Down
11 changes: 9 additions & 2 deletions examples/distributed/collectors/multi_nodes/delayed_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,22 @@
framework="rpc",
)
def main():
import gym
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
from torchrl.collectors.collectors import RandomPolicy
from torchrl.data import BoundedTensorSpec
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend

collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector
device_str = "device" if num_workers == 1 else "devices"

def make_env():
# gymnasium breaks when using multiproc
with set_gym_backend(gym):
return GymEnv("ALE/Pong-v5")

collector = RPCDataCollector(
[EnvCreator(lambda: GymEnv("ALE/Pong-v5"))] * num_jobs,
[EnvCreator(make_env)] * num_jobs,
policy=RandomPolicy(BoundedTensorSpec(-1, 1, shape=(1,))),
launcher="submitit_delayed",
frames_per_batch=frames_per_batch,
Expand Down
10 changes: 8 additions & 2 deletions examples/distributed/collectors/multi_nodes/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import time
from argparse import ArgumentParser

import gym

import tqdm

from torchrl.collectors.collectors import (
Expand All @@ -14,7 +16,7 @@
)
from torchrl.collectors.distributed import DistributedDataCollector
from torchrl.envs import EnvCreator
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend

parser = ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -90,7 +92,11 @@
f"device assignment not implemented for backend {args.backend}"
)

make_env = EnvCreator(lambda: GymEnv(args.env))
def gym_make():
with set_gym_backend(gym):
return GymEnv(args.env)

make_env = EnvCreator(gym_make)
action_spec = make_env().action_spec

collector = DistributedDataCollector(
Expand Down
10 changes: 8 additions & 2 deletions examples/distributed/collectors/multi_nodes/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import time
from argparse import ArgumentParser

import gym

import torch
import tqdm

Expand All @@ -15,7 +17,7 @@
)
from torchrl.collectors.distributed import RPCDataCollector
from torchrl.envs import EnvCreator
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend

parser = ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -79,7 +81,11 @@
else:
collector_kwargs = {device_str: "cpu", "storing_{device_str}": "cpu"}

make_env = EnvCreator(lambda: GymEnv(args.env))
def gym_make():
with set_gym_backend(gym):
return GymEnv(args.env)

make_env = EnvCreator(gym_make)
action_spec = make_env().action_spec

collector = RPCDataCollector(
Expand Down
10 changes: 8 additions & 2 deletions examples/distributed/collectors/multi_nodes/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import time
from argparse import ArgumentParser

import gym

import tqdm

from torchrl.collectors.collectors import (
Expand All @@ -14,7 +16,7 @@
)
from torchrl.collectors.distributed import DistributedSyncDataCollector
from torchrl.envs import EnvCreator
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend

parser = ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -85,7 +87,11 @@
f"device assignment not implemented for backend {args.backend}"
)

make_env = EnvCreator(lambda: GymEnv(args.env))
def gym_make():
with set_gym_backend(gym):
return GymEnv(args.env)

make_env = EnvCreator(gym_make)
action_spec = make_env().action_spec

collector = DistributedSyncDataCollector(
Expand Down
10 changes: 8 additions & 2 deletions examples/distributed/collectors/single_machine/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import time
from argparse import ArgumentParser

import gym

import torch
import tqdm

Expand All @@ -31,7 +33,7 @@
)
from torchrl.collectors.distributed import DistributedDataCollector
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend

parser = ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -89,7 +91,11 @@

device_count = torch.cuda.device_count()

make_env = EnvCreator(lambda: GymEnv(args.env))
def gym_make():
with set_gym_backend(gym):
return GymEnv(args.env)

make_env = EnvCreator(gym_make)
if args.worker_parallelism == "collector" or num_workers == 1:
action_spec = make_env().action_spec
else:
Expand Down
10 changes: 8 additions & 2 deletions examples/distributed/collectors/single_machine/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
import time
from argparse import ArgumentParser

import gym

import torch.cuda
import tqdm

from torchrl.collectors.collectors import RandomPolicy, SyncDataCollector
from torchrl.collectors.distributed import RPCDataCollector
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend

parser = ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -85,7 +87,11 @@
else:
collector_kwargs = {"device": "cpu", "storing_device": "cpu"}

make_env = EnvCreator(lambda: GymEnv(args.env))
def gym_make():
with set_gym_backend(gym):
return GymEnv(args.env)

make_env = EnvCreator(gym_make)
if num_workers == 1:
action_spec = make_env().action_spec
else:
Expand Down
10 changes: 8 additions & 2 deletions examples/distributed/collectors/single_machine/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import time
from argparse import ArgumentParser

import gym

import torch
import tqdm

Expand All @@ -31,7 +33,7 @@
)
from torchrl.collectors.distributed import DistributedSyncDataCollector
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend

parser = ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -84,7 +86,11 @@

device_count = torch.cuda.device_count()

make_env = EnvCreator(lambda: GymEnv(args.env))
def gym_make():
with set_gym_backend(gym):
return GymEnv(args.env)

make_env = EnvCreator(gym_make)
if args.worker_parallelism == "collector" or num_workers == 1:
action_spec = make_env().action_spec
else:
Expand Down
3 changes: 2 additions & 1 deletion torchrl/collectors/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def exec_fun():
time.sleep(0.5)
continue
print(f"node: {node}")
cmd = f"sinfo -n {node} -O nodeaddr | tail -1"
# by default, sinfo will truncate the node name at char 20, we increase this to 200
cmd = f"sinfo -n {node} -O nodeaddr:200 | tail -1"
rank0_ip = subprocess.check_output(cmd, shell=True, text=True).strip()
print(f"IP: {rank0_ip}")
world_size = self.num_jobs + 1
Expand Down

0 comments on commit 8ca7a39

Please sign in to comment.