fix(topk): fix UB and prevent vector load splitting in standalone_topk#3088
fix(topk): fix UB and prevent vector load splitting in standalone_topk#3088solos wants to merge 6 commits into
Conversation
- Fix C++ UB by replacing union type-punning with `__builtin_memcpy`. - Prevent NVCC from splitting 128-bit vector loads by using unconditional pointer selection for safe companion reads. - Fix `len_cast_for_sync` underflow when `len_cast == 0`. - Guard against OOB reads when input length is smaller than `sizeof(WideT)`.
for more information, see https://pre-commit.ci
Greptile SummaryThis PR refactors the vectorized processing logic in
Confidence Score: 5/5Safe to merge; all four fixes are correct, the previous pointer-ternary concern from the prior thread is fully addressed by integer index clamping, and no new issues are introduced. The union-to-memcpy refactor is the canonical C++ type-pun idiom and NVCC optimizes it to zero overhead. The unconditional load with safe_i eliminates the predicated-load regression using well-defined integer arithmetic. The len_cast > 0 guard prevents unsigned underflow and OOB companion reads. The non-sync overload is unaffected because its loop condition already gates correctly when len_cast == 0. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["vectorized_process sync version"] --> B{"T smaller than WideT?"}
B -- no --> C["scalar loop: f(in[i], i, true)"]
B -- yes --> D["compute skip_cnt and len_cast"]
D --> E{"len_cast > 0? (NEW guard)"}
E -- no --> F["skip vector loop\navoids underflow and OOB"]
E -- yes --> G["len_cast_for_sync = round_up to sync_width"]
G --> H["loop: i = tid to len_cast_for_sync"]
H --> I["valid = i < len_cast\nsafe_i = valid ? i : 0\nNEW integer clamp"]
I --> J["wide_data = in_cast[safe_i]\nunconditional LDG.E.128"]
J --> K["__builtin_memcpy into local_array\nNEW replaces union type-pun"]
K --> L["f(local_array[j], real_i+j, valid)"]
L --> H
H -- done --> M["tail: scalar skip + remain\ntid < sync_width only"]
F --> M
Reviews (5): Last reviewed commit: "Merge branch 'main' into fix-topk" | Re-trigger Greptile |
Clamp the load index to 0 for invalid padding threads instead of selecting between two pointers. This eliminates language-level UB from OOB pointer formation without changing the emitted vectorized global load.
ptrendx
left a comment
There was a problem hiding this comment.
No. Using the union to achieve the vectorized accesses is a very wide practice in CUDA and endorsed by the NVCC compiler experts. As for the second issue - what is the case where you saw the issue (shape of the inputs etc.)? If len_cast is less than 0 then len_cast_for_sync is also going to be less than 0 and so the loop will just not run - I don't believe this PR actually changes the behavior of the kernel in any way.
Regarding the claim that the loop won't run when len_cast is negative: that's not always true. Due to C++'s truncation-toward-zero integer division. For example, if len_cast = -1 and sync_width = 32, then len_cast - 1 = -2, and -2 / 32 is 0 (not -1), so len_cast_for_sync becomes (0 + 1) * 32 = 32. The loop then executes multiple iterations even though len_cast is negative, with valid always false. While the invalid iterations are logically guarded by valid, the expression &in_cast[i] (or the pointer arithmetic in in_cast[i]) is still evaluated for out-of-bounds indices, which is undefined behavior. The added if (len_cast > 0) guard correctly prevents this edge case and does change the behavior for safety. |
Description
This PR refactors the vectorized processing logic in standalone_topk to eliminate C++ undefined behavior and ensure optimal GPU instruction emission. The changes address a critical issue where conditional vector loads were causing the compiler to split 128-bit memory transactions into scalar operations, degrading performance. Additionally, it fixes potential edge cases related to unsigned integer underflow and out-of-bounds memory access.
Fixes #
Fix C++ UB by replacing union type-punning with __builtin_memcpy.
Type of change
Changes
Please list the changes introduced in this PR:
__builtin_memcpy.len_cast_for_syncunderflow whenlen_cast == 0.sizeof(WideT).Checklist: