Skip to content

AsyncGRPOTrainer: add model_init_kwargs support#5893

Open
rycerzes wants to merge 1 commit into
huggingface:mainfrom
rycerzes:feat/async-grpo-model-init-kwargs
Open

AsyncGRPOTrainer: add model_init_kwargs support#5893
rycerzes wants to merge 1 commit into
huggingface:mainfrom
rycerzes:feat/async-grpo-model-init-kwargs

Conversation

@rycerzes
Copy link
Copy Markdown
Contributor

@rycerzes rycerzes commented May 31, 2026

What does this PR do?

Adds model_init_kwargs support to AsyncGRPOTrainer, matching the existing behavior in GRPOTrainer.

Currently, AsyncGRPOTrainer hardcodes AutoModelForCausalLM.from_pretrained(model, device_map=None, dtype=torch.float32), which prevents passing custom kwargs like attn_implementation or dtype. This PR:

  • Adds model_init_kwargs field to AsyncGRPOConfig (with _VALID_DICT_FIELDS for dict serialization support)
  • Replaces direct AutoModelForCausalLM loading with create_model_from_path(model, **model_init_kwargs), which infers the correct architecture class from the model config (important for models like Qwen3.5 that use Qwen3_5ForConditionalGeneration instead of a standard CausalLM)
  • Enforces device_map=None for FSDP2 compatibility

Changes:

  • async_grpo_config.py: add model_init_kwargs field + docstring + _VALID_DICT_FIELDS
  • async_grpo_trainer.py: replace AutoModelForCausalLM with create_model_from_path, update imports

Tests:

  • test_model_init_kwargs: verifies that model_init_kwargs={"dtype": "bfloat16"} results in bfloat16 model parameters

Closes part of #5831

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

AI writing disclosure

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Who can review?

@qgallouedec


Note

Low Risk
Small, localized change to model instantiation with an existing shared helper and a focused unit test; no auth or rollout pipeline logic changes.

Overview
AsyncGRPOTrainer can now take optional model_init_kwargs on AsyncGRPOConfig, aligned with GRPOTrainer, so string model IDs accept from_pretrained options (e.g. dtype, attn_implementation) instead of a fixed AutoModelForCausalLM load with dtype=torch.float32.

Loading goes through create_model_from_path, which picks the architecture from the hub config (not only causal LM). device_map=None is always applied for FSDP2. A test asserts bfloat16 parameters when model_init_kwargs={"dtype": "bfloat16"}.

Reviewed by Cursor Bugbot for commit b925002. Bugbot is set up for automated code reviews on this repo. Configure here.

Add model_init_kwargs to AsyncGRPOConfig and use create_model_from_path
in the trainer, matching GRPOTrainer's model loading behavior. This
allows passing custom kwargs (e.g., attn_implementation, dtype) when
loading the model from a string.

Closes part of huggingface#5831
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant