Skip to content

Inplace ReLU incompatible with backward hook #61519

Closed
@frgfm

Description

🐛 Inplace RELU issues with backward hooks

Hello there 👋

I encountered a bug I cannot manage to handle with in place ReLUs. In a project of mine, I use backward hooks to get intermediate feature maps and gradients. It used to be working fine up until a few versions ago where densenet was throwing warnings caused by the in place relu over here: https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py#L217

I updated my torch/torchvision versions recently and now it has turned into an error. I tried on my end and was able to reproduce the bug with a simple sequential that includes an in place ReLU. I have no clue how to get around this, anyone got an idea? The strange thing is that the gradient flows back properly, it's just the backward hook that throws the error.

I had initially opened pytorch/vision#4164 thinking the issue was with densenet, but seeing how I was able to reproduce this with a simple sequential, I think the root problem is with PyTorch actually!

Happy to help solving this, thanks in advance!

To Reproduce

import torch
from torch import nn
# Get the model
mod = nn.Sequential(nn.Conv2d(3, 8, 3, padding=1), nn.ReLU(), nn.Conv2d(8, 8, 3, padding=1), nn.ReLU(inplace=True))
# Hook a module before the in place ReLU
def hook_g(m, input, output): print(input.shape)
mod[2].register_full_backward_hook(hook_g)
# Inference
out = mod(torch.rand((1, 3, 224, 224)))

yields:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-6-d15e02b5edd3> in <module>
----> 1 out = mod(torch.rand((1, 3, 224, 224)))

~/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/miniconda3/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    137     def forward(self, input):
    138         for module in self:
--> 139             input = module(input)
    140         return input
    141 

~/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/miniconda3/lib/python3.8/site-packages/torch/nn/modules/activation.py in forward(self, input)
    100 
    101     def forward(self, input: Tensor) -> Tensor:
--> 102         return F.relu(input, inplace=self.inplace)
    103 
    104     def extra_repr(self) -> str:

~/miniconda3/lib/python3.8/site-packages/torch/nn/functional.py in relu(input, inplace)
   1294         return handle_torch_function(relu, (input,), input, inplace=inplace)
   1295     if inplace:
-> 1296         result = torch.relu_(input)
   1297     else:
   1298         result = torch.relu(input)

RuntimeError: Output 0 of BackwardHookFunctionBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can remove this warning by cloning the output of the custom Function.

Expected behavior

It used to be throwing a warning but this is now causing an error, and I'm unsure how to get around this.

Environment

  • PyTorch / torchvision Version: 1.9.0 / 0.10.0
  • OS: Ubuntu 20.04.2 LTS (x86_64)
  • How you installed PyTorch / torchvision: conda
  • Python version: 3.8 (64-bit runtime)
  • CUDA/cuDNN version: 11.2.152 (cuDNN 8.2.0)

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @mruberry @jbschlosser

Metadata

Assignees

No one assigned

    Labels

    module: autogradRelated to torch.autograd, and the autograd engine in generalmodule: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions