Add Experts4bit for 4-bit quantization of fused MoE experts#1965
Add Experts4bit for 4-bit quantization of fused MoE experts#1965pjordanandrsn wants to merge 4 commits into
Conversation
|
Hi, thanks for the PR. I am a little concerned with how quickly it was opened after discussion. With that said I'll follow up soon, but likely we won't merge something for this until after v0.50.0 release. |
|
Thanks @matthewdouglas — fair concern. The asking-first part was real: nothing was written until the shape was pinned down, and the PR follows it — plain No rush from me; post-v0.50.0 was always the plan. Converting to draft so it reads as what it is — something concrete to react to when you pick the feature up. Happy to rework it toward whatever you land on, or for you to cherry-pick the useful parts. |
|
So looking forward to this; this will allow Unsloth training of Qwen 3.5/.6 and Gemma 4 sparse moes on reasonable hardware vs H100/200. |
|
Updated this branch with backward/gradient + QLoRA-training test coverage and a small OLMoE-1B-7B demo validating quant-correctness and the ~4.7 GB footprint on a 12 GB card (per your June note). Still the correctness-first per-expert loop — no grouped-GEMM. No rush given v0.50.0 isn't out yet; happy to leave this in draft until you pick it up. Separately I have a memory-lowering |
…ytes-foundation#1849) transformers v5 stores fused MoE experts as a single 3D nn.Parameter (e.g. OlmoeExperts, Qwen3MoeExperts), which the nn.Linear-based 4-bit walker skips. The experts stay in full precision and load_in_4bit barely shrinks the model (issue bitsandbytes-foundation#1849). Experts4bit holds gate_up_proj and down_proj packed in NF4/FP4 as plain nn.Parameter buffers, with per-expert absmax kept on the module itself. The forward pass dequantizes one expert at a time (a per-expert loop), mirroring the reference fused-experts forward. There is no Params4bit tensor-subclass machinery, so the module serializes through the default state_dict with no custom hooks. - from_float() quantizes existing bf16/fp16 expert stacks - enforces in_features % blocksize == 0 for clean per-expert blocking - double-quant (compress_statistics) and grouped-GEMM intentionally deferred for a first cut - tests: quant round-trip, forward vs. full-precision reference, state_dict round-trip, and validation guards
Prove Experts4bit works as a frozen 4-bit QLoRA base. New tests in tests/test_experts4bit.py cover the autograd contract: gradients reach the input activations, the frozen packed weights never receive a gradient, and the backward matches a full-precision reference forward. Add examples/experts4bit_qlora_demo.py with a small per-expert ExpertsLoRA wrapper over a frozen base, plus a test that a real optimizer step reduces loss while the 4-bit base stays bit-identical. The wrapper is a reference pattern (PEFT/Unsloth territory), intentionally not part of the bitsandbytes public API. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…rametrized forward coverage - test_experts4bit_1849_regression_...: build a fused-3D nn.Parameter expert module (the transformers v5 layout the Linear4bit walker skips), assert Experts4bit.from_float actually 4-bit-quantizes it (uint8-packed, >3x smaller than fp16). Guards the exact silent-skip bitsandbytes-foundation#1849 reports. - test_experts4bit_shapes: forward correctness across a spread of (num_experts, hidden, intermediate). Co-Authored-By: Jordan Anderson <paul.jordan.anderson@gmail.com>
7a2b9fd to
2748d76
Compare
|
Rebased on current main; added a regression test for the exact #1849 silent-skip (fused 3-D module with no |
What
Adds
bitsandbytes.nn.Experts4bit, a module that stores fused Mixture-of-Expertsweights in 4-bit (NF4/FP4) precision.
Fixes the memory issue in #1849: transformers v5 stores MoE experts as a single 3D
nn.Parameter(e.g.OlmoeExperts,Qwen3MoeExperts—gate_up_proj[num_experts, 2*intermediate, hidden],down_proj[num_experts, hidden, intermediate]).The
nn.Linear-based 4-bit walker only swapsnn.Linear, so these fused experts areskipped, stay in full precision, and dominate the loaded footprint.
Design
This follows the approach @matthewdouglas outlined on the issue:
nn.Parameterfor the packed weights (notParams4bit), with per-expertabsmaxkept on the module as buffers. This avoids bendingParams4bit'stensor-subclass + device-movement machinery around a 3D stack, and the module
serializes through the default
state_dict— no custom save/load hooks.forward(mirrors the reference fused-experts forward inOlmoeExperts/FP8Experts): one expert's weight is dequantized, used, and freed at atime. This keeps the runtime working set small and leaves a clean path to a grouped-GEMM
kernel later.
in_features % blocksize == 0so per-expert quantization blocks tile eachexpert exactly and never straddle an expert boundary.
Relationship to
replace_parameter_4bit(#1720): that generic parametrization alsoquantizes arbitrary
nn.Parameters, but dequantizes the entire[num_experts, …]stackon every access.
Experts4bitis MoE-aware — it only touches the experts a batch actuallyroutes to — which is what enables the grouped-GEMM follow-up.
Intentionally deferred for this first cut (per the issue discussion): double-quant
(
compress_statistics), a grouped-GEMM forward, and the transformers-side walker wiring.API
Footprint & validation (measured on an RTX A2000 12 GB, sm_86)
For one real OLMoE-1B-7B layer (
num_experts=64, hidden=2048, intermediate=1024, NF4,blocksize 64, no double-quant), measured
Experts4bitvs. the bf16 stack:Experts4bit(192 MB packed + 24 MB absmax)3.56× smaller for the expert weights, which are the bulk of the model — combined with
the existing
Linear4bitpath on the non-expert layers this takes OLMoE-1B-7B from ~13 GBto ~3.5 GB (fits a single 12 GB card). A forward over the real-sized layer peaks at
1295 MB of VRAM: because experts are dequantized one at a time, the working set never
materializes the full bf16 stack — the property that makes the grouped-GEMM follow-up
worthwhile.
Testing
tests/test_experts4bit.py— 11 cases, all green on the CPU default backend:packed-weight / absmax shape + dtype assertions
forwardvs. a full-precision reference forward (gated + non-gated), float32 compute,rtol=atol=1e-4state_dictround-trip: bit-exact restore of packed weights + absmax, identical forwardafter reload
in_features % blocksize, invalidquant_type)On CUDA (A2000, bnb 0.49.2 / torch 2.4.1) the NF4 round-trip mean-abs error is 0.0073 and
the forward matches the full-precision reference exactly (max-abs 0.0).
Closes #1849.
cc @matthewdouglas @SunMarc