From e1cb084d1ebc206d43addd8f67e359f94b7955af Mon Sep 17 00:00:00 2001 From: solos Date: Fri, 5 Jun 2026 14:11:24 +0800 Subject: [PATCH 1/3] fix(topk): fix UB and prevent vector load splitting in standalone_topk - Fix C++ UB by replacing union type-punning with `__builtin_memcpy`. - Prevent NVCC from splitting 128-bit vector loads by using unconditional pointer selection for safe companion reads. - Fix `len_cast_for_sync` underflow when `len_cast == 0`. - Guard against OOB reads when input length is smaller than `sizeof(WideT)`. --- .../common/util/standalone_topk.cuh | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/transformer_engine/common/util/standalone_topk.cuh b/transformer_engine/common/util/standalone_topk.cuh index 3d19cbfcf2..c66c84261e 100644 --- a/transformer_engine/common/util/standalone_topk.cuh +++ b/transformer_engine/common/util/standalone_topk.cuh @@ -181,11 +181,6 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, const } else { static_assert(sizeof(WideT) % sizeof(T) == 0); constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); - // TODO: it's UB - union { - WideT scalar; - T array[items_per_scalar]; // NOLINT(runtime/arrays) - } wide; int skip_cnt = (reinterpret_cast(in) % sizeof(WideT)) @@ -198,13 +193,15 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, const const idxT len_cast = (len - skip_cnt) / items_per_scalar; for (idxT i = thread_rank; i < len_cast; i += num_threads) { - wide.scalar = in_cast[i]; + const WideT wide_data = in_cast[i]; + T local_array[items_per_scalar]; // NOLINT(runtime/arrays) + __builtin_memcpy(local_array, &wide_data, sizeof(WideT)); const idxT real_i = skip_cnt + i * items_per_scalar; #pragma unroll for (int j = 0; j < items_per_scalar; ++j) { - f(wide.array[j], real_i + j); + f(local_array[j], real_i + j); } - } + } static_assert(WARP_SIZE >= items_per_scalar); // and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt @@ -236,10 +233,6 @@ __device__ void vectorized_process(const T *in, idxT len, Func f, int sync_width } else { static_assert(sizeof(WideT) % sizeof(T) == 0); constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); - union { - WideT scalar; - T array[items_per_scalar]; // NOLINT(runtime/arrays) - } wide; int skip_cnt = (reinterpret_cast(in) % sizeof(WideT)) @@ -251,18 +244,26 @@ __device__ void vectorized_process(const T *in, idxT len, Func f, int sync_width const WideT *in_cast = reinterpret_cast(in + skip_cnt); const idxT len_cast = (len - skip_cnt) / items_per_scalar; - const idxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width; - for (idxT i = tid; i < len_cast_for_sync; i += stride) { - bool valid = i < len_cast; - if (valid) { - wide.scalar = in_cast[i]; - } - const idxT real_i = skip_cnt + i * items_per_scalar; + // Skip when no full vector chunk exists: avoids len_cast_for_sync underflow and + // OOB companion reads (in_cast[0] needs at least one valid WideT). + if (len_cast > 0) { + const idxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width; + for (idxT i = tid; i < len_cast_for_sync; i += stride) { + const bool valid = i < len_cast; + // Unconditional 128-bit vector load: invalid threads read in_cast[0] (cached, + // discarded via valid=false) so NVCC emits LDG.E.128 instead of predicated load. + // Safe because len_cast > 0 guarantees at least one valid WideT at in_cast[0]. + const WideT *load_ptr = valid ? &in_cast[i] : &in_cast[0]; + const WideT wide_data = *load_ptr; + T local_array[items_per_scalar]; // NOLINT(runtime/arrays) + __builtin_memcpy(local_array, &wide_data, sizeof(WideT)); + const idxT real_i = skip_cnt + i * items_per_scalar; #pragma unroll for (int j = 0; j < items_per_scalar; ++j) { - f(wide.array[j], real_i + j, valid); + f(local_array[j], real_i + j, valid); } } + } static_assert(WARP_SIZE >= items_per_scalar); // need at most one warp for skipped and remained elements, From a2171b4a02ac4eaa8d68ab7a4b18bdc09c711ffd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Jun 2026 06:29:03 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/util/standalone_topk.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/util/standalone_topk.cuh b/transformer_engine/common/util/standalone_topk.cuh index c66c84261e..c8d6e19817 100644 --- a/transformer_engine/common/util/standalone_topk.cuh +++ b/transformer_engine/common/util/standalone_topk.cuh @@ -201,7 +201,7 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, const for (int j = 0; j < items_per_scalar; ++j) { f(local_array[j], real_i + j); } - } + } static_assert(WARP_SIZE >= items_per_scalar); // and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt @@ -259,11 +259,11 @@ __device__ void vectorized_process(const T *in, idxT len, Func f, int sync_width __builtin_memcpy(local_array, &wide_data, sizeof(WideT)); const idxT real_i = skip_cnt + i * items_per_scalar; #pragma unroll - for (int j = 0; j < items_per_scalar; ++j) { - f(local_array[j], real_i + j, valid); + for (int j = 0; j < items_per_scalar; ++j) { + f(local_array[j], real_i + j, valid); + } } } - } static_assert(WARP_SIZE >= items_per_scalar); // need at most one warp for skipped and remained elements, From 4dcb9452f60d580b0094d934c0de716535824f46 Mon Sep 17 00:00:00 2001 From: solos Date: Fri, 5 Jun 2026 14:43:00 +0800 Subject: [PATCH 3/3] [Common] Use index clamping for safe-passenger load in standalone_topk Clamp the load index to 0 for invalid padding threads instead of selecting between two pointers. This eliminates language-level UB from OOB pointer formation without changing the emitted vectorized global load. --- transformer_engine/common/util/standalone_topk.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/util/standalone_topk.cuh b/transformer_engine/common/util/standalone_topk.cuh index c8d6e19817..e5274b2060 100644 --- a/transformer_engine/common/util/standalone_topk.cuh +++ b/transformer_engine/common/util/standalone_topk.cuh @@ -252,9 +252,9 @@ __device__ void vectorized_process(const T *in, idxT len, Func f, int sync_width const bool valid = i < len_cast; // Unconditional 128-bit vector load: invalid threads read in_cast[0] (cached, // discarded via valid=false) so NVCC emits LDG.E.128 instead of predicated load. - // Safe because len_cast > 0 guarantees at least one valid WideT at in_cast[0]. - const WideT *load_ptr = valid ? &in_cast[i] : &in_cast[0]; - const WideT wide_data = *load_ptr; + // Index clamping (not pointer ternary) avoids C++ UB from &in_cast[i] when i >= len_cast. + const idxT safe_i = valid ? i : static_cast(0); + const WideT wide_data = in_cast[safe_i]; T local_array[items_per_scalar]; // NOLINT(runtime/arrays) __builtin_memcpy(local_array, &wide_data, sizeof(WideT)); const idxT real_i = skip_cnt + i * items_per_scalar;