Skip to content

Memory profiling with JAX pprof #147

Open
@celidos

Description

Good day!

I'm trying to use memory profiler pprof as described here:
https://jax.readthedocs.io/en/latest/device_memory_profiling.html

I'm trying to train Myrtle NTK infinite network on CIFAR with architecture taken from Colab notebook:
https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/myrtle_kernel_with_neural_tangents.ipynb

import jax.profiler

def MyrtleNetwork(depth, W_std=np.sqrt(2.0), b_std=0.):
  layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]}
  width = 1
  activation_fn = stax.Relu()
  layers = []
  conv = functools.partial(stax.Conv, W_std=W_std, b_std=b_std, padding='SAME')
  
  layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][0]
  layers += [stax.AvgPool((2, 2), strides=(2, 2))]
  layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][1]
  layers += [stax.AvgPool((2, 2), strides=(2, 2))]
  layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][2]
  layers += [stax.AvgPool((2, 2), strides=(2, 2))] * 3

  layers += [stax.Flatten(), stax.Dense(10, W_std, b_std)]

  return stax.serial(*layers)

from jax.lib import xla_bridge

... training kernel with batch size 4...

jax.profiler.save_device_memory_profile(output_fname + '.prof', 'gpu')

The problem is that when I check output of the profiler, it's memory consumption looks very small:

      flat  flat%   sum%        cum   cum%
  806.39kB 87.80% 87.80%   806.39kB 87.80%  backend_compile
  112.01kB 12.20%   100%   112.01kB 12.20%  _execute_compiled
         0     0%   100%   918.40kB   100%  <unknown>
         0     0%   100%   112.01kB 12.20%  <unknown>
         0     0%   100%    28.81kB  3.14%  <unknown>
         0     0%   100%    22.65kB  2.47%  <unknown>
         0     0%   100%    22.33kB  2.43%  <unknown>
         0     0%   100%    20.62kB  2.25%  <unknown>
         0     0%   100%     6.48kB  0.71%  <unknown>
         0     0%   100%     6.48kB  0.71%  <unknown>
         0     0%   100%     9.59kB  1.04%  PRNGKey
         0     0%   100%    11.69kB  1.27%  _call_with_frames_removed
         0     0%   100%    11.69kB  1.27%  _find_and_load
         0     0%   100%    11.69kB  1.27%  _find_and_load_unlocked
         0     0%   100%   118.17kB 12.87%  _flatten_batch_dimensions
         0     0%   100%   118.17kB 12.87%  _flatten_kernel
         0     0%   100%    15.94kB  1.74%  _gather
         0     0%   100%     8.21kB  0.89%  _index_to_gather
         0     0%   100%    11.69kB  1.27%  _load_unlocked
         0     0%   100%     6.37kB  0.69%  _normalize_index
         0     0%   100%    21.34kB  2.32%  _reduce_window_sum
         0     0%   100%   120.30kB 13.10%  _reshape
         0     0%   100%    15.94kB  1.74%  _rewriting_take
         0     0%   100%   140.81kB 15.33%  _scan
         0     0%   100%   641.46kB 69.84%  _xla_call_impl
         0     0%   100%   806.39kB 87.80%  _xla_callable_uncached
         0     0%   100%   214.72kB 23.38%  apply_fn
         0     0%   100%   256.69kB 27.95%  apply_fn_with_masking
         0     0%   100%   256.69kB 27.95%  apply_fun
         0     0%   100%   276.95kB 30.16%  apply_primitive
         0     0%   100%   918.40kB   100%  bind
         0     0%   100%   276.95kB 30.16%  bind_with_trace
         0     0%   100%     5.25kB  0.57%  broadcast
         0     0%   100%     8.02kB  0.87%  broadcast_in_dim
         0     0%   100%   641.46kB 69.84%  cache_miss
         0     0%   100%   164.94kB 17.96%  cached
         0     0%   100%   641.46kB 69.84%  call_bind
         0     0%   100%   118.48kB 12.90%  col_fn
         0     0%   100%   806.39kB 87.80%  compile
         0     0%   100%   806.39kB 87.80%  compile_or_get_cached
         0     0%   100%    24.82kB  2.70%  concatenate
         0     0%   100%    69.81kB  7.60%  conv_general_dilated
         0     0%   100%   106.36kB 11.58%  deferring_binary_op
         0     0%   100%     7.08kB  0.77%  dot_general
         0     0%   100%    11.69kB  1.27%  exec_module
         0     0%   100%   118.48kB 12.90%  f_pmapped
         0     0%   100%    91.50kB  9.96%  fn
         0     0%   100%   806.39kB 87.80%  from_xla_computation
         0     0%   100%     6.17kB  0.67%  full
         0     0%   100%     7.73kB  0.84%  gather
         0     0%   100%   258.98kB 28.20%  h
         0     0%   100%    40.05kB  4.36%  init_fun
         0     0%   100%   641.45kB 69.84%  memoized_fun
         0     0%   100%    33.72kB  3.67%  normal
         0     0%   100%    33.72kB  3.67%  ntk_init_fn
         0     0%   100%   317.85kB 34.61%  odeint
         0     0%   100%   349.15kB 38.02%  predict_fn
         0     0%   100%   641.46kB 69.84%  process_call
         0     0%   100%   276.95kB 30.16%  process_primitive
         0     0%   100%    21.34kB  2.32%  reduce_window
         0     0%   100%   641.46kB 69.84%  reraise_with_filtered_traceback
         0     0%   100%   121.22kB 13.20%  reshape
         0     0%   100%   125.61kB 13.68%  row_fn
         0     0%   100%     9.59kB  1.04%  seed_with_impl
         0     0%   100%   258.98kB 28.20%  serial_fn
         0     0%   100%   258.98kB 28.20%  serial_fn_x1
         0     0%   100%    22.33kB  2.43%  stack
         0     0%   100%     7.08kB  0.77%  tensordot
         0     0%   100%     9.59kB  1.04%  threefry_seed
         0     0%   100%   906.71kB 98.73%  train_kernel_network_with_report
         0     0%   100%    28.81kB  3.14%  tree_map
         0     0%   100%   118.17kB 12.87%  wrapped_fn
         0     0%   100%   806.39kB 87.80%  wrapper
         0     0%   100%   164.94kB 17.96%  xla_primitive_callable

The only way I've managed to obtain some data is by using

jax.profiler.start_trace('tensorboard')

... train ...

jax.profiler.stop_trace()

and to send this info to Tensorboard. In the Tensorboard there is only few information available, like total memory consumption graph (without no detalization like in listing above), and for this training it is like 2 GB GPU memory used. And it cannot see, which operations take so much memory.

Why pprof is not seeing any internal memory usage and shows little kB memory used? How can I obtain detailed memory profiling for neural tangents like in pprof?

Tnank you!

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions