feat: DSpark trainer (DFlash + Markov/confidence heads + L1 distillation)#613
feat: DSpark trainer (DFlash + Markov/confidence heads + L1 distillation)#613maocheng23 wants to merge 2 commits into
Conversation
…ion) Port of TorchSpec PR sgl-project#129 to SpecForge. Adds: - specforge/modeling/draft/dspark.py: DSparkConfig, VanillaMarkov, AcceptRatePredictor, DSparkDraftModel (subclass of DFlashDraftModel) - specforge/core/dspark.py: OnlineDSparkModel (subclass of OnlineDFlashModel) with Markov-biased logits + CE + L1 distribution distillation + confidence BCE and a pooled global-mean loss - scripts/train_dspark.py: training driver (clone of train_dflash.py) - configs/qwen3-8b-dspark.json, examples/run_qwen3_8b_dspark_online.sh - last_hidden_states surfaced from the DFlash target backends (HF + sglang) - tests/test_utils/test_dspark.py: 11 CPU unit tests Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces the DSpark draft model and its online training pipeline, which builds upon the DFlash backbone by adding EAGLE-style Markov and confidence heads along with L1 distribution distillation. The changes include new configurations, training scripts, model definitions, target model updates to surface final hidden states, and unit tests. The review feedback highlights two key issues: a bug in the dataset filtering step in train_dspark.py where calling .sum() on a Python list will cause a crash, and an incorrect loss scaling calculation in dspark.py when Tensor Parallelism is enabled, which should use the data-parallel group instead of the global world size.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| train_eagle3_dataset = train_eagle3_dataset.filter( | ||
| lambda x: x["loss_mask"].sum() >= min_loss_tokens | ||
| ) |
There was a problem hiding this comment.
The Hugging Face Dataset.filter method passes a dictionary of Python types (where sequence features like loss_mask are standard Python list objects) to the filter function when batched=False (the default). Calling .sum() on a Python list will raise an AttributeError: 'list' object has no attribute 'sum' and crash the training script. Use the built-in sum() function instead.
| train_eagle3_dataset = train_eagle3_dataset.filter( | |
| lambda x: x["loss_mask"].sum() >= min_loss_tokens | |
| ) | |
| train_eagle3_dataset = train_eagle3_dataset.filter( | |
| lambda x: sum(x["loss_mask"]) >= min_loss_tokens | |
| ) |
| world_size = dist.get_world_size() if dist.is_initialized() else 1 | ||
| global_den = local_den.detach().clone() | ||
| if world_size > 1: | ||
| dist.all_reduce(global_den, op=dist.ReduceOp.SUM) | ||
| global_den = global_den + 1e-6 | ||
| loss = ( | ||
| self.ce_loss_alpha * ce_num / global_den | ||
| + self.l1_loss_alpha * l1_num / global_den | ||
| + self.confidence_head_alpha * conf_num / global_den | ||
| ) * world_size |
There was a problem hiding this comment.
When Tensor Parallelism (TP) is enabled (e.g., via --tp-size), dist.get_world_size() returns the total number of GPUs across both DP and TP groups. However, the loss denominator should only be summed across the data parallel (DP) group, and the loss scaling factor should be dp_size rather than world_size to ensure correct gradient scaling under FSDP. Using the global world_size will scale gradients incorrectly by a factor of tp_size.
Consider using get_dp_group() from specforge.distributed to perform the all_reduce and scale the loss correctly.
| world_size = dist.get_world_size() if dist.is_initialized() else 1 | |
| global_den = local_den.detach().clone() | |
| if world_size > 1: | |
| dist.all_reduce(global_den, op=dist.ReduceOp.SUM) | |
| global_den = global_den + 1e-6 | |
| loss = ( | |
| self.ce_loss_alpha * ce_num / global_den | |
| + self.l1_loss_alpha * l1_num / global_den | |
| + self.confidence_head_alpha * conf_num / global_den | |
| ) * world_size | |
| from specforge.distributed import get_dp_group | |
| dp_group = get_dp_group() if dist.is_initialized() else None | |
| dp_size = dist.get_world_size(dp_group) if dp_group is not None else 1 | |
| global_den = local_den.detach().clone() | |
| if dp_size > 1: | |
| dist.all_reduce(global_den, op=dist.ReduceOp.SUM, group=dp_group) | |
| global_den = global_den + 1e-6 | |
| loss = ( | |
| self.ce_loss_alpha * ce_num / global_den | |
| + self.l1_loss_alpha * l1_num / global_den | |
| + self.confidence_head_alpha * conf_num / global_den | |
| ) * dp_size |
| choices=["sglang", "hf"], | ||
| help="Backend for target model: 'sglang' (service) or 'hf' (local). " | ||
| "DSpark's L1/confidence losses need the target's final hidden state; " | ||
| "the 'hf' backend always surfaces it.", |
There was a problem hiding this comment.
What does this mean? Does SGLang not return the final hidden state? It should have.
| # Copy the modeling files next to the checkpoint so auto_map can | ||
| # resolve DSparkDraftModel (which subclasses DFlashDraftModel) on | ||
| # reload with trust_remote_code. | ||
| modeling_dir = os.path.join( | ||
| os.path.dirname(__file__), "..", "specforge", "modeling", "draft" | ||
| ) | ||
| for fname in ("dspark.py", "dflash.py"): | ||
| src = os.path.join(modeling_dir, fname) | ||
| if os.path.exists(src): | ||
| shutil.copy(src, os.path.join(save_dir, fname)) | ||
|
|
There was a problem hiding this comment.
Will this work? Isn't there other local imports that those files depend on?
| if mode == "train" and optimizer is not None: | ||
| logdict["train/lr"] = optimizer.get_learning_rate() | ||
|
|
||
| logdict[f"{mode}/loss"] = loss |
There was a problem hiding this comment.
Let's add "ce_loss", "l1_loss", "confidence_loss" ? I think they give us some insight about the training overall.
| world_size = dist.get_world_size() if dist.is_initialized() else 1 | ||
| global_den = local_den.detach().clone() | ||
| if world_size > 1: | ||
| dist.all_reduce(global_den, op=dist.ReduceOp.SUM) | ||
| global_den = global_den + 1e-6 | ||
| loss = ( | ||
| self.ce_loss_alpha * ce_num / global_den | ||
| + self.l1_loss_alpha * l1_num / global_den | ||
| + self.confidence_head_alpha * conf_num / global_den | ||
| ) * world_size |
| ) | ||
|
|
||
| return ( | ||
| loss, |
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Summary
Adds DSpark draft-model training to SpecForge — a port of TorchSpec PR #129 ("DeepSpec - DSpark trainer support"). DSpark extends DFlash's block-diffusion drafter with EAGLE-style Markov (low-rank learned bigram bias) and confidence (per-position accept-rate) heads, trained with a combined objective:
|softmax(draft) − softmax(target)|, where the target distribution is the frozen target LM head applied to the target's final hidden state at the aligned position.1 − 0.5·L1) prediction, used for adaptive block length at inference.Markov/confidence formulation adapted from DeepSeek's DeepSpec (MIT).
What's added
New files
specforge/modeling/draft/dspark.py—DSparkConfig(extendsQwen3Config),VanillaMarkov,AcceptRatePredictor,build_markov_head,DSparkDraftModel(subclass ofDFlashDraftModel).specforge/core/dspark.py—OnlineDSparkModel(subclass ofOnlineDFlashModel): reuses DFlash anchor sampling / MASK-noise / block-mask machinery; adds Markov-biased logits, CE + L1 + confidence, within-block decay weighting, and a pooled global-mean loss.scripts/train_dspark.py— training driver (mirrorstrain_dflash.py); plumbslast_hidden_states, logs per-component losses, adds--max-steps.configs/qwen3-8b-dspark.json,examples/run_qwen3_8b_dspark_online.sh.tests/test_utils/test_dspark.py— 11 CPU unit tests (head math, config, internal loss identity, all-slots label convention, grad flow, CE-only path).Changed
specforge/modeling/target/dflash_target_model.py— surface the target's final hidden state asDFlashTargetOutput.last_hidden_states(HF + sglang backends). DFlash ignores it; DSpark's L1/confidence losses need it. CE-only training needs nothing new.specforge/modeling/draft/__init__.py,specforge/core/__init__.py— exports.Design notes
DFlashConfig/DFlashModel/DFlashTrainerclass family; a monolithicOnlineDFlashModel+ a flat training script; a different draft-modelforwardsignature; DFlash drops anchor slot 0 while DSpark supervises all block slots with a cumprod eval mask). DSpark is re-expressed in SpecForge idioms —DSparkConfigextendsQwen3Config(not aDFlashConfig), the loss wrapper subclassesOnlineDFlashModel, and training is a script (no trainer class).world_sizeto cancel FSDP's mean-grad reduction (correct for ZeRO-2 /SHARD_GRAD_OP, a single shard group). Verified by matching loss scale across 1-node vs 2-node runs.device/device_typeNameErrorintrain_dflash.py'smain()(carried into the new script).Testing
python -m unittest tests.test_utils.test_dspark).world_sizecorrection).confidence_losssettles/decreases (0.51→0.45) once the accept-rate target stabilizes — the earlier short-run rise was just the non-stationary target catching up. All three components healthy throughout; no NaN/OOM/instability. Checkpoints every ~1000 steps + per-epoch.Notes / follow-ups
last_hidden_statespath is implemented but not yet GPU-verified — use HF, or run CE-only (--l1-loss-alpha 0 --no-confidence-head) on sglang until verified.post_initHF init) — faithful to the reference.DSparkTrainStrategyinto the runtime path for parity with EAGLE3.🤖 Generated with Claude Code