Skip to content

Commit

Permalink
add OpInfo for torch.nn.functional.dropout (#62315)
Browse files Browse the repository at this point in the history
Summary:
Addresses pytorch/functorch#78.

Pull Request resolved: #62315

Reviewed By: mruberry

Differential Revision: D30932765

Pulled By: zou3519

fbshipit-source-id: 481c67b59a966b4d640973d252b3e392d8db728e
  • Loading branch information
pmeier authored and facebook-github-bot committed Sep 15, 2021
1 parent d6d286f commit 32c5da8
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
1 change: 1 addition & 0 deletions test/test_fx_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,6 +1468,7 @@ def test_normalize_operator_exhaustive(self, device, dtype, op):
"igamma",
"igammac",
"index_put",
"nn.functional.dropout",
"polygamma",
"special.polygamma",
"repeat",
Expand Down
42 changes: 42 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5331,6 +5331,16 @@ def _tensor(shape, dtype=dtype, low=None, high=None):

return [SampleInput(tensor, args=args) for tensor, args in test_cases]

def sample_inputs_dropout(op_info, device, dtype, requires_grad, **kwargs):
input = make_tensor((S,), device=device, dtype=dtype, requires_grad=requires_grad)

return [
SampleInput(input),
SampleInput(input, kwargs=dict(p=0.0)),
SampleInput(input, kwargs=dict(p=1.0)),
SampleInput(input, kwargs=dict(training=False)),
]

def sample_inputs_one_hot(op_info, device, dtype, requires_grad, **kwargs):
def make_input(shape, *, low, high):
return make_tensor(shape, device=device, dtype=dtype, low=low, high=high, requires_grad=requires_grad)
Expand Down Expand Up @@ -5735,6 +5745,14 @@ def reference_mse_loss(input, target, reduction="mean"):
return se


def wrapper_set_seed(op, input, *args, **kwargs):
"""Wrapper to set seed manually for some functions like dropout
See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.
"""
torch.manual_seed(42)
return op(input, *args, **kwargs)


def reference_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5):
feature_size = np.prod(normalized_shape)
inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload]
Expand Down Expand Up @@ -9317,6 +9335,30 @@ def wrapper(x: np.ndarray, *args, **kwargs):
decorators=(toleranceOverride({torch.float32: tol(atol=0, rtol=4e-6), }),),
dtypes=all_types_and(torch.bool),
safe_casts_outputs=True),
OpInfo(
"nn.functional.dropout",
op=lambda input, *args, **kwargs:
wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs),
ref=_NOTHING,
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
skips=(
# Probably because we have used lambda for the op here
# AssertionError: JIT Test does not execute any logic
SkipInfo('TestJit', 'test_variant_consistency_jit'),
# inplace variant dispatches to dropout kernel, while on CUDA
# the op dispatches to _fused_dropout (with a few more conditions)
# hence, different values and this skip here
SkipInfo('TestMathBits', 'test_neg_view', device_type='cuda'),
# On CUDA, the op is dispatched (and a few more conditions) to
# _fused_dropout, which doesn't support forward AD
SkipInfo('TestGradients', 'test_forward_mode_AD', device_type='cuda'),),
gradcheck_wrapper=wrapper_set_seed,
supports_forward_ad=True,
supports_out=False,
sample_inputs_func=sample_inputs_dropout,
inplace_variant=lambda input, *args, **kwargs:
wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs, inplace=True)),
OpInfo(
"nn.functional.one_hot",
ref=reference_one_hot,
Expand Down

0 comments on commit 32c5da8

Please sign in to comment.