-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbench_vmap.py
66 lines (53 loc) · 1.9 KB
/
bench_vmap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
# for no gpu
# os.environ['CUDA_VISIBLE_DEVICES']=''
import time
import jax, jax.numpy as np
import vbjax as vb
def net(x, p):
r, v = x
k, _, mpr_p = p
c = k*r.sum(), k*v.sum()
return vb.mpr_dfun(x, c, mpr_p)
def noise(_, p):
_, sigma, _ = p
return sigma
_, loop = vb.make_sde(0.01, net, noise)
n_nodes = 164
rv0 = vb.randn(2, n_nodes)
zs = vb.randn(1000, *rv0.shape)
def run(pars, mpr_p=vb.mpr_default_theta):
k, sig, eta = pars # explored pars
p = k, sig, mpr_p._replace(eta=eta) # set mpr
xs = loop(rv0, zs, p) # run sim
std = xs[400:, 0].std() # eval metric
return std
run_batches = jax.jit(jax.vmap(run, in_axes=1))
def bench_cpu():
run_batches_cores = jax.pmap(jax.vmap(run, in_axes=1), in_axes=1)
for cores in [8]*10: #[2, 4, 6, 8, 16]:
for n in [4]: #[2,4,8,16]:
log_ks, etas = np.mgrid[-9.0:0.0:1j*n, -5.0:-6.0:36j]
pars = np.c_[np.exp(log_ks.ravel()),np.ones(log_ks.size)*0.2, etas.ravel()].T.copy()
pars = pars.reshape((3, cores, -1))
tic = time.time()
for i in range(50):
result = run_batches_cores(pars)
result.block_until_ready()
toc = time.time()
iter = 50*log_ks.size*zs.shape[0]
print(f'{cores} {n} {iter/1e6/(toc-tic):0.2f} Miter/s')
print()
def bench_gpu():
for n in [32]*20: #[2,4,8,16,32,48,64]:
log_ks, etas = np.mgrid[-9.0:0.0:1j*n, -5.0:-6.0:32j]
pars = np.c_[np.exp(log_ks.ravel()),np.ones(log_ks.size)*0.2, etas.ravel()].T.copy()
tic = time.time()
for i in range(50):
result = run_batches(pars)
result.block_until_ready()
toc = time.time()
iter = 50*log_ks.size*zs.shape[0]
print(f'{n} {iter/1e6/(toc-tic):0.2f} Miter/s')
if __name__ == '__main__':
bench_cpu()