Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High memory usage for CPU inference on variable input shapes (10x compared to pytorch 1.1) #27971

Closed
lopuhin opened this issue Oct 15, 2019 · 25 comments
Labels
high priority module: cpu CPU specific problem (e.g., perf, algorithm) module: memory usage PyTorch is using more memory than it should, or it is leaking memory module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@lopuhin
Copy link
Contributor

lopuhin commented Oct 15, 2019

🐛 Bug

In pytorch 1.3, when doing inference with resnet34 on CPU with variable input shapes, much more memory is used compared to pytorch 1.1 (both CPU-only builds on one core): 6 GB for pytorch 1.3 vs. ~0.5 GB for pytorch 1.1

To Reproduce

Steps to reproduce the behavior:

Run the following script https://gist.github.com/lopuhin/0d100ef7df01fdfc91d9685f6e01ff64 - it performs inference with resnet34 on images with fixed width and variable height, and reports speed and memory growth over the course of the benchmark.

Running under pytorch 1.1:

$ python3 pytorch_high_mem.py --n 500
torch 1.1.0
heights: mean=1004, p50=278 p95=5100 max=7680
n=100 memory growth (kb): 477,952
n=200 memory growth (kb): 503,948
n=300 memory growth (kb): 503,948
n=400 memory growth (kb): 518,652
time: mean=0.924 s, p50=0.271 s, p95=4.626 s
memory (kb): 174,552 initial, 518,652 growth

Running under pytorch 1.3:

$ python3 pytorch_high_mem.py --n 500
torch 1.3.0+cpu
heights: mean=1004, p50=278 p95=5100 max=7680
n=100 memory growth (kb): 2,624,296
n=200 memory growth (kb): 4,480,012
n=300 memory growth (kb): 5,579,568
n=400 memory growth (kb): 5,600,888
time: mean=0.676 s, p50=0.196 s, p95=3.825 s
memory (kb): 187,840 initial, 6,200,664 growth

Expected behavior

Expected behavior is low memory usage as in pytorch 1.1. Alternatively, a way to control caching (e.g. something which disables caching or something like torch.cuda.clear_caches() but for CPU) - as I understand, high memory usage happens because allocations are cached, which makes sense for fixed shapes, but does not work well for variable shapes. Binning shapes is possible as a work-around but has a noticeable performance penalty and memory usage is still higher.

Environment

Environment under pytorch 1.1 (via collect_env.py script):

Collecting environment information...
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: None

OS: Debian GNU/Linux 9 (stretch)
GCC version: Could not collect
CMake version: Could not collect

Python version: 3.6
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip3] numpy==1.16.3
[pip3] torch==1.1.0
[pip3] torchvision==0.3.0
[conda] Could not collect

pytorch installed with

pip install -U --no-cache-dir cython wheel pip http://download.pytorch.org/whl/cpu/torch-1.1.0-cp36-cp36m-linux_x86_64.whl http://download.pytorch.org/whl/cpu/torchvision-0.3.0-cp36-cp36m-linux_x86_64.whl

Environment under pytorch 1.3:

PyTorch version: 1.3.0+cpu
Is debug build: No
CUDA used to build PyTorch: None

OS: Debian GNU/Linux 9 (stretch)
GCC version: Could not collect
CMake version: Could not collect

Python version: 3.6
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip3] numpy==1.16.3
[pip3] torch==1.3.0+cpu
[pip3] torchvision==0.4.1+cpu
[conda] Could not collect

pytorch installed with

pip install torch==1.3.0+cpu torchvision==0.4.1+cpu -f https://download.pytorch.org/whl/torch_stable.html

Additional context

This may be similar to oneapi-src/oneDNN#489 but here mkldnn is not used explicitly.

cc @VitalyFedyunin @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @ezyang @gchanan @zou3519

@zou3519 zou3519 added high priority module: memory usage PyTorch is using more memory than it should, or it is leaking memory module: cpu CPU specific problem (e.g., perf, algorithm) triage review labels Oct 15, 2019
@zou3519
Copy link
Contributor

zou3519 commented Oct 15, 2019

10X memory usage compared to pytorch 1.1 is bad so I am marking this as high pri.

@ezyang ezyang added the module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration label Oct 16, 2019
@ezyang
Copy link
Contributor

ezyang commented Oct 16, 2019

Paging the MKL-DNN folks as this is almost certainly MKLDNN related

@lopuhin
Copy link
Contributor Author

lopuhin commented Oct 16, 2019

Thanks for the hint. Are there any environment variables or options that might influence result? Edit: maybe #25186 could be useful here.

I just tried the benchmark on an AMD Ryzen CPU and got the same results.

@lopuhin
Copy link
Contributor Author

lopuhin commented Oct 16, 2019

FWIW ONNX runtime looks almost unaffected by this issue, so as a workaround it's possible to use it for inference, here are benchmark results on the same machine.

Model exported with (no other optimizations applied):

torch.onnx.export(model, torch.randn(1, 3, 920, 320), 'resnet34.onnx', verbose=True, input_names=['input'], output_names=['output'], dynamic_axes={'input': {2: 'height'}})

And results are:

$ python pytorch_high_mem.py --n 500 --onnx
torch 1.3.0+cpu
heights: mean=1004, p50=278 p95=5100 max=7680
n=100 memory growth (kb): 478,060
n=200 memory growth (kb): 670,640
n=300 memory growth (kb): 753,936
n=400 memory growth (kb): 782,776
time: mean=0.481 s, p50=0.134 s, p95=2.441 s
memory (kb): 286,724 initial, 821,696 growth

Even better, memory stops growing after about 800 iterations with sess_options.enable_cpu_mem_arena = False.

@ezyang
Copy link
Contributor

ezyang commented Oct 16, 2019

I can reproduce this on master.

python test.py
torch 1.4.0a0+4f1f084
heights: mean=1031, p50=284 p95=5561 max=7680
n=100 memory growth (kb): 2,624,828
n=200 memory growth (kb): 4,481,904
n=300 memory growth (kb): 5,616,512
n=400 memory growth (kb): 5,616,512
n=500 memory growth (kb): 6,322,232
n=600 memory growth (kb): 6,389,196
n=700 memory growth (kb): 6,677,340

@ezyang
Copy link
Contributor

ezyang commented Oct 16, 2019

You can get more information about MKLDNN by setting env var MKLDNN_VERBOSE=1. I get logs that look like:

mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_nChw8c out:f32_nchw,num:1,1x64x50x80,0.181885                                
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_nchw out:f32_nChw8c,num:1,1x64x50x80,0.178955                                
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_oihw out:f32_OIhw8i8o,num:1,64x64x3x3,0.0361328                              
mkldnn_verbose,exec,convolution,jit:avx2,forward_training,fsrc:nChw8c fwei:OIhw8i8o fbia:undef fdst:nChw8c,alg:convolution_dire
ct,mb1_ic64oc64_ih50oh50kh3sh1dh0ph1_iw80ow80kw3sw1dw0pw1,3.58203                                                             
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_nChw8c out:f32_nchw,num:1,1x64x50x80,0.185059                                
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_nchw out:f32_nChw8c,num:1,1x64x50x80,0.221924                                
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_oihw out:f32_OIhw8i8o,num:1,128x64x3x3,0.0869141                             
mkldnn_verbose,exec,convolution,jit:avx2,forward_training,fsrc:nChw8c fwei:OIhw8i8o fbia:undef fdst:nChw8c,alg:convolution_dire
ct,mb1_ic64oc128_ih50oh25kh3sh2dh0ph1_iw80ow40kw3sw2dw0pw1,1.9541
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_nChw8c out:f32_nchw,num:1,1x128x25x40,0.101074
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_nchw out:f32_nChw8c,num:1,1x128x25x40,0.0710449
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_oihw out:f32_OIhw8i8o,num:1,128x128x3x3,0.128174
mkldnn_verbose,exec,convolution,jit:avx2,forward_training,fsrc:nChw8c fwei:OIhw8i8o fbia:undef fdst:nChw8c,alg:convolution_dire
ct,mb1_ic128oc128_ih25oh25kh3sh1dh0ph1_iw40ow40kw3sw1dw0pw1,3.62695
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_nChw8c out:f32_nchw,num:1,1x128x25x40,0.0930176
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_nchw out:f32_nChw8c,num:1,1x64x50x80,0.198975
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_oihw out:f32_OIhw8i8o,num:1,128x64x1x1,0.0109863
mkldnn_verbose,exec,convolution,jit_1x1:avx2,forward_training,fsrc:nChw8c fwei:OIhw8i8o fbia:undef fdst:nChw8c,alg:convolution_
direct,mb1_ic64oc128_ih50oh25kh1sh2dh0ph0_iw80ow40kw1sw2dw0pw0,0.24585

(I don't really know what it means though XD)

@vpirogov
Copy link

@ezyang, unfortunately verbose log does not tell us anything about memory consumption.

@vpirogov
Copy link

Observed behavior is likely the result of caching mechanism implemented outside of the library.

@XiaobingSuper
Copy link
Collaborator

@lopuhin, This is the same problem which you said in oneapi-src/oneDNN#489, Ideep will cache MKLDNN primitives to reduce the cost of create MKLDNN primitive creation, we support an environment variable named LRU_CACHE_CAPACITY to control the cache capacity. The default value is 1024, you can set a smaller number to reduce the memory use by export LRU_CACHE_CAPACITY=your number. Thanks!

@lopuhin
Copy link
Contributor Author

lopuhin commented Oct 17, 2019

Wow this works perfectly and solves the issue, thank you @XiagenFeng

Benchmark results:

$ LRU_CACHE_CAPACITY=1 python pytorch_high_mem.py --n 500
torch 1.3.0+cpu
heights: mean=1004, p50=278 p95=5100 max=7680
n=100 memory growth (kb): 361,128
n=200 memory growth (kb): 397,024
n=300 memory growth (kb): 397,024
n=400 memory growth (kb): 397,024
time: mean=0.519 s, p50=0.142 s, p95=2.660 s
memory (kb): 191,356 initial, 397,024 growth
LRU_CACHE_CAPACITY=16 python pytorch_high_mem.py --n 500
torch 1.3.0+cpu
heights: mean=1004, p50=278 p95=5100 max=7680
n=100 memory growth (kb): 521,332
n=200 memory growth (kb): 604,804
n=300 memory growth (kb): 621,560
n=400 memory growth (kb): 675,048
time: mean=0.510 s, p50=0.143 s, p95=2.506 s
memory (kb): 191,496 initial, 675,048 growth

@ezyang
Copy link
Contributor

ezyang commented Oct 17, 2019

Downgrading priority as a workaround is present. I'll keep the bug open in case anyone else notices high memory usage; we may want to reduce the default cache size (but hard to say without more reports.)

@jerryzh168 jerryzh168 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Oct 21, 2019
@ezyang
Copy link
Contributor

ezyang commented Nov 18, 2019

Amplifying priority: #29809 is a duplicate report of this problem.

@ezyang
Copy link
Contributor

ezyang commented Dec 3, 2019

Another duplicate report: #29893

@ssnl
Copy link
Collaborator

ssnl commented Jan 25, 2020

duplicates #32037 #32596

@ssnl
Copy link
Collaborator

ssnl commented Jan 25, 2020

time to reduce default cache size?

@ngimel ngimel added high priority triage review and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 25, 2020
@ezyang
Copy link
Contributor

ezyang commented Feb 3, 2020

Let's reduce the default cache size.

@ezyang
Copy link
Contributor

ezyang commented Feb 3, 2020

@gchanan says maybe the recent release of MKL DNN may have helped here.

@smessmer smessmer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 3, 2020
@WillLiGitHub
Copy link

pytorch=1.3.0 and set LRU_CACHE_CAPACITY=1 fix the memory leak.

@Baranowski
Copy link
Contributor

I cannot reproduce this with current master (1.6.0a0+96885f7)

print(*torch.__config__.show().split("\n"), sep="\n")
PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2019.0.4 Product Build 20190411 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.2.0 (Git Hash 70f8b879ea7a0c38caedb3320b7c85e8497ff50d)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - CPU capability usage: AVX2
  - CUDA Runtime 10.1
  - NVCC architecture flags: -gencode;arch=compute_75,code=sm_75
  - CuDNN 7.6.5  (built against CUDA 10.0)
  - Magma 2.5.1
  - Build settings: BLAS=MKL, BUILD_TYPE=RelWithDebInfo, CXX_FLAGS=-D__STDC_FORMAT_MACROS -I/usr/local/cuda-10.1.243/include -L/usr/local/cuda-10.1.243/lib64 -L/home/wbaranowski/miniconda3/envs/pytorch-cuda-dev/lib -Wno-deprecated -fvisibility-inlines-hidden -fopenmp -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_INTERNAL_THREADPOOL_IMPL -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=1, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON,
USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=0, USE_NNPACK=0, USE_OPENMP=ON, USE_STATIC_DISPATCH=OFF,

@rgommers
Copy link
Collaborator

I can reproduce this with 1.3.0, 1.4.0 and 1.5.0 installed with conda install pytorch -c pytorch, and also when building v1.5.0 from source in a conda environment. In those cases setting LRU_CACHE_CAPACITY=1 indeed fixes things.

I cannot reproduce this with current master (1.6.0a0+fe44741) built from source in that same conda env, max memory usage is ~700Mb (vs. 6-8 Gb with the other cases above).

The 1.5.0 binary and v1.5.0 source build both use:

$ python -c "import torch; print(*torch.__config__.show().split('\n'), sep='\n')"
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v0.21.1 (Git Hash 7d2fd500bc78936d1d648ca713b901012f470dbc)

Full output for v1.5.0 build:

PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v0.21.1 (Git Hash 7d2fd500bc78936d1d648ca713b901012f470dbc)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.2
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_37,code=compute_37
  - CuDNN 7.6.5
  - Magma 2.5.2
  - Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -fopenmp -DNDEBUG -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_INTERNAL_THREADPOOL_IMPL -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_STATIC_DISPATCH=OFF,

On current master, MKL-DNN has been upgraded to v1.2.0:

$ python -c "import torch; print(*torch.__config__.show().split('\n'), sep='\n')"
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.2.0 (Git Hash 70f8b879ea7a0c38caedb3320b7c85e8497ff50d)

Full output for master (1.6.0a0+22e3063) build:

PyTorch built with:
  - GCC 9.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.2.0 (Git Hash 70f8b879ea7a0c38caedb3320b7c85e8497ff50d)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - CPU capability usage: AVX2
  - Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -fopenmp -DNDEBUG -DUSE_PYTORCH_QNNPACK -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=0, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=0, USE_NNPACK=0, USE_OPENMP=ON, USE_STATIC_DISPATCH=OFF

The MKL-DNN upgrade from v0.21.1 to v1.2.0 happened in gh-32422.

Memory usage now is still a little higher than with PyTorch 1.1 (760 MB now for n=400, vs. 518 MB on 1.1), but that's probably expected, and it doesn't keep growing:

$ python high_mem.py 
torch 1.6.0a0+22e3063
heights: mean=1031, p50=284 p95=5561 max=7680
n=100 memory growth (kb): 574,124
n=200 memory growth (kb): 760,356
n=300 memory growth (kb): 760,432
n=400 memory growth (kb): 760,432
n=500 memory growth (kb): 809,008
n=600 memory growth (kb): 809,008
n=700 memory growth (kb): 809,008
n=800 memory growth (kb): 821,312
n=900 memory growth (kb): 821,312
time: mean=0.441 s, p50=0.120 s, p95=2.430 s
memory (kb): 204,652 initial, 821,312 growth

That upgrade also got rid of third_party/ideep/include/ideep/lru_cache.hpp completely, and LRU_CACHE_CAPACITY is no longer defined anywhere in the code base. So looks like there's nothing left to do here, closing.

@AloneGu
Copy link

AloneGu commented Jun 16, 2020

same here for pytorch 1.3.0

fixed by

import os
os.environ["LRU_CACHE_CAPACITY"] = "3"

@pinzhenx
Copy link
Collaborator

@AloneGu The fix was on the master branch
For Pytorch <= 1.5, you still have to set LRU_CACHE_CAPACITY manually

@AloneGu
Copy link

AloneGu commented Jun 17, 2020

@AloneGu The fix was on the master branch
For Pytorch <= 1.5, you still have to set LRU_CACHE_CAPACITY manually

got it , thx

@jonsneyers
Copy link

For the record (since I recently found this issue searching for a solution to this particular problem), the relevant environment variable is now called ONEDNN_PRIMITIVE_CACHE_CAPACITY. See also: https://www.intel.com/content/www/us/en/develop/documentation/onednn-developer-guide-and-reference/top/advanced-topics/primitive-cache.html

@rbracco
Copy link

rbracco commented Oct 18, 2023

So I had to go really deep on a CPU-inference memory issue for a model that has variable sized input (audio). Here's what I found, hope it helps:

What worked

  • Setting ONEDNN_PRIMITIVE_CACHE_CAPACITY to 1 via os.environ["ONEDNN_PRIMITIVE_CACHE_CAPACITY"] = "1" or ONEDNN_PRIMITIVE_CACHE_CAPACITY="1" python <inference-file>.py. Showed a dramatic improvement in memory usage with no sacrifice in speed (see table below)
  • Wrapping inference with with torch.jit.optimized_execution(False): showed a further large improvement in memory, also with no sacrifice in speed. This is pretty crazy because A. there's zero documentation for this feature. and B. it surprisingly had the same impact on my .ckpt models as on my .pt models, which I wouldn't expect since I think only the latter are scripted. Note: There appears to be a slight CPU/memory tradeoff here, in production with limited resources, keeping it True allows for 10-15% higher peak throughput, at the expense of memory, but if you're here, memory is probably your bottleneck
  • Wrapping inference with with torch.backends.mkldnn.flags(enabled=False): had the same impact on memory as setting ONEDNN_PRIMITIVE_CACHE_CAPACITY="1", but caused a 15% slowdown in CPU inference. It seems that setting the ONEDNN cache size is a more targeted approach.
  • Setting os.environ["LRU_CACHE_CAPACITY"] = "1" did nothing, confirming @jonsneyers lifesaving post pointing to the new relevant variable.
  • If you think your high memory use might be due to variable sized inputs, try passing torch.randn() (with a plausible shape for your input, e.g. for audio torch.randn(1, 96342)), if you run it 200 times with a different random tensor of the same shape and your memory issue disappears, it's probably the variable size. You can repeat with torch.randn(1, random.randint(50000,100000) and if the memory issue returns it's definitely due to variable size. Note: Even after fixing, your memory will jump around due to variable tensor size, this is normal as bigger tensors use more memory but, once fixed, you should not see a significant difference in peak memory usage between a test of a single random tensor of shape 1x100,000, and a range of random tensors of size 1xrandom.randint(50000,100000)

What didn't work, but maybe would work for you

  • I was using a jitted model from a .pt file, I also tested memory usage in the .ckpt (non-scripted) version to rule out a torchscript issue
  • Setting torch.set_num_threads(1) and torch.set_num_interop_threads(1) slowed it way down but didnt impact memory
  • Attempting to turn off the ONEDNN cache completely as described here: ONEDNN_ENABLE_PRIMITIVE_CACHE="OFF" did nothing (note: I later realized this is because I tried setting an env variable, but the docs state it has to be done during the build process)
  • Experimenting with padding variable input shapes to multiples of 320 to decrease the total variability.
  • Deploying to ONNX, I didn't try it cause it's a huge pain, but maybe it would've worked.

Memory usage after inference of 500 items of varying sizes

ONEDNN_PRIMITIVE_CACHE_CAPACITY - default ONEDNN_PRIMITIVE_CACHE_CAPACITY = "1"
torch.jit.optimized_execution - default 4840MB 3997MB
torch.jit.optimized_execution(False) 4446MB 3135MB

This issue was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: cpu CPU specific problem (e.g., perf, algorithm) module: memory usage PyTorch is using more memory than it should, or it is leaking memory module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests