Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PYTORCH_ENABLE_MPS_FALLBACK does not appear to work for nn.Conv1d #134416

Closed
a2aaron opened this issue Aug 25, 2024 · 3 comments
Closed

PYTORCH_ENABLE_MPS_FALLBACK does not appear to work for nn.Conv1d #134416

a2aaron opened this issue Aug 25, 2024 · 3 comments
Labels
module: mps Related to Apple Metal Performance Shaders framework module: regression It used to work, and now it doesn't triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@a2aaron
Copy link

a2aaron commented Aug 25, 2024

🐛 Describe the bug

It looks like #129207 addressed an issue with the MPS implementation of nn.Conv1d. Specifically: The implementation would silently return incorrect results when running a convolution with more than 65536 channels. In the issue, the fix was to temporarily have nn.Conv1d throw NotImplementedError in this situation and to suggest using PYTORCH_ENABLE_MPS_FALLBACK = 1 so that the operation could fall back to the CPU.

However, it seems like this fallback does not work. Trying to run the linked issue's minimal example (slightly trimmed down) causes an error regardless if PYTORCH_ENABLE_MPS_FALLBACK is set or not:

# in main.py
import torch
import torch.nn as nn
import os

print(os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]) # Prints 1, assuming the variable is set

torch.manual_seed(0)
conv = nn.Conv1d(1, 65537, 3, padding=1)

x = torch.ones([1, 1, 3])
y_mps = conv.to("mps")(x.to("mps")) # Fails with NotImplementedError

Run with something like PYTORCH_ENABLE_MPS_FALLBACK=1 python3.11 main.py

Attempting to run this code always fails with the following, even when PYTORCH_ENABLE_MPS_FALLBACK is set to 1:

Traceback (most recent call last):
  File "[path to main.py]", line 13, in <module>
    y_mps = conv.to("mps")(x.to("mps"))  # Fails with NotImplementedError
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 373, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 368, in _conv_forward
    return F.conv1d(
           ^^^^^^^^^
NotImplementedError: Output channels > 65536 not supported at the MPS device. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

Versions

Collecting environment information...
PyTorch version: 2.5.0.dev20240824
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.5 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.29.0
Libc version: N/A

Python version: 3.11.5 (main, Aug 24 2023, 15:09:32) [Clang 14.0.0 (clang-1400.0.29.202)] (64-bit runtime)
Python platform: macOS-14.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] numpy==2.1.0
[pip3] torch==2.5.0.dev20240824
[pip3] torchaudio==2.4.0.dev20240824
[pip3] torchsde==0.2.6
[pip3] torchvision==0.20.0.dev20240824
[conda] Could not collect

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

@hvaara
Copy link
Contributor

hvaara commented Aug 26, 2024

@pytorchbot label "module: mps"

@pytorch-bot pytorch-bot bot added the module: mps Related to Apple Metal Performance Shaders framework label Aug 26, 2024
@albanD albanD added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: regression It used to work, and now it doesn't labels Aug 26, 2024
@hvaara
Copy link
Contributor

hvaara commented Aug 26, 2024

I was unsure I should add the regression label, since it's technically not a regression per se. The previous implementation had a silent correctness issue. That issue was fixed by #129484, but it also introduced the behavior OP is now seeing.

I can take a look.

@jhavukainen
Copy link
Collaborator

jhavukainen commented Aug 29, 2024

Yeah my bad of not thinking this through. The message instructing to use the fallback here is not correct as the op is not failing due to missing MPS implementation we could fallback from, I'll open a PR to fix the message for now to avoid misleading people.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: mps Related to Apple Metal Performance Shaders framework module: regression It used to work, and now it doesn't triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
4 participants