Skip to content

Commit

Permalink
Loosen last dim contiguity for sdpa constraint to include last dim 0,1 (
Browse files Browse the repository at this point in the history
pytorch#139787)

Previously we were checking for a last dim with stride == 1. When the size is <= 1 that also is sufficient because the stride is insignificant. Fix for pytorch#138317

Pull Request resolved: pytorch#139787
Approved by: https://github.com/drisspg
  • Loading branch information
eellison authored and pobin6 committed Dec 5, 2024
1 parent 4691d64 commit b7874da
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 3 deletions.
66 changes: 66 additions & 0 deletions test/inductor/test_cuda_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,72 @@ def outer_reduce(x):
self.assertEqual(outer_reduce(a), out)
self.assertTrue("for roffset" not in code)

def test_scaled_dot_product_efficient_attention_backward(self):
from torch import nn, Tensor

class SelfAttention(nn.Module):
def __init__(
self,
num_attention_heads: int = 12,
hidden_size: int = 768,
attention_probs_dropout_prob: float = 0.1,
):
super().__init__()

self.num_attention_heads = num_attention_heads
self.attention_head_size = hidden_size // num_attention_heads

self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, hidden_size)
self.value = nn.Linear(hidden_size, hidden_size)

self.dropout_prob = attention_probs_dropout_prob

def transpose_for_scores(self, x: Tensor) -> Tensor:
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
return x.view(new_x_shape).permute(0, 2, 1, 3)

def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=False,
)
return attn_output

device = torch.device("cuda")
num_attention_heads = 8
hidden_size = 512
attention_probs_dropout_prob = 0.0
model = SelfAttention(
num_attention_heads=num_attention_heads,
hidden_size=hidden_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
).to(device)

model = torch.compile(model)

# runs without failure
batch_size = 8
length = 1
inputs_embeds = torch.randn(batch_size, length, hidden_size, device=device)
attention_mask = torch.ones(batch_size, 1, length, length, device=device)
attn_output = model(hidden_states=inputs_embeds, attention_mask=attention_mask)[
0
]
loss = attn_output.mean()
loss.backward()

def test_non_contiguous_unaligned_input_indices(self):
from torch._inductor.compile_fx import remove_unaligned_input_idxs

Expand Down
9 changes: 6 additions & 3 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2349,9 +2349,12 @@ def is_aligned_realized_tensor(x):
(V.graph.sizevars.size_hint(x.get_stride()[i]) % ALIGNMENT) == 0
for i in range(len(x.get_stride()) - 1)
)
return (
V.graph.sizevars.size_hint(x.get_stride()[-1])
) == 1 and aligned_strides
# if the last dim size is <= 1, stride doesnt matter
aligned_last_dim = (
V.graph.sizevars.size_hint(x.get_stride()[-1]) == 1
or V.graph.sizevars.size_hint(x.get_size()[-1]) <= 1
)
return aligned_last_dim and aligned_strides

try:
arg.get_stride()
Expand Down

0 comments on commit b7874da

Please sign in to comment.