Skip to content

fix unfused padding causal sdpa#3063

Open
hungryGeek16 wants to merge 2 commits into
NVIDIA:mainfrom
hungryGeek16:fix-unfused-padding-causal-sdpa
Open

fix unfused padding causal sdpa#3063
hungryGeek16 wants to merge 2 commits into
NVIDIA:mainfrom
hungryGeek16:fix-unfused-padding-causal-sdpa

Conversation

@hungryGeek16
Copy link
Copy Markdown

@hungryGeek16 hungryGeek16 commented May 31, 2026

Adds a targeted PyTorch SDPA fallback for unfused THD padding_causal self-attention so TransformerEngine does not materialize the full quadratic padding/causal mask. Includes a regression test that fails if get_full_mask is called on this path.

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 31, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 31, 2026

Greptile Summary

This PR introduces a targeted PyTorch SDPA fast path for UnfusedDotProductAttention when the layout is THD with padding_causal self-attention, avoiding materialisation of the full quadratic mask via get_full_mask. It also relocates the scale / get_full_mask call ordering and adds a regression test that monkeypatches get_full_mask to confirm the new path is taken.

  • _use_varlen_sdpa + _format_context + _forward_varlen_sdpa are introduced as three new helpers that route qualifying calls through a per-batch F.scaled_dot_product_attention loop, then reshape the output back to the caller's expected layout.
  • The regression test verifies numerics, backward-pass gradient existence, and that get_full_mask is never called on this path.

Confidence Score: 3/5

The new fast path can silently produce wrong attention output when triggered during inference with a KV cache, because it slices keys to query length instead of the full cache length.

The fast path in _forward_varlen_sdpa always uses seqlens_q to slice both query and key tensors. For inference formats like sbhd_2bshd or thd_2bshd — where Q is a short decode step and K/V hold the full context cache — seqlens_q will be 1 for each batch item, and the key is truncated to a single token. The model attends to nothing but the first cached token for every generated step, with no error or warning. FlashAttention.forward already asserts this invariant for the same reason.

transformer_engine/pytorch/attention/dot_product_attention/backends.py — specifically the _use_varlen_sdpa guard conditions and _forward_varlen_sdpa key/value slicing logic.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/backends.py Adds _use_varlen_sdpa / _forward_varlen_sdpa fast path for padding_causal self-attention; missing max_seqlen_q == max_seqlen_kv guard allows the path to fire incorrectly for inference KV-cache scenarios, and self.softmax_scale is passed instead of the locally-computed scale variable (already flagged in previous threads).
tests/pytorch/attention/test_attention.py Adds regression test for the unfused THD padding_causal fast path; correctly verifies numerics and that get_full_mask is not called, but does not exercise QK layer-scaling or inference KV-cache scenarios where the new path can misbehave.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[UnfusedDotProductAttention.forward] --> B{qkv_format?}
    B -- thd --> C[ConvertTHDtoBSHD + transpose to SBHD]
    B -- bshd --> D[transpose to SBHD]
    B -- sbhd_2bshd --> E[transpose K/V only]
    B -- thd_2bshd --> F[convert_thd_to_bshd + transpose]
    C --> G[get_padding_mask if attn_mask is None]
    D --> G
    E --> G
    F --> G
    G --> H{_use_varlen_sdpa?}
    H -- True --> I[_forward_varlen_sdpa]
    I --> I1[extract seqlens from attention_mask]
    I1 --> I2[per-batch F.scaled_dot_product_attention is_causal=True]
    I2 --> I3[_format_context to original layout]
    H -- False --> J[get_full_mask]
    J --> K[Legacy unfused BMM path]
Loading

Reviews (3): Last reviewed commit: "Merge branch 'main' into fix-unfused-pad..." | Re-trigger Greptile

@cyanguwa
Copy link
Copy Markdown
Collaborator

cyanguwa commented Jun 1, 2026

Thanks for the contribution @hungryGeek16, but it looks like your base branch may be out of date - could you rebase please? Thanks!

@jberchtold-nvidia jberchtold-nvidia removed their request for review June 1, 2026 15:56
@hungryGeek16 hungryGeek16 force-pushed the fix-unfused-padding-causal-sdpa branch from 51d298b to 82f0e0e Compare June 8, 2026 18:23
@hungryGeek16
Copy link
Copy Markdown
Author

@cyanguwa , I have rebased and resolved conflicts, let me know if this works. Thanks!

Comment on lines +418 to +463
def _forward_varlen_sdpa(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
q_format: str,
batch_size: int,
max_seqlen_q: int,
cu_seqlens_q: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor],
scale: float,
) -> torch.Tensor:
"""Run causal self-attention without expanding padding masks to [b, 1, sq, sk]."""
context_layer = torch.zeros(
batch_size,
query_layer.size(2),
max_seqlen_q,
value_layer.size(3),
dtype=query_layer.dtype,
device=query_layer.device,
)

if attention_mask is not None:
seqlens_q = attention_mask.logical_not()[:, 0, 0, :].sum(dim=1)
else:
seqlens_q = torch.full(
(batch_size,), max_seqlen_q, dtype=torch.int64, device=query_layer.device
)

dropout_p = self.attention_dropout.p if self.training else 0.0
with self.attention_dropout_ctx():
for batch_id in range(batch_size):
seqlen_q = int(seqlens_q[batch_id].item())
if seqlen_q == 0:
continue
query = query_layer[:seqlen_q, batch_id].permute(1, 0, 2).unsqueeze(0)
key = key_layer[:seqlen_q, batch_id].permute(1, 0, 2).unsqueeze(0)
value = value_layer[:seqlen_q, batch_id].permute(1, 0, 2).unsqueeze(0)
context_layer[batch_id, :, :seqlen_q, :] = F.scaled_dot_product_attention(
query,
key,
value,
dropout_p=dropout_p,
is_causal=True,
scale=scale,
).squeeze(0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Fast path fires for inference KV-cache with mismatched Q/K seqlens

_use_varlen_sdpa has no guard on max_seqlen_q == max_seqlen_kv. For inference with a KV cache (e.g. qkv_format = "sbhd_2bshd" or "thd_2bshd"), max_seqlen_q is the current decode length (often 1) while max_seqlen_kv is the full cache length. When padding_causal is set for batched inference, get_padding_mask creates a [batch, 1, 1, max_seqlen_q] mask, seqlens_q will be all-1s, and key_layer[:1, batch_id] selects only the first cache token instead of the full KV context. Every generated token then attends to nothing but the first token — silently wrong, no error raised.

The fix is to add max_seqlen_q == max_seqlen_kv as a guard in _use_varlen_sdpa (forwarding those values from forward), or to reject the fast path inside forward before calling _use_varlen_sdpa when the two dimensions differ. FlashAttention.forward already asserts this invariant (line 1054) for the same reason.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants