Skip to content

feat: DSpark trainer (DFlash + Markov/confidence heads + L1 distillation)#613

Open
maocheng23 wants to merge 2 commits into
sgl-project:mainfrom
maocheng23:dspark-trainer
Open

feat: DSpark trainer (DFlash + Markov/confidence heads + L1 distillation)#613
maocheng23 wants to merge 2 commits into
sgl-project:mainfrom
maocheng23:dspark-trainer

Conversation

@maocheng23

@maocheng23 maocheng23 commented Jun 28, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds DSpark draft-model training to SpecForge — a port of TorchSpec PR #129 ("DeepSpec - DSpark trainer support"). DSpark extends DFlash's block-diffusion drafter with EAGLE-style Markov (low-rank learned bigram bias) and confidence (per-position accept-rate) heads, trained with a combined objective:

loss = ce_alpha·CE + l1_alpha·L1_distill + confidence_alpha·BCE
  • CE — hard cross-entropy on the next tokens, on Markov-biased logits.
  • L1 distillation|softmax(draft) − softmax(target)|, where the target distribution is the frozen target LM head applied to the target's final hidden state at the aligned position.
  • Confidence BCE — per-position accept-rate (1 − 0.5·L1) prediction, used for adaptive block length at inference.

Markov/confidence formulation adapted from DeepSeek's DeepSpec (MIT).

What's added

New files

  • specforge/modeling/draft/dspark.pyDSparkConfig (extends Qwen3Config), VanillaMarkov, AcceptRatePredictor, build_markov_head, DSparkDraftModel (subclass of DFlashDraftModel).
  • specforge/core/dspark.pyOnlineDSparkModel (subclass of OnlineDFlashModel): reuses DFlash anchor sampling / MASK-noise / block-mask machinery; adds Markov-biased logits, CE + L1 + confidence, within-block decay weighting, and a pooled global-mean loss.
  • scripts/train_dspark.py — training driver (mirrors train_dflash.py); plumbs last_hidden_states, logs per-component losses, adds --max-steps.
  • configs/qwen3-8b-dspark.json, examples/run_qwen3_8b_dspark_online.sh.
  • tests/test_utils/test_dspark.py — 11 CPU unit tests (head math, config, internal loss identity, all-slots label convention, grad flow, CE-only path).

Changed

  • specforge/modeling/target/dflash_target_model.py — surface the target's final hidden state as DFlashTargetOutput.last_hidden_states (HF + sglang backends). DFlash ignores it; DSpark's L1/confidence losses need it. CE-only training needs nothing new.
  • specforge/modeling/draft/__init__.py, specforge/core/__init__.py — exports.

Design notes

  • This is a port, not a cherry-pick. SpecForge structures DFlash differently from TorchSpec (no DFlashConfig/DFlashModel/DFlashTrainer class family; a monolithic OnlineDFlashModel + a flat training script; a different draft-model forward signature; DFlash drops anchor slot 0 while DSpark supervises all block slots with a cumprod eval mask). DSpark is re-expressed in SpecForge idioms — DSparkConfig extends Qwen3Config (not a DFlashConfig), the loss wrapper subclasses OnlineDFlashModel, and training is a script (no trainer class).
  • Pooled global loss — local numerators over an all-reduced denominator, scaled by world_size to cancel FSDP's mean-grad reduction (correct for ZeRO-2 / SHARD_GRAD_OP, a single shard group). Verified by matching loss scale across 1-node vs 2-node runs.
  • Also fixes a latent device/device_type NameError in train_dflash.py's main() (carried into the new script).

Testing

  • CPU unit tests: 11/11 pass (python -m unittest tests.test_utils.test_dspark).
  • GPU (2× H200 nodes, Qwen3-8B target, flex_attention):
    • 1-node 8-GPU and 2-node 16-GPU smoke: clean, all three loss components healthy, checkpoint saved. Loss scale matches across 1-node vs 2-node (validates the world_size correction).
    • 300-step real ShareGPT run (2-node): loss 3.09→2.49, draft acc 1.5%→16.2%, ce 11.46→6.20, l1 1.97→1.66.
    • Full 2-epoch run on the complete ShareGPT set (120,675 convs, 7,276 steps, 2-node): combined loss 3.09→1.31, draft greedy accuracy 1.5%→~52%, ce_loss→~2.0, l1_loss 1.97→0.67. confidence_loss settles/decreases (0.51→0.45) once the accept-rate target stabilizes — the earlier short-run rise was just the non-stationary target catching up. All three components healthy throughout; no NaN/OOM/instability. Checkpoints every ~1000 steps + per-epoch.

Notes / follow-ups

  • The HF target backend is fully validated for the L1/confidence path. The sglang last_hidden_states path is implemented but not yet GPU-verified — use HF, or run CE-only (--l1-loss-alpha 0 --no-confidence-head) on sglang until verified.
  • Heads use default init (backbone keeps post_init HF init) — faithful to the reference.
  • Optional next step: wire a DSparkTrainStrategy into the runtime path for parity with EAGLE3.

🤖 Generated with Claude Code

…ion)

Port of TorchSpec PR sgl-project#129 to SpecForge. Adds:
- specforge/modeling/draft/dspark.py: DSparkConfig, VanillaMarkov,
  AcceptRatePredictor, DSparkDraftModel (subclass of DFlashDraftModel)
- specforge/core/dspark.py: OnlineDSparkModel (subclass of OnlineDFlashModel)
  with Markov-biased logits + CE + L1 distribution distillation + confidence BCE
  and a pooled global-mean loss
- scripts/train_dspark.py: training driver (clone of train_dflash.py)
- configs/qwen3-8b-dspark.json, examples/run_qwen3_8b_dspark_online.sh
- last_hidden_states surfaced from the DFlash target backends (HF + sglang)
- tests/test_utils/test_dspark.py: 11 CPU unit tests

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the DSpark draft model and its online training pipeline, which builds upon the DFlash backbone by adding EAGLE-style Markov and confidence heads along with L1 distribution distillation. The changes include new configurations, training scripts, model definitions, target model updates to surface final hidden states, and unit tests. The review feedback highlights two key issues: a bug in the dataset filtering step in train_dspark.py where calling .sum() on a Python list will cause a crash, and an incorrect loss scaling calculation in dspark.py when Tensor Parallelism is enabled, which should use the data-parallel group instead of the global world size.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread scripts/train_dspark.py
Comment on lines +327 to +329
train_eagle3_dataset = train_eagle3_dataset.filter(
lambda x: x["loss_mask"].sum() >= min_loss_tokens
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The Hugging Face Dataset.filter method passes a dictionary of Python types (where sequence features like loss_mask are standard Python list objects) to the filter function when batched=False (the default). Calling .sum() on a Python list will raise an AttributeError: 'list' object has no attribute 'sum' and crash the training script. Use the built-in sum() function instead.

Suggested change
train_eagle3_dataset = train_eagle3_dataset.filter(
lambda x: x["loss_mask"].sum() >= min_loss_tokens
)
train_eagle3_dataset = train_eagle3_dataset.filter(
lambda x: sum(x["loss_mask"]) >= min_loss_tokens
)

Comment thread specforge/core/dspark.py
Comment on lines +292 to +301
world_size = dist.get_world_size() if dist.is_initialized() else 1
global_den = local_den.detach().clone()
if world_size > 1:
dist.all_reduce(global_den, op=dist.ReduceOp.SUM)
global_den = global_den + 1e-6
loss = (
self.ce_loss_alpha * ce_num / global_den
+ self.l1_loss_alpha * l1_num / global_den
+ self.confidence_head_alpha * conf_num / global_den
) * world_size

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When Tensor Parallelism (TP) is enabled (e.g., via --tp-size), dist.get_world_size() returns the total number of GPUs across both DP and TP groups. However, the loss denominator should only be summed across the data parallel (DP) group, and the loss scaling factor should be dp_size rather than world_size to ensure correct gradient scaling under FSDP. Using the global world_size will scale gradients incorrectly by a factor of tp_size.

Consider using get_dp_group() from specforge.distributed to perform the all_reduce and scale the loss correctly.

Suggested change
world_size = dist.get_world_size() if dist.is_initialized() else 1
global_den = local_den.detach().clone()
if world_size > 1:
dist.all_reduce(global_den, op=dist.ReduceOp.SUM)
global_den = global_den + 1e-6
loss = (
self.ce_loss_alpha * ce_num / global_den
+ self.l1_loss_alpha * l1_num / global_den
+ self.confidence_head_alpha * conf_num / global_den
) * world_size
from specforge.distributed import get_dp_group
dp_group = get_dp_group() if dist.is_initialized() else None
dp_size = dist.get_world_size(dp_group) if dp_group is not None else 1
global_den = local_den.detach().clone()
if dp_size > 1:
dist.all_reduce(global_den, op=dist.ReduceOp.SUM, group=dp_group)
global_den = global_den + 1e-6
loss = (
self.ce_loss_alpha * ce_num / global_den
+ self.l1_loss_alpha * l1_num / global_den
+ self.confidence_head_alpha * conf_num / global_den
) * dp_size

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds right.

curnane-lab pushed a commit to curnane-lab/SpecForge_npu that referenced this pull request Jun 29, 2026

@Dogacel Dogacel left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think overall LGTM, thank you!

Comment thread scripts/train_dspark.py
choices=["sglang", "hf"],
help="Backend for target model: 'sglang' (service) or 'hf' (local). "
"DSpark's L1/confidence losses need the target's final hidden state; "
"the 'hf' backend always surfaces it.",

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mean? Does SGLang not return the final hidden state? It should have.

Comment thread scripts/train_dspark.py
Comment on lines +391 to +401
# Copy the modeling files next to the checkpoint so auto_map can
# resolve DSparkDraftModel (which subclasses DFlashDraftModel) on
# reload with trust_remote_code.
modeling_dir = os.path.join(
os.path.dirname(__file__), "..", "specforge", "modeling", "draft"
)
for fname in ("dspark.py", "dflash.py"):
src = os.path.join(modeling_dir, fname)
if os.path.exists(src):
shutil.copy(src, os.path.join(save_dir, fname))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this work? Isn't there other local imports that those files depend on?

Comment thread scripts/train_dspark.py
if mode == "train" and optimizer is not None:
logdict["train/lr"] = optimizer.get_learning_rate()

logdict[f"{mode}/loss"] = loss

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add "ce_loss", "l1_loss", "confidence_loss" ? I think they give us some insight about the training overall.

Comment thread specforge/core/dspark.py
Comment on lines +292 to +301
world_size = dist.get_world_size() if dist.is_initialized() else 1
global_den = local_den.detach().clone()
if world_size > 1:
dist.all_reduce(global_den, op=dist.ReduceOp.SUM)
global_den = global_den + 1e-6
loss = (
self.ce_loss_alpha * ce_num / global_den
+ self.l1_loss_alpha * l1_num / global_den
+ self.confidence_head_alpha * conf_num / global_den
) * world_size

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds right.

Comment thread specforge/core/dspark.py
)

return (
loss,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Report indiviudal losses here.

@maocheng23 maocheng23 marked this pull request as ready for review June 29, 2026 16:28
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants