fix unfused padding causal sdpa#3063
Conversation
Greptile SummaryThis PR introduces a targeted PyTorch SDPA fast path for
Confidence Score: 3/5The new fast path can silently produce wrong attention output when triggered during inference with a KV cache, because it slices keys to query length instead of the full cache length. The fast path in _forward_varlen_sdpa always uses seqlens_q to slice both query and key tensors. For inference formats like sbhd_2bshd or thd_2bshd — where Q is a short decode step and K/V hold the full context cache — seqlens_q will be 1 for each batch item, and the key is truncated to a single token. The model attends to nothing but the first cached token for every generated step, with no error or warning. FlashAttention.forward already asserts this invariant for the same reason. transformer_engine/pytorch/attention/dot_product_attention/backends.py — specifically the _use_varlen_sdpa guard conditions and _forward_varlen_sdpa key/value slicing logic. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[UnfusedDotProductAttention.forward] --> B{qkv_format?}
B -- thd --> C[ConvertTHDtoBSHD + transpose to SBHD]
B -- bshd --> D[transpose to SBHD]
B -- sbhd_2bshd --> E[transpose K/V only]
B -- thd_2bshd --> F[convert_thd_to_bshd + transpose]
C --> G[get_padding_mask if attn_mask is None]
D --> G
E --> G
F --> G
G --> H{_use_varlen_sdpa?}
H -- True --> I[_forward_varlen_sdpa]
I --> I1[extract seqlens from attention_mask]
I1 --> I2[per-batch F.scaled_dot_product_attention is_causal=True]
I2 --> I3[_format_context to original layout]
H -- False --> J[get_full_mask]
J --> K[Legacy unfused BMM path]
Reviews (3): Last reviewed commit: "Merge branch 'main' into fix-unfused-pad..." | Re-trigger Greptile |
|
Thanks for the contribution @hungryGeek16, but it looks like your base branch may be out of date - could you rebase please? Thanks! |
51d298b to
82f0e0e
Compare
|
@cyanguwa , I have rebased and resolved conflicts, let me know if this works. Thanks! |
| def _forward_varlen_sdpa( | ||
| self, | ||
| query_layer: torch.Tensor, | ||
| key_layer: torch.Tensor, | ||
| value_layer: torch.Tensor, | ||
| q_format: str, | ||
| batch_size: int, | ||
| max_seqlen_q: int, | ||
| cu_seqlens_q: Optional[torch.Tensor], | ||
| attention_mask: Optional[torch.Tensor], | ||
| scale: float, | ||
| ) -> torch.Tensor: | ||
| """Run causal self-attention without expanding padding masks to [b, 1, sq, sk].""" | ||
| context_layer = torch.zeros( | ||
| batch_size, | ||
| query_layer.size(2), | ||
| max_seqlen_q, | ||
| value_layer.size(3), | ||
| dtype=query_layer.dtype, | ||
| device=query_layer.device, | ||
| ) | ||
|
|
||
| if attention_mask is not None: | ||
| seqlens_q = attention_mask.logical_not()[:, 0, 0, :].sum(dim=1) | ||
| else: | ||
| seqlens_q = torch.full( | ||
| (batch_size,), max_seqlen_q, dtype=torch.int64, device=query_layer.device | ||
| ) | ||
|
|
||
| dropout_p = self.attention_dropout.p if self.training else 0.0 | ||
| with self.attention_dropout_ctx(): | ||
| for batch_id in range(batch_size): | ||
| seqlen_q = int(seqlens_q[batch_id].item()) | ||
| if seqlen_q == 0: | ||
| continue | ||
| query = query_layer[:seqlen_q, batch_id].permute(1, 0, 2).unsqueeze(0) | ||
| key = key_layer[:seqlen_q, batch_id].permute(1, 0, 2).unsqueeze(0) | ||
| value = value_layer[:seqlen_q, batch_id].permute(1, 0, 2).unsqueeze(0) | ||
| context_layer[batch_id, :, :seqlen_q, :] = F.scaled_dot_product_attention( | ||
| query, | ||
| key, | ||
| value, | ||
| dropout_p=dropout_p, | ||
| is_causal=True, | ||
| scale=scale, | ||
| ).squeeze(0) |
There was a problem hiding this comment.
Fast path fires for inference KV-cache with mismatched Q/K seqlens
_use_varlen_sdpa has no guard on max_seqlen_q == max_seqlen_kv. For inference with a KV cache (e.g. qkv_format = "sbhd_2bshd" or "thd_2bshd"), max_seqlen_q is the current decode length (often 1) while max_seqlen_kv is the full cache length. When padding_causal is set for batched inference, get_padding_mask creates a [batch, 1, 1, max_seqlen_q] mask, seqlens_q will be all-1s, and key_layer[:1, batch_id] selects only the first cache token instead of the full KV context. Every generated token then attends to nothing but the first token — silently wrong, no error raised.
The fix is to add max_seqlen_q == max_seqlen_kv as a guard in _use_varlen_sdpa (forwarding those values from forward), or to reject the fast path inside forward before calling _use_varlen_sdpa when the two dimensions differ. FlashAttention.forward already asserts this invariant (line 1054) for the same reason.
Adds a targeted PyTorch SDPA fallback for unfused THD padding_causal self-attention so TransformerEngine does not materialize the full quadratic padding/causal mask. Includes a regression test that fails if get_full_mask is called on this path.