[GRPO] Match the Liger dapo/cispo/vespo normalizer to the non-Liger path#5890
[GRPO] Match the Liger dapo/cispo/vespo normalizer to the non-Liger path#5890kashif wants to merge 2 commits into
Conversation
pass num_items_in_batch to the liger loss and skip the grad-accum division for these loss types so it matches the non-liger path
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: be793fbe91
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| old_per_token_logps=inputs.get("old_per_token_logps"), | ||
| ref_per_token_logps=inputs.get("ref_per_token_logps"), | ||
| vllm_is_ratio=inputs.get("importance_sampling_ratio"), | ||
| num_items_in_batch=inputs.get("num_items_in_batch"), |
There was a problem hiding this comment.
Require a Liger version that accepts the new kwarg
This unconditionally passes num_items_in_batch into LigerFusedLinearGRPOLoss.forward, but the repo still declares the optional dependency as liger-kernel>=0.8.0 in pyproject.toml, and Liger 0.8.0's GRPO forward signature only accepts through vllm_is_ratio. In environments satisfying the current TRL extra with Liger 0.8.0, any use_liger_kernel=True GRPO training run now fails immediately with TypeError: forward() got an unexpected keyword argument 'num_items_in_batch'; please either bump the minimum Liger version that added this argument or avoid passing it to older kernels.
Useful? React with 👍 / 👎.
Warning
Don't merge yet — depends on a new liger-kernel release.
num_items_in_batchis only in Ligermain(linkedin/Liger-Kernel#1202), not in a tagged release, so theliger-kernelpin needs bumping once that ships.The Liger and non-Liger GRPO paths normalize the dapo/cispo/vespo loss differently.
The non-Liger path divides by
num_items_in_batch / num_processes(the total tokens across the whole generation batch) and leaves gradient accumulation alone. The Liger path didn't passnum_items_in_batchat all, so it fell back to the current micro-batch's mask, and then divided again bygradient_accumulation_steps. So the two drift apart once you have grad accumulation with uneven completion lengths.This passes
num_items_in_batchthrough to the Liger loss and drops the extra grad-accum division for those three loss types. Everything else keeps dividing as before.It's the last bit of the parity work tracked in linkedin/Liger-Kernel#1082 (the IS correction, tool_mask, delta and bias-correction KL were already sorted).
Added a test that runs dapo with grad accumulation > 1 and checks the arg gets forwarded and the loss isn't divided twice.