forked from zju3dv/EasyVolcap
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpin_memory_release_tests.py
117 lines (98 loc) · 3.87 KB
/
pin_memory_release_tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gc
import torch
from torch import nn
from torch.nn import functional as F
from easyvolcap.utils.test_utils import my_tests, assert_allclose, grad_check
from easyvolcap.utils.console_utils import *
def test_pinning_parameters():
N_LEVELS = 1
N_NODES = 100
N_PARAMS = 9
model = nn.Module()
sampler = nn.Module()
network = nn.Module()
renderer = nn.Module()
tfgs = nn.Module()
tfgs.forest = nn.ParameterList([
nn.ParameterList([
nn.ParameterDict({
str(k):
torch.empty(0, 4, 2**19)
for k in range(N_PARAMS)
}) for j in range(N_NODES)
]) for i in tqdm(range(N_LEVELS))
])
network.tfgs = tfgs
sampler.unregistered = [network]
renderer.unregistered = [network]
model.sampler = sampler
model.network = network
model.renderer = renderer
data = torch.rand(1, 4, 2**19, N_LEVELS, N_NODES, N_PARAMS)
length = len(list(model.named_parameters()))
pbar = tqdm(total=length)
# # with torch.no_grad():
# for name, param in model.named_parameters():
# # param.data = torch.cat([param.data, torch.rand((1,) + param.data.shape[1:])], dim=0)
# # param.data = torch.rand((1,) + param.data.shape[1:])
# param.set_(torch.rand((1,) + param.data.shape[1:]))
# # param = nn.Parameter(torch.rand((1,) + param.data.shape[1:]), requires_grad=param.requires_grad)
with torch.no_grad():
for l in range(N_LEVELS):
for n in range(N_NODES):
for p in range(N_PARAMS):
# param = model.tfgs[l][n][str(p)]
param = getattr(model.network.tfgs.forest[l][n], str(p))
# param.data = data[:, :, :, l, n, p]
# param.data = data[:, :, torch.randperm(2**19)].contiguous()
# param.data = torch.rand((1,) + param.data.shape[1:])
# param.set_(torch.rand((1,) + param.data.shape[1:]))
# param.data = torch.cat([param.data, torch.rand((1,) + param.data.shape[1:])], dim=0)
param.data = torch.cat([param.data, data[:, :, :, l, n, p].contiguous()], dim=0)
pbar.update()
pbar.close()
del data
gc.collect()
# torch.cuda.empty_cache()
# from easyvolcap.utils.host_utils import host_empty_cache
# host_empty_cache()
# model.to('cpu', non_blocking=True)
pbar = tqdm(total=length)
for name, param in model.named_parameters():
param.data = param.data.to('cuda', non_blocking=True).to('cpu', non_blocking=True)
# p.data = p.data.pin_memory()
pbar.update()
pbar.close()
breakpoint()
def test_minimal_pinning():
N_LEVELS = 9
N_NODES = 100
N_PARAMS = 9
data = torch.rand(1, 2**16, 4, N_LEVELS, N_NODES, N_PARAMS)
module = nn.Module()
module.copied = torch.empty(0, 2**16, 4, N_LEVELS, N_NODES, N_PARAMS)
module.copied.data = torch.cat([module.copied.data, data], dim=0)
del data
gc.collect()
module.copied.data = module.copied.data.pin_memory()
breakpoint()
def test_pin_memory_release_behavior():
REPEAT = 128
N = 1024
abc = dotdict()
data_stream = torch.cuda.Stream()
from easyvolcap.utils.host_utils import host_empty_cache
for i in tqdm(range(REPEAT)):
with torch.cuda.stream(data_stream):
gc.collect()
# torch.xpu.empty_cache()
torch.cuda.empty_cache()
host_empty_cache()
abc.edf = torch.empty(N, N, N + i, device='cpu', pin_memory=True)
# abc.edf.copy_(torch.rand(N, N, N + i, device='cuda'), non_blocking=True)
# abc.edf = torch.rand(N, N, N + i, device='cuda').to('cpu', non_blocking=True) # 4GB
# torch.cuda.synchronize()
# gc.collect()
# breakpoint()
if __name__ == '__main__':
my_tests(globals())