Description
Good day!
I'm trying to use memory profiler pprof
as described here:
https://jax.readthedocs.io/en/latest/device_memory_profiling.html
I'm trying to train Myrtle NTK infinite network on CIFAR with architecture taken from Colab notebook:
https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/myrtle_kernel_with_neural_tangents.ipynb
import jax.profiler
def MyrtleNetwork(depth, W_std=np.sqrt(2.0), b_std=0.):
layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]}
width = 1
activation_fn = stax.Relu()
layers = []
conv = functools.partial(stax.Conv, W_std=W_std, b_std=b_std, padding='SAME')
layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][0]
layers += [stax.AvgPool((2, 2), strides=(2, 2))]
layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][1]
layers += [stax.AvgPool((2, 2), strides=(2, 2))]
layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][2]
layers += [stax.AvgPool((2, 2), strides=(2, 2))] * 3
layers += [stax.Flatten(), stax.Dense(10, W_std, b_std)]
return stax.serial(*layers)
from jax.lib import xla_bridge
... training kernel with batch size 4...
jax.profiler.save_device_memory_profile(output_fname + '.prof', 'gpu')
The problem is that when I check output of the profiler, it's memory consumption looks very small:
flat flat% sum% cum cum%
806.39kB 87.80% 87.80% 806.39kB 87.80% backend_compile
112.01kB 12.20% 100% 112.01kB 12.20% _execute_compiled
0 0% 100% 918.40kB 100% <unknown>
0 0% 100% 112.01kB 12.20% <unknown>
0 0% 100% 28.81kB 3.14% <unknown>
0 0% 100% 22.65kB 2.47% <unknown>
0 0% 100% 22.33kB 2.43% <unknown>
0 0% 100% 20.62kB 2.25% <unknown>
0 0% 100% 6.48kB 0.71% <unknown>
0 0% 100% 6.48kB 0.71% <unknown>
0 0% 100% 9.59kB 1.04% PRNGKey
0 0% 100% 11.69kB 1.27% _call_with_frames_removed
0 0% 100% 11.69kB 1.27% _find_and_load
0 0% 100% 11.69kB 1.27% _find_and_load_unlocked
0 0% 100% 118.17kB 12.87% _flatten_batch_dimensions
0 0% 100% 118.17kB 12.87% _flatten_kernel
0 0% 100% 15.94kB 1.74% _gather
0 0% 100% 8.21kB 0.89% _index_to_gather
0 0% 100% 11.69kB 1.27% _load_unlocked
0 0% 100% 6.37kB 0.69% _normalize_index
0 0% 100% 21.34kB 2.32% _reduce_window_sum
0 0% 100% 120.30kB 13.10% _reshape
0 0% 100% 15.94kB 1.74% _rewriting_take
0 0% 100% 140.81kB 15.33% _scan
0 0% 100% 641.46kB 69.84% _xla_call_impl
0 0% 100% 806.39kB 87.80% _xla_callable_uncached
0 0% 100% 214.72kB 23.38% apply_fn
0 0% 100% 256.69kB 27.95% apply_fn_with_masking
0 0% 100% 256.69kB 27.95% apply_fun
0 0% 100% 276.95kB 30.16% apply_primitive
0 0% 100% 918.40kB 100% bind
0 0% 100% 276.95kB 30.16% bind_with_trace
0 0% 100% 5.25kB 0.57% broadcast
0 0% 100% 8.02kB 0.87% broadcast_in_dim
0 0% 100% 641.46kB 69.84% cache_miss
0 0% 100% 164.94kB 17.96% cached
0 0% 100% 641.46kB 69.84% call_bind
0 0% 100% 118.48kB 12.90% col_fn
0 0% 100% 806.39kB 87.80% compile
0 0% 100% 806.39kB 87.80% compile_or_get_cached
0 0% 100% 24.82kB 2.70% concatenate
0 0% 100% 69.81kB 7.60% conv_general_dilated
0 0% 100% 106.36kB 11.58% deferring_binary_op
0 0% 100% 7.08kB 0.77% dot_general
0 0% 100% 11.69kB 1.27% exec_module
0 0% 100% 118.48kB 12.90% f_pmapped
0 0% 100% 91.50kB 9.96% fn
0 0% 100% 806.39kB 87.80% from_xla_computation
0 0% 100% 6.17kB 0.67% full
0 0% 100% 7.73kB 0.84% gather
0 0% 100% 258.98kB 28.20% h
0 0% 100% 40.05kB 4.36% init_fun
0 0% 100% 641.45kB 69.84% memoized_fun
0 0% 100% 33.72kB 3.67% normal
0 0% 100% 33.72kB 3.67% ntk_init_fn
0 0% 100% 317.85kB 34.61% odeint
0 0% 100% 349.15kB 38.02% predict_fn
0 0% 100% 641.46kB 69.84% process_call
0 0% 100% 276.95kB 30.16% process_primitive
0 0% 100% 21.34kB 2.32% reduce_window
0 0% 100% 641.46kB 69.84% reraise_with_filtered_traceback
0 0% 100% 121.22kB 13.20% reshape
0 0% 100% 125.61kB 13.68% row_fn
0 0% 100% 9.59kB 1.04% seed_with_impl
0 0% 100% 258.98kB 28.20% serial_fn
0 0% 100% 258.98kB 28.20% serial_fn_x1
0 0% 100% 22.33kB 2.43% stack
0 0% 100% 7.08kB 0.77% tensordot
0 0% 100% 9.59kB 1.04% threefry_seed
0 0% 100% 906.71kB 98.73% train_kernel_network_with_report
0 0% 100% 28.81kB 3.14% tree_map
0 0% 100% 118.17kB 12.87% wrapped_fn
0 0% 100% 806.39kB 87.80% wrapper
0 0% 100% 164.94kB 17.96% xla_primitive_callable
The only way I've managed to obtain some data is by using
jax.profiler.start_trace('tensorboard')
... train ...
jax.profiler.stop_trace()
and to send this info to Tensorboard. In the Tensorboard there is only few information available, like total memory consumption graph (without no detalization like in listing above), and for this training it is like 2 GB GPU memory used. And it cannot see, which operations take so much memory.
Why pprof
is not seeing any internal memory usage and shows little kB memory used? How can I obtain detailed memory profiling for neural tangents like in pprof
?
Tnank you!
Activity