Skip to content

[Feature] Train infer disaggregated#523

Open
jiapingW wants to merge 3 commits into
mainfrom
train_infer_disaggre
Open

[Feature] Train infer disaggregated#523
jiapingW wants to merge 3 commits into
mainfrom
train_infer_disaggre

Conversation

@jiapingW

@jiapingW jiapingW commented Apr 2, 2026

Copy link
Copy Markdown
Collaborator

Motivation

Modifications

Related Issues

Accuracy Test

Benchmark & Profiling

Checklist

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Ray-based distributed architecture for SpecForge, enabling both colocated and disaggregated (训推分离) training modes. The changes include new Ray-based worker groups for rollout and training, a centralized orchestrator, and support for NCCL-based GPU-to-GPU data transfer. My feedback highlights performance bottlenecks in the rollout dispatch logic, potential runtime errors in the DataCollator initialization, risks associated with clearing the global device mesh, the need for robust error handling when waiting for distributed workers, and unnecessary synchronization in the data transfer utility.

Comment on lines +173 to +184
for dp_idx in range(dp_size):
data_batch, actual_count = self._fetch_multi_local(
self._rollout_batch_size
)
if data_batch is None:
break
per_dp_count = actual_count

send_ref = self.rollout_group.generate_and_send_single(
tp_idx, data_batch, [sp_leader_ranks[dp_idx]]
)
send_refs.append(send_ref)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation performs dp_size separate forward passes on the target model per logical training step. Since the target model is typically much larger than the draft model, this creates a significant performance bottleneck, especially as the number of DP groups increases.

Consider batching all dp_size requests into a single forward pass on the RolloutWorkerGroup (with a total batch size of dp_size * rollout_batch_size), then sharding and sending the results to the respective TrainWorker groups. This would leverage GPU parallelism much more effectively for the target model inference.

Comment thread specforge/data/utils.py
Comment on lines +37 to +47
def __init__(self, sp_degree=None, ulysses_degree=None):
if sp_degree is not None:
self.sp_degree = sp_degree
else:
self.sp_degree = torch.distributed.get_world_size(get_draft_sp_group())
if ulysses_degree is not None:
self.ulysses_degree = ulysses_degree
else:
self.ulysses_degree = torch.distributed.get_world_size(
get_sp_ulysses_group()
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling torch.distributed.get_world_size() in the constructor of DataCollatorWithPadding will raise a RuntimeError if the collator is instantiated in a process where torch.distributed is not yet initialized (e.g., the driver process during dataset pre-building or in the orchestrator before workers are launched).

While the current RayOrchestrator passes these values explicitly, other utility functions like prepare_dp_dataloaders use the default constructor, which could lead to crashes if called outside a distributed context. Consider deferring the world size check until the first call to __call__ or providing safe defaults.

Comment thread specforge/distributed.py
_SP_RING_GROUP = PROCESS_GROUP.RING_PG if sp_size > 1 else my_draft_sp_group
_TP_DEVICE_MESH = dist.DeviceMesh.from_group(my_tp_group, device_type="cuda")
_DP_DEVICE_MESH = dist.DeviceMesh.from_group(my_dp_group, device_type="cuda")
_DEVICE_MESH = None # 2D mesh not available in subgroup mode

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Setting _DEVICE_MESH = None in init_distributed_from_subgroup might cause failures in other parts of the codebase that rely on get_device_mesh(). While 1D meshes (_TP_DEVICE_MESH, _DP_DEVICE_MESH) are initialized, some FSDP configurations or monitoring tools in the existing codebase might expect the global 2D mesh to be present.

if self._enable_perf:
t3 = time.perf_counter()

metrics = ray.get(train_refs[0])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

ray.get(train_refs[0]) only waits for the first worker (rank 0) to complete. If any other worker in the distributed group encounters an error or is significantly slower, the orchestrator may proceed to the next step prematurely or hang in subsequent collective operations, making debugging difficult.

It is safer to wait for all workers to ensure consistency and catch exceptions occurring on non-zero ranks.

Suggested change
metrics = ray.get(train_refs[0])
metrics_list = ray.get(train_refs)
metrics = metrics_list[0]

position_ids=_to(batch.position_ids),
)
if needs_sync:
torch.cuda.synchronize()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

torch.cuda.synchronize() is a heavy operation that stalls the CPU until all GPU tasks are finished, which can reduce the benefits of using non_blocking=True for overlapping transfers.

Since this is called immediately before the forward pass, you can rely on the default stream's serialization or use CUDA events for more fine-grained synchronization if multiple streams are involved.

@FrankLeeeee

Copy link
Copy Markdown
Collaborator

need to add ray to pyproject.toml.

@moehanabi

Copy link
Copy Markdown
Contributor

Hi @jiapingW — apologies, I only came across this PR today, otherwise I'd have chimed in earlier 😂. I'd independently been working on the same train/inference disaggregation problem in #573.

The two take fairly different approaches: yours is built on Ray with orchestrated worker groups, while #573 is Ray-free and centers on a remote target-serving design with an async prefetch pipeline (configurable depth + multi-server round-robin scheduling) to overlap target inference with draft training — we measured up to ~2.37x speedup with dual-server prefetch. It also adds GPU-direct NCCL transfer with a wire-format fallback, TP broadcast, and all-to-all sharding (committed at https://github.com/moehanabi/SpecForge/tree/remote_train_sharded_nccl and has not merged to my pr now).

@jiapingW

Copy link
Copy Markdown
Collaborator Author

Hi @jiapingW — apologies, I only came across this PR today, otherwise I'd have chimed in earlier 😂. I'd independently been working on the same train/inference disaggregation problem in #573.

The two take fairly different approaches: yours is built on Ray with orchestrated worker groups, while #573 is Ray-free and centers on a remote target-serving design with an async prefetch pipeline (configurable depth + multi-server round-robin scheduling) to overlap target inference with draft training — we measured up to ~2.37x speedup with dual-server prefetch. It also adds GPU-direct NCCL transfer with a wire-format fallback, TP broadcast, and all-to-all sharding (committed at https://github.com/moehanabi/SpecForge/tree/remote_train_sharded_nccl and has not merged to my pr now).

Hi, it's a good job! Now we are developing train and infer disaggreation. We are refactoring the code to make it more maintainable. This feature will be completed in the next two days, and we welcome your further optimizations at that time.

@jiapingW

Copy link
Copy Markdown
Collaborator Author

We also hope to decouple the system and improve the efficiency of online training by running the model via an SGL server instead of an SGL model instance.

@moehanabi

Copy link
Copy Markdown
Contributor

We also hope to decouple the system and improve the efficiency of online training by running the model via an SGL server instead of an SGL model instance.

great work! hope I can see it soon!

@moehanabi

moehanabi commented Jun 30, 2026

Copy link
Copy Markdown
Contributor

We also hope to decouple the system and improve the efficiency of online training by running the model via an SGL server instead of an SGL model instance.

I saw many runtime-related pr such as #618 . Are they all about this "refactoring the code"? Are you working together for it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants