Skip to content

Add Experts4bit for 4-bit quantization of fused MoE experts#1965

Draft
pjordanandrsn wants to merge 4 commits into
bitsandbytes-foundation:mainfrom
pjordanandrsn:feature/experts-4bit
Draft

Add Experts4bit for 4-bit quantization of fused MoE experts#1965
pjordanandrsn wants to merge 4 commits into
bitsandbytes-foundation:mainfrom
pjordanandrsn:feature/experts-4bit

Conversation

@pjordanandrsn

Copy link
Copy Markdown

What

Adds bitsandbytes.nn.Experts4bit, a module that stores fused Mixture-of-Experts
weights in 4-bit (NF4/FP4) precision.

Fixes the memory issue in #1849: transformers v5 stores MoE experts as a single 3D
nn.Parameter (e.g. OlmoeExperts, Qwen3MoeExpertsgate_up_proj
[num_experts, 2*intermediate, hidden], down_proj [num_experts, hidden, intermediate]).
The nn.Linear-based 4-bit walker only swaps nn.Linear, so these fused experts are
skipped, stay in full precision, and dominate the loaded footprint.

Design

This follows the approach @matthewdouglas outlined on the issue:

  • Plain nn.Parameter for the packed weights (not Params4bit), with per-expert
    absmax kept on the module as buffers
    . This avoids bending Params4bit's
    tensor-subclass + device-movement machinery around a 3D stack, and the module
    serializes through the default state_dict — no custom save/load hooks.
  • Per-expert dequant loop in forward (mirrors the reference fused-experts forward in
    OlmoeExperts / FP8Experts): one expert's weight is dequantized, used, and freed at a
    time. This keeps the runtime working set small and leaves a clean path to a grouped-GEMM
    kernel later.
  • Enforces in_features % blocksize == 0 so per-expert quantization blocks tile each
    expert exactly and never straddle an expert boundary.

Relationship to replace_parameter_4bit (#1720): that generic parametrization also
quantizes arbitrary nn.Parameters, but dequantizes the entire [num_experts, …] stack
on every access. Experts4bit is MoE-aware — it only touches the experts a batch actually
routes to — which is what enables the grouped-GEMM follow-up.

Intentionally deferred for this first cut (per the issue discussion): double-quant
(compress_statistics), a grouped-GEMM forward, and the transformers-side walker wiring.

API

from bitsandbytes.nn import Experts4bit

# Quantize an existing fp16/bf16 fused-expert stack:
experts = Experts4bit.from_float(gate_up_proj, down_proj, quant_type="nf4")
out = experts(hidden_states, top_k_index, top_k_weights)

# Or construct empty + load_state_dict (e.g. pre-quantized checkpoints):
experts = Experts4bit(num_experts, hidden_dim, intermediate_dim)
experts.load_state_dict(sd)

Footprint & validation (measured on an RTX A2000 12 GB, sm_86)

For one real OLMoE-1B-7B layer (num_experts=64, hidden=2048, intermediate=1024, NF4,
blocksize 64, no double-quant), measured Experts4bit vs. the bf16 stack:

per layer full model (×16 layers)
experts, bf16 (today) 768.0 MB 12.00 GB
experts, Experts4bit (192 MB packed + 24 MB absmax) 216.0 MB 3.38 GB

3.56× smaller for the expert weights, which are the bulk of the model — combined with
the existing Linear4bit path on the non-expert layers this takes OLMoE-1B-7B from ~13 GB
to ~3.5 GB (fits a single 12 GB card). A forward over the real-sized layer peaks at
1295 MB of VRAM: because experts are dequantized one at a time, the working set never
materializes the full bf16 stack — the property that makes the grouped-GEMM follow-up
worthwhile.

Testing

tests/test_experts4bit.py — 11 cases, all green on the CPU default backend:

  • quant round-trip per expert (NF4/FP4 × fp16/bf16/fp32) within 4-bit tolerance, with
    packed-weight / absmax shape + dtype assertions
  • forward vs. a full-precision reference forward (gated + non-gated), float32 compute,
    rtol=atol=1e-4
  • state_dict round-trip: bit-exact restore of packed weights + absmax, identical forward
    after reload
  • validation guards (in_features % blocksize, invalid quant_type)

On CUDA (A2000, bnb 0.49.2 / torch 2.4.1) the NF4 round-trip mean-abs error is 0.0073 and
the forward matches the full-precision reference exactly (max-abs 0.0).

Closes #1849.

cc @matthewdouglas @SunMarc

@matthewdouglas

Copy link
Copy Markdown
Member

Hi, thanks for the PR. I am a little concerned with how quickly it was opened after discussion. With that said I'll follow up soon, but likely we won't merge something for this until after v0.50.0 release.

@pjordanandrsn

Copy link
Copy Markdown
Author

Thanks @matthewdouglas — fair concern. The asking-first part was real: nothing was written until the shape was pinned down, and the PR follows it — plain nn.Parameter + per-expert absmax buffers on the module, compress_statistics deferred, in_features % blocksize == 0 enforced, per-expert dequant loop for the first cut. The footprint/VRAM numbers in the description are measured on my own A2000 12 GB; fitting fused-experts MoE models on a 12 GB card is the itch that started this.

No rush from me; post-v0.50.0 was always the plan. Converting to draft so it reads as what it is — something concrete to react to when you pick the feature up. Happy to rework it toward whatever you land on, or for you to cherry-pick the useful parts.

@pjordanandrsn pjordanandrsn marked this pull request as draft June 10, 2026 17:00
@David-AU-github

Copy link
Copy Markdown

So looking forward to this; this will allow Unsloth training of Qwen 3.5/.6 and Gemma 4 sparse moes on reasonable hardware vs H100/200.

@pjordanandrsn

Copy link
Copy Markdown
Author

Updated this branch with backward/gradient + QLoRA-training test coverage and a small OLMoE-1B-7B demo validating quant-correctness and the ~4.7 GB footprint on a 12 GB card (per your June note). Still the correctness-first per-expert loop — no grouped-GEMM.

No rush given v0.50.0 isn't out yet; happy to leave this in draft until you pick it up.

Separately I have a memory-lowering matmul_4bit forward (gated so it only engages on bitsandbytes ≥ 0.50, else falls back to the bit-exact dequantize path) — glad to add it here, keep it as a follow-up, or leave it for your MoE-kernel work, whichever you prefer.

pjordanandrsn and others added 4 commits July 2, 2026 10:42
…ytes-foundation#1849)

transformers v5 stores fused MoE experts as a single 3D nn.Parameter
(e.g. OlmoeExperts, Qwen3MoeExperts), which the nn.Linear-based 4-bit
walker skips. The experts stay in full precision and load_in_4bit barely
shrinks the model (issue bitsandbytes-foundation#1849).

Experts4bit holds gate_up_proj and down_proj packed in NF4/FP4 as plain
nn.Parameter buffers, with per-expert absmax kept on the module itself.
The forward pass dequantizes one expert at a time (a per-expert loop),
mirroring the reference fused-experts forward. There is no Params4bit
tensor-subclass machinery, so the module serializes through the default
state_dict with no custom hooks.

- from_float() quantizes existing bf16/fp16 expert stacks
- enforces in_features % blocksize == 0 for clean per-expert blocking
- double-quant (compress_statistics) and grouped-GEMM intentionally
  deferred for a first cut
- tests: quant round-trip, forward vs. full-precision reference,
  state_dict round-trip, and validation guards
Prove Experts4bit works as a frozen 4-bit QLoRA base. New tests in
tests/test_experts4bit.py cover the autograd contract: gradients reach the
input activations, the frozen packed weights never receive a gradient, and the
backward matches a full-precision reference forward.

Add examples/experts4bit_qlora_demo.py with a small per-expert ExpertsLoRA
wrapper over a frozen base, plus a test that a real optimizer step reduces loss
while the 4-bit base stays bit-identical. The wrapper is a reference pattern
(PEFT/Unsloth territory), intentionally not part of the bitsandbytes public API.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…rametrized forward coverage

- test_experts4bit_1849_regression_...: build a fused-3D nn.Parameter expert module (the transformers
  v5 layout the Linear4bit walker skips), assert Experts4bit.from_float actually 4-bit-quantizes it
  (uint8-packed, >3x smaller than fp16). Guards the exact silent-skip bitsandbytes-foundation#1849 reports.
- test_experts4bit_shapes: forward correctness across a spread of (num_experts, hidden, intermediate).

Co-Authored-By: Jordan Anderson <paul.jordan.anderson@gmail.com>
@pjordanandrsn pjordanandrsn force-pushed the feature/experts-4bit branch from 7a2b9fd to 2748d76 Compare July 2, 2026 16:14
@pjordanandrsn

Copy link
Copy Markdown
Author

Rebased on current main; added a regression test for the exact #1849 silent-skip (fused 3-D module with no nn.Linear → asserts from_float actually 4-bit-quantizes it) plus shape-parametrized forward coverage — 38 tests, CPU+CUDA. Still draft, no rush.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Failed to quant MoE models with fused expert weights in transformers v5

3 participants