Introduce Mega-C++ to reduce CPU overhead#3099
Conversation
for more information, see https://pre-commit.ci
|
/te-ci pytorch L1 |
| m.def("te_general_grouped_gemm_for_discrete_out", | ||
| &transformer_engine::pytorch::te_general_grouped_gemm_for_discrete_out, | ||
| "Grouped GEMM for discrete output list"); | ||
| m.def("megacpp_grouped_mlp_forward", &transformer_engine::pytorch::megacpp_grouped_mlp_forward, |
There was a problem hiding this comment.
We should expose these functions within the tex.grouped_mlp_experimental submodule:
TransformerEngine/transformer_engine/pytorch/csrc/extensions/pybind.cpp
Lines 647 to 650 in 3fffa55
There was a problem hiding this comment.
It would make more sense to organize:
csrc/
├── extensions/
│ ├── grouped_mlp_experimental/
│ │ ├── megacpp.cpp
│ │ └── grouped_mlp_experimental.cpp
│ ├── pybind.cpp
│ └── ...
If we implement more mega-C++ impls in the future, I don't see a reason why they would be more similar to each other than to the block they are fusing.
| name: str | ||
| is_scaled: bool | ||
| is_gated: bool | ||
| glu_interleave_size: int |
There was a problem hiding this comment.
Is it worth supporting GLU interleaving in the mega-C++ path? The only benefit is to support the fused GEMM+GLU kernel, and otherwise the unnecessary memory-bound kernel means perf is a lost cause. If we can simplify our optimized code paths, then it's worth it.
There was a problem hiding this comment.
The only benefit is to support the fused GEMM+GLU kernel
I do hope in the future we can launch CuteDSL fused kernels in C++ with some TVM-FFI tricks, otherwise we are forced to choose either better kernel fusions or less CPU overhead. Currently the CuteDSL fusion path is very CPU bounded for small models and we rely on CUDA graph and paged stashing for it to work well
| # Explicit env opt-in gives megacpp first chance. Unsupported recipes intentionally | ||
| # return the ops unchanged so lower-priority recipe-specific fusers remain the | ||
| # fallback path. | ||
| register_forward_fusion(fuse_forward_megacpp_ops, prepend=True) |
There was a problem hiding this comment.
The GEMM+act fusions provide better GPU perf, so I think they should take higher priority than mega-C++. Basically, I see mega-C++ as "we can't do any better on GPU than the unfused impl, but at least we can make the CPU overhead very small".
There was a problem hiding this comment.
Current order is follows:
- check env var
- env var = 1, then check supported recipe for mega-C++, so bf16 is supported, not mxfp8 / nvfp4
- then for mxfp8, nvfp4, mega-C++ does fallback and check for the next fusion.
The reasoning is that, I do not want the compromise of either better fusion or less host bound, so for future mxfp8 support, we can do the following two things:
- directly do cuteDSL integration directly with tvm-ffi and do cublas as a backup plan
- maybe add a new value to NVTE_MEGACPP_GROUPED_LINEAR=forced, so for users who cannot enable cuda graph for some reason, they can enforce C++ when they know that their training is more host bound
Description
Assistant: GPT5.5 codex
Get rid of CPU overhead whenever CUDA Graph is not applicable. Guarded by NVTE_MEGACPP_GROUPED_LINEAR.
Drop-in replace grouped MLP, ie. FC1 - act - FC2. Target BF16 grouped gemm with cublas grouped gemm backend.
In the future, we can extend to mxfp8 / nvfp4 with cublas backend or even cuteDSL grouped gemm and call
cute.jitin C++: NVIDIA/cutlass#3289Recommend CUDA >= 13.2.1
TODO:
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: