Skip to content

fix(topk): fix UB and prevent vector load splitting in standalone_topk#3088

Open
solos wants to merge 6 commits into
NVIDIA:mainfrom
solos:fix-topk
Open

fix(topk): fix UB and prevent vector load splitting in standalone_topk#3088
solos wants to merge 6 commits into
NVIDIA:mainfrom
solos:fix-topk

Conversation

@solos
Copy link
Copy Markdown

@solos solos commented Jun 5, 2026

Description

This PR refactors the vectorized processing logic in standalone_topk to eliminate C++ undefined behavior and ensure optimal GPU instruction emission. The changes address a critical issue where conditional vector loads were causing the compiler to split 128-bit memory transactions into scalar operations, degrading performance. Additionally, it fixes potential edge cases related to unsigned integer underflow and out-of-bounds memory access.

Fixes #

Fix C++ UB by replacing union type-punning with __builtin_memcpy.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Fix C++ UB by replacing union type-punning with __builtin_memcpy.
  • Prevent NVCC from splitting 128-bit vector loads by using unconditional pointer selection for safe companion reads.
  • Fix len_cast_for_sync underflow when len_cast == 0.
  • Guard against OOB reads when input length is smaller than sizeof(WideT).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

- Fix C++ UB by replacing union type-punning with `__builtin_memcpy`.
- Prevent NVCC from splitting 128-bit vector loads by using unconditional
  pointer selection for safe companion reads.
- Fix `len_cast_for_sync` underflow when `len_cast == 0`.
- Guard against OOB reads when input length is smaller than `sizeof(WideT)`.
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 5, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 5, 2026

Greptile Summary

This PR refactors the vectorized processing logic in standalone_topk.cuh to eliminate C++ undefined behavior and ensure the GPU emits optimal 128-bit (LDG.E.128) vector load instructions. All four stated fixes are sound and correctly implemented.

  • Union type-punning removed: Both vectorized_process overloads replace the union { WideT scalar; T array[]; } idiom with __builtin_memcpy, which is the well-defined C++ way to reinterpret memory; NVCC optimizes this to zero-overhead register moves at compile time.
  • Conditional load eliminated: The sync-version loop switches from if (valid) { wide.scalar = in_cast[i]; } to unconditional in_cast[safe_i] with integer-clamped index (valid ? i : 0), removing the predicated load that caused NVCC to split the 128-bit transaction into scalar ops; this also supersedes the pointer-ternary approach flagged in a prior review thread.
  • Underflow + OOB guard: The entire vectorized block in the sync version is wrapped in if (len_cast > 0), preventing unsigned underflow in len_cast - 1 and ensuring in_cast[0] (the companion read target) is within the allocated buffer.

Confidence Score: 5/5

Safe to merge; all four fixes are correct, the previous pointer-ternary concern from the prior thread is fully addressed by integer index clamping, and no new issues are introduced.

The union-to-memcpy refactor is the canonical C++ type-pun idiom and NVCC optimizes it to zero overhead. The unconditional load with safe_i eliminates the predicated-load regression using well-defined integer arithmetic. The len_cast > 0 guard prevents unsigned underflow and OOB companion reads. The non-sync overload is unaffected because its loop condition already gates correctly when len_cast == 0.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/common/util/standalone_topk.cuh Union type-punning replaced with __builtin_memcpy; conditional 128-bit load replaced with unconditional index-clamped load; len_cast underflow and OOB companion-read guarded with len_cast > 0 check. All changes are correct and well-commented.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["vectorized_process sync version"] --> B{"T smaller than WideT?"}
    B -- no --> C["scalar loop: f(in[i], i, true)"]
    B -- yes --> D["compute skip_cnt and len_cast"]
    D --> E{"len_cast > 0? (NEW guard)"}
    E -- no --> F["skip vector loop\navoids underflow and OOB"]
    E -- yes --> G["len_cast_for_sync = round_up to sync_width"]
    G --> H["loop: i = tid to len_cast_for_sync"]
    H --> I["valid = i < len_cast\nsafe_i = valid ? i : 0\nNEW integer clamp"]
    I --> J["wide_data = in_cast[safe_i]\nunconditional LDG.E.128"]
    J --> K["__builtin_memcpy into local_array\nNEW replaces union type-pun"]
    K --> L["f(local_array[j], real_i+j, valid)"]
    L --> H
    H -- done --> M["tail: scalar skip + remain\ntid < sync_width only"]
    F --> M
Loading

Reviews (5): Last reviewed commit: "Merge branch 'main' into fix-topk" | Re-trigger Greptile

Comment thread transformer_engine/common/util/standalone_topk.cuh Outdated
Clamp the load index to 0 for invalid padding threads instead of
selecting between two pointers. This eliminates language-level UB from
OOB pointer formation without changing the emitted vectorized global load.
Copy link
Copy Markdown
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Using the union to achieve the vectorized accesses is a very wide practice in CUDA and endorsed by the NVCC compiler experts. As for the second issue - what is the case where you saw the issue (shape of the inputs etc.)? If len_cast is less than 0 then len_cast_for_sync is also going to be less than 0 and so the loop will just not run - I don't believe this PR actually changes the behavior of the kernel in any way.

@solos solos closed this Jun 6, 2026
@solos solos reopened this Jun 6, 2026
@solos
Copy link
Copy Markdown
Author

solos commented Jun 6, 2026

No. Using the union to achieve the vectorized accesses is a very wide practice in CUDA and endorsed by the NVCC compiler experts. As for the second issue - what is the case where you saw the issue (shape of the inputs etc.)? If len_cast is less than 0 then len_cast_for_sync is also going to be less than 0 and so the loop will just not run - I don't believe this PR actually changes the behavior of the kernel in any way.

Regarding the claim that the loop won't run when len_cast is negative: that's not always true. Due to C++'s truncation-toward-zero integer division. For example, if len_cast = -1 and sync_width = 32, then len_cast - 1 = -2, and -2 / 32 is 0 (not -1), so len_cast_for_sync becomes (0 + 1) * 32 = 32. The loop then executes multiple iterations even though len_cast is negative, with valid always false. While the invalid iterations are logically guarded by valid, the expression &in_cast[i] (or the pointer arithmetic in in_cast[i]) is still evaluated for out-of-bounds indices, which is undefined behavior. The added if (len_cast > 0) guard correctly prevents this edge case and does change the behavior for safety.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants