Quantization support for GroupedTensor: FP8 per-tensor#3102
Conversation
Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds FP8 per-tensor (delayed) scaling support for
Confidence Score: 3/5The kernel changes touch the hot path for every grouped FP8 quantization and are not safe to merge in their current form — both a silent stack overflow and a shared-memory correctness hazard need to be resolved first. The
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["nvte_group_quantize / nvte_group_dequantize"] --> B["group_quantize_fwd_helper / group_dequantize_helper"]
B --> C{scaling_mode?}
C -- NVTE_DELAYED_TENSOR_SCALING --> D["fp8::group_quantize / fp8::group_dequantize"]
C -- NVTE_MXFP8_1D_SCALING --> E["mxfp8::group_quantize / mxfp8::group_dequantize"]
D --> F["VectorizedUnaryKernelLauncher\n(offsets, first_dims, last_dims,\nnum_tensors, scale_numel, ...)"]
F --> G["unary_kernel (GPU)\nfor each element:\n find_tensor_id (binary search)\n apply per-tensor scale\n accumulate block_max[tensor_id]"]
G --> H["per-tensor amax reduction\nreduce_max loop over num_tensors\natomicMaxFloat(amax[t])"]
G --> I["scale_inv write\nat offsets[tensor_id] only"]
style H fill:#f96,stroke:#c00
style G fill:#f96,stroke:#c00
|
| } | ||
| const int warp_id = threadIdx.x / THREADS_PER_WARP; | ||
|
|
||
| float block_max[64] = {0.0f}; |
There was a problem hiding this comment.
Fixed-size
block_max array overflows when num_tensors > 64
block_max is indexed by tensor_id which is bounded by num_tensors (a runtime kernel parameter), yet the array is fixed at size 64. When a grouped tensor has more than 64 sub-tensors — common in large MoE models — any thread whose tensor_id >= 64 writes past the end of the array, corrupting other stack/local variables (including warp_id, loop variables, max, etc.) and producing silently wrong quantized outputs or GPU faults. The same defect exists in the unary_grad_kernel at the corresponding location.
| if (offsets != nullptr || num_tensors > 1) { | ||
| for (size_t t = 0; t < num_tensors; ++t) { | ||
| float t_max = block_max[t]; | ||
| t_max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(t_max, warp_id); | ||
| if (threadIdx.x == 0 && t_max > 0.0f) { | ||
| size_t amax_idx = (amax_numel == num_tensors) ? t : 0; | ||
| atomicMaxFloat(&amax[amax_idx], t_max); | ||
| } | ||
| } |
There was a problem hiding this comment.
Shared-memory race in per-tensor amax loop
reduce_max uses a __shared__ float staging[num_warps] array and a single __syncthreads() that ensures visibility before warp 0 reads staging. However, reduce_max does NOT call __syncthreads() after warp 0's read before returning. Calling it in a loop means that warp 1 (and other non-zero warps) can reach the staging[warpid] = my_warp_max write for iteration t+1 before warp 0 finishes reading staging[1] for iteration t. Without an explicit barrier between iterations, the CUDA memory model does not guarantee ordering, so warp 0 can read a partially updated staging[1] and compute an incorrect per-tensor amax. A __syncthreads() is needed after each call to reduce_max in this loop (and the identical loop in unary_grad_kernel).
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes #2449
Type of change
Changes
Please list the changes introduced in this PR:
Kernels: Extended unary_kernel and unary_grad_kernel in vectorized_pointwise.h to dynamically support per-tensor scale, scale_inv, and amax for grouped tensors.
Alignment: Aligned the random padding in test_common.cu to a constant 64 elements to guarantee matching element offsets between input and output grouped tensors.
Verification: Corrected the FP8 cast validation loop in test_cast_fp8_grouped.cu to compare raw quantized values directly, resolving false test failures caused by rounding errors.
Checklist: