Question: I can't release the memory in gpu during execute nt.empirical_ntk_fn #146
Open
Description
I have follow the solution in RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm #8506
and High GPU memory during empirical NTK calculation #100
But it doesn't work for me.
This code can reproduce the problem:
import numpy as np
import cupy as cp
import torch
import torchvision.datasets as datasets
import torch.nn.functional as F
import jax
from jax import random
import jax.numpy as jnp
from jax.example_libraries import optimizers
from jax import jit, grad, vmap, pmap
import functools
import neural_tangents as nt
from neural_tangents import stax
from tqdm import tqdm
import gc
mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.8
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
class_5_idx = np.where(mnist_trainset.targets.numpy() == 5)
class_3_idx = np.where(mnist_trainset.targets.numpy() == 3)
class_5 = mnist_trainset.data[class_5_idx]
class_3 = mnist_trainset.data[class_3_idx]
M = 200
W = H = 28
C = 1
P = 200
eta = 0.1
inputs = np.vstack((class_5[:P//2], class_3[:P//2]))
idx = np.arange(P)
np.random.shuffle(idx)
inputs = inputs[idx].reshape(-1, 28, 28, 1).astype('float32')
ys = np.vstack((*np.ones(P//2), *np.zeros(P//2)))
ys = ys[idx].astype('float32')
# kernel for all layers
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Conv(10, (5, 5), (1, 1), 'SAME'), stax.Relu(), # shape (output_c, filter_size, stride_size)
stax.Flatten(),
stax.Dense(M), stax.Relu(),
stax.Dense(1)
)
# kernel for first layer
init_fn_a, apply_fn_a, kernel_fn_a = stax.serial(
stax.Conv(10, (5, 5), (1, 1), 'SAME'), stax.Relu(),
stax.Flatten()
)
shape, params = init_fn(random.PRNGKey(np.random.randint(1e6)), inputs.shape)
eNTK = nt.empirical_ntk_fn(apply_fn, vmap_axes=0, trace_axes=(), implementation=2)
opt_init, opt_update, get_params = optimizers.sgd(eta)
opt_state = opt_init(params)
_eNTK_a = nt.empirical_ntk_fn(apply_fn_a, vmap_axes=0, trace_axes=(), implementation=2)
eNTK_a = jit(lambda x1, x2, params: _eNTK_a(x1, x2, params))
loss = jit(lambda params, x, y: 0.5 * jnp.mean((apply_fn(params, x) - y) ** 2))
grad_loss = jit(lambda state, x, y: grad(loss)(get_params(state), x, y))
nsteps = 1000
K_a = np.zeros((P, P))
for i in tqdm(range(nsteps)):
opt_state = opt_update(i, grad_loss(opt_state, inputs, ys), opt_state)
for j in range(P//2):
for k in range(P//2):
a = np.sum(eNTK_a(inputs[j:(j+1)*2], inputs[k:(k+1)*2],
[get_params(opt_state)[0], get_params(opt_state)[1], get_params(opt_state)[2]])
, axis=(2, 3))
K_a[j:(j+1)*2, k:(k+1)*2] = a
print(mempool.used_bytes())
print(mempool.total_bytes())
print(pinned_mempool.n_free_blocks())
a = cp.array(a)
print(mempool.used_bytes())
print(mempool.total_bytes())
print(pinned_mempool.n_free_blocks())
del a
print(mempool.used_bytes())
print(mempool.total_bytes())
print(pinned_mempool.n_free_blocks())
mempool.free_all_blocks()
pinned_mempool.free_all_blocks()
gc.collect()
break
break
Activity