forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_rb_distributed.py
142 lines (117 loc) · 4.17 KB
/
test_rb_distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
import sys
import time
import pytest
import torch
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
from tensordict import TensorDict
from torchrl._utils import logger as torchrl_logger
from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.data.replay_buffers.writers import RoundRobinWriter
RETRY_COUNT = 3
RETRY_BACKOFF = 3
class ReplayBufferNode(RemoteTensorDictReplayBuffer):
def __init__(self, capacity: int, scratch_dir=None):
super().__init__(
storage=LazyMemmapStorage(
max_size=capacity, scratch_dir=scratch_dir, device=torch.device("cpu")
),
sampler=RandomSampler(),
writer=RoundRobinWriter(),
collate_fn=lambda x: x,
)
def construct_buffer_test(rank, name, world_size):
if name == "TRAINER":
buffer = _construct_buffer("BUFFER")
assert type(buffer) is torch._C._distributed_rpc.PyRRef
def add_to_buffer_remotely_test(rank, name, world_size):
if name == "TRAINER":
buffer = _construct_buffer("BUFFER")
res, _ = _add_random_tensor_dict_to_buffer(buffer)
assert type(res) is int
assert res == 0
def sample_from_buffer_remotely_returns_correct_tensordict_test(rank, name, world_size):
if name == "TRAINER":
buffer = _construct_buffer("BUFFER")
_, inserted = _add_random_tensor_dict_to_buffer(buffer)
sampled = _sample_from_buffer(buffer, 1)
assert type(sampled) is type(inserted) is TensorDict
a_sample = sampled["a"]
a_insert = inserted["a"]
assert (a_sample == a_insert).all()
@pytest.mark.skipif(
sys.platform == "win32",
reason="Distributed package support on Windows is a prototype feature and is subject to changes.",
)
@pytest.mark.parametrize("names", [["BUFFER", "TRAINER"]])
@pytest.mark.parametrize(
"func",
[
construct_buffer_test,
add_to_buffer_remotely_test,
sample_from_buffer_remotely_returns_correct_tensordict_test,
],
)
def test_funcs(names, func):
world_size = len(names)
with mp.Pool(world_size) as pool:
pool.starmap(
init_rpc, ((rank, name, world_size) for rank, name in enumerate(names))
)
pool.starmap(
func, ((rank, name, world_size) for rank, name in enumerate(names))
)
pool.apply_async(shutdown)
def init_rpc(rank, name, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
str_init_method = "tcp://localhost:10030"
options = rpc.TensorPipeRpcBackendOptions(
num_worker_threads=16, init_method=str_init_method
)
rpc.init_rpc(
name,
rank=rank,
backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=options,
world_size=world_size,
)
def shutdown():
rpc.shutdown()
def _construct_buffer(target):
for _ in range(RETRY_COUNT):
try:
buffer_rref = rpc.remote(target, ReplayBufferNode, args=(1000,))
return buffer_rref
except Exception as e:
torchrl_logger.info(f"Failed to connect: {e}")
time.sleep(RETRY_BACKOFF)
raise RuntimeError("Unable to connect to replay buffer")
def _add_random_tensor_dict_to_buffer(buffer):
rand_td = TensorDict({"a": torch.randint(100, (1,))}, [])
return (
rpc.rpc_sync(
buffer.owner(),
ReplayBufferNode.add,
args=(
buffer,
rand_td,
),
),
rand_td,
)
def _sample_from_buffer(buffer, batch_size):
return rpc.rpc_sync(
buffer.owner(), ReplayBufferNode.sample, args=(buffer, batch_size)
)
if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)