Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] Memory leak in nn.Linear #132332

Open
hvaara opened this issue Jul 31, 2024 · 15 comments
Open

[MPS] Memory leak in nn.Linear #132332

hvaara opened this issue Jul 31, 2024 · 15 comments
Labels
high priority module: memory usage PyTorch is using more memory than it should, or it is leaking memory module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@hvaara
Copy link
Contributor

hvaara commented Jul 31, 2024

🐛 Describe the bug

Under certain circumstances nn.Linear will have memory leaks on MPS. The exact failure mode and condition that leads to leakage is unclear at this moment. I'll give an update when I have more information.

Possibly related to #125217.

Steps to reproduce

import torch
import torch.nn as nn

N, C, H, W = 64, 32, 256, 256
iters = 5

model = nn.Linear(H, W, device='mps')
input = torch.rand(N, C, H, W, device='mps')

# Warm up
model(input)
torch.mps.empty_cache()

# Begin test
driver_before = torch.mps.driver_allocated_memory()
model(input)
for _ in range(iters):
    model(input)
torch.mps.empty_cache()
driver_after = torch.mps.driver_allocated_memory()

predicted_leak = iters * 4 * N * C * H * W
print(f"{driver_before = }")
print(f"{driver_after = }")
print(f"Detected {driver_after - driver_before} bytes (predicted: {predicted_leak}) leak of GPU memory")

Example output

driver_before = 1102823424
driver_after = 3787177984
Detected 2684354560 bytes (predicted: 2684354560) leak of GPU memory

Versions

PyTorch version: 2.5.0a0+git5406e46
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.5 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.30.1
Libc version: N/A

Python version: 3.8.19 | packaged by conda-forge | (default, Mar 20 2024, 12:49:57) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-14.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3 Max

Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] optree==0.11.0
[pip3] torch==2.5.0a0+git5406e46
[pip3] torchvision==0.20.0a0+61bd547
[conda] numpy 1.24.4 pypi_0 pypi
[conda] optree 0.11.0 pypi_0 pypi
[conda] torch 2.5.0a0+git5406e46 dev_0
[conda] torchvision 0.20.0a0+61bd547 dev_0

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

@malfet malfet added module: mps Related to Apple Metal Performance Shaders framework module: memory usage PyTorch is using more memory than it should, or it is leaking memory labels Jul 31, 2024
@malfet
Copy link
Contributor

malfet commented Jul 31, 2024

I believe there is already an issue about it
@jhavukainen you were looking into it, haven't you?

@hvaara
Copy link
Contributor Author

hvaara commented Jul 31, 2024

Possibly related to #125217 which saw a similar issue in nn.MaxPool2d. There was also a recent fix for an issue someone reported when running Llama (fine-tuning?) which could be related. I'll see if I can dig up that issue id. @malfet were you referring to either of these, or another issue?

@hvaara
Copy link
Contributor Author

hvaara commented Jul 31, 2024

I have a regression test prepared. I'll get a PR out and link to this bug.

@hvaara
Copy link
Contributor Author

hvaara commented Aug 1, 2024

PR for regression test is in #132355.

Need to investigate if the memory leak is in PyTorch or MPS. If the former the fix should be implemented in PyTorch not MPS, and the regression test needs to be updated.

Awaiting confirmation that this is a live bug even with the changes in macOS from #125217 (ref #125217 (comment))

We had anecdotal evidence that the memory leak in nn.Linear goes away when profiling is enabled. Assumption was that it has to do with synctype changes when enabling profiling. I'll try to reproduce this finding. I don't know if this also happened for nn.MaxPool2d or if we ever tested this. If I can repro with nn.Linear, I can also test this for nn.MaxPool2d.

@kulinseth
Copy link
Collaborator

PR for regression test is in #132355.

Need to investigate if the memory leak is in PyTorch or MPS. If the former the fix should be implemented in PyTorch not MPS, and the regression test needs to be updated.

Awaiting confirmation that this is a live bug even with the changes in macOS from #125217 (ref #125217 (comment))

We had anecdotal evidence that the memory leak in nn.Linear goes away when profiling is enabled. Assumption was that it has to do with synctype changes when enabling profiling. I'll try to reproduce this finding. I don't know if this also happened for nn.MaxPool2d or if we ever tested this. If I can repro with nn.Linear, I can also test this for nn.MaxPool2d.

We have repro'ed the MaxPool2d issue . This was a bug in the refcount one of the MaxPool2dGrad intermediate tensors which is fixed in MPSGraph framework. The fix will be in upcoming releases...

@jhavukainen
Copy link
Collaborator

Hi @hvaara! Thanks for the repro case. I can verify that this is a live bug as you say even with the changes that target the maxpool2d memory leak issue so this should be something different. Based on the initial look I can verify that this does show some similarities to the earlier problem in the sense that its not recognized as a leak as such but instead something in the MPSGraph seems to be holding onto the memory and thinking its doing the right thing.

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 5, 2024
@hvaara
Copy link
Contributor Author

hvaara commented Aug 5, 2024

@jhavukainen Thanks a lot for testing the repro case for nn.Linear with the fix for nn.MaxPool2d!

We had anecdotal evidence that the memory leak in nn.Linear goes away when profiling is enabled. Assumption was that it has to do with synctype changes when enabling profiling. I'll try to reproduce this finding. I don't know if this also happened for nn.MaxPool2d or if we ever tested this. If I can repro with nn.Linear, I can also test this for nn.MaxPool2d.

FYI, switching from COMMIT_ADAPTIVE to COMMIT_AND_WAIT in

mpsStream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_ADAPTIVE);

makes the leaking issue in nn.Linear go away.

@jhavukainen
Copy link
Collaborator

@hvaara Oh that's great find, thanks! I'll convene with the MPSGraph experts on this to see if it rings some bells on what might be the root cause for it but this definitely narrows it down.

@kulinseth
Copy link
Collaborator

@hvaara Oh that's great find, thanks! I'll convene with the MPSGraph experts on this to see if it rings some bells on what might be the root cause for it but this definitely narrows it down.

@jhavukainen This is between COMMIT_ADAPTIVE to COMMIT_AND_WAIT and what this has to do with MPSGraph piece ?

@jhavukainen
Copy link
Collaborator

@hvaara Oh that's great find, thanks! I'll convene with the MPSGraph experts on this to see if it rings some bells on what might be the root cause for it but this definitely narrows it down.

@jhavukainen This is between COMMIT_ADAPTIVE to COMMIT_AND_WAIT and what this has to do with MPSGraph piece ?

@kulinseth I added that comment before we discussed the implication of COMMIT_ADAPTIVE and COMMIT_AND_WAIT having different behavior here. So yes you are correct it does not seem to be related to MPSGraph as we concluded in our chat but instead on how we keep encoding to the command buffer until we hit the low watermark value and flush on the pytorch side.

@jhavukainen
Copy link
Collaborator

Ok @hvaara thanks for your patience! Here's the results from my deep dive to the traces of metal resources getting generated during the execution of the code snippet and my current understanding on what's going on:

  • The memory grows by the size of the output object of linear layer for each iteration run so by 4 * N * C * H * W bytes, making sense since this is the output object we allocate on L50 in the linear op 1
  • When the executeMPSGraph is called using COMMIT_ADAPTIVE as the SyncMode, the operation is encoded to the active command buffer in MPSStream but unless the memory pressure is too high, the command buffer is not committed but instead we keep on encoding operations to the same command buffer 2
  • Since each iteration with the nn.Linear needs to maintain its own output tensor until the computation is completed and the python object goes out of scope to be garbage collected, we see the memory increase while the command buffer keeps accumulating more encoded operations. In case we change to COMMIT_AND_WAIT as our SyncMode the command buffers get committed after each encode and the memory buffers required for the outputs are also released once the computation returns. This probably lets the garbage collector sees that the object is no longer needed since we don't assign the outputs of the computation to anything in the python script.
  • To validate this I tried to adjust the original script slightly to allow for the adaptive commit to run its course and the computations to return so that the underlying objects can be released and we can see the memory get freed up. It seems like on my local setup it takes about 5-6s of waiting for the objects to get cleaned up. In actual code I assume this would happen more often since there can be other operations calling the command buffer to get synchronized in between in order to perform updates based on the results.

So in summary to me it seems like the memory is managed as intended in this case. In its current form the COMMIT_ADAPTIVE is a bit opaque to the user since it might make the memory seem like its behaving erratically as the commit is automated to only happen once the local device hits the low memory watermark, which depends on the device you are running on. Additionally there's the interplay of when will the underlying python objects get garbage collected that should finally release the memory buffers assigned to them. Let me know if this sounds reasonable or if you think there's still something we missed here. Or if this is causing a concrete issue on your side that would warrant us to do some changes on how the adaptive commit works at the moment.

Here's also the adjusted script that I used to check that eventually the underlying memory is freed as the objects are garbage collected:

import torch
import torch.nn as nn
from time import sleep

N, C, H, W = 64, 32, 256, 256
iters = 5

model = nn.Linear(H, W, device='mps')
input = torch.rand(N, C, H, W, device='mps')

# Warm up
model(input)
torch.mps.empty_cache()

# Begin test
driver_before = torch.mps.driver_allocated_memory()
model(input)

for _ in range(iters):
    model(input)
torch.mps.empty_cache()

sleep(6)
driver_after = torch.mps.driver_allocated_memory()

predicted_leak = iters * 4 * N * C * H * W
print(f"{driver_before = }")
print(f"{driver_after = }")
print(f"Detected {driver_after - driver_before} bytes (predicted: {predicted_leak}) leak of GPU memory")

leak_per_iter = (driver_after - driver_before) / iters
print(f"{leak_per_iter} per iteration")
size_of_array = (4 * N * C * H * W)
print(f"Size of output array: {size_of_array}")
# driver_before = 1101463552
# driver_after = 545701888
# Detected -555761664 bytes (predicted: 2684354560) leak of GPU memory
# -111152332.8 per iteration
# Size of output array: 536870912

Footnotes

  1. https://github.com/pytorch/pytorch/blob/f5e704a6f25939478f770f8980c344ab461f0113/aten/src/ATen/native/mps/operations/Linear.mm#L50

  2. https://github.com/pytorch/pytorch/blob/f5e704a6f25939478f770f8980c344ab461f0113/aten/src/ATen/mps/MPSStream.mm#L71

@jhavukainen
Copy link
Collaborator

I'll close this for now since based on what I'm seeing this is not a memory leak in nn.Linear. @hvaara please don't hesitate to reopen if it looks like this is not the case from your point of view, or a feature request if there is a need for some additional controls in limiting how much memory the COMMIT_ADAPTIVE approach can use in your application.

@kulinseth
Copy link
Collaborator

@hvaara , can you please comment and see if it addresses this issue.

@hvaara
Copy link
Contributor Author

hvaara commented Aug 15, 2024

Thanks a lot for investigating everybody, and for the detailed notes from your deep dive @jhavukainen! Highly appreciate it!

There are a couple things that are still not clear to me. I'll prepare a notebook with some examples to better illustrate what I mean.

@jhavukainen
Copy link
Collaborator

No problem! Sure I'm happy to take a look once you have the notebook with examples ready

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: memory usage PyTorch is using more memory than it should, or it is leaking memory module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants