[Feature] Train infer disaggregated#523
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a Ray-based distributed architecture for SpecForge, enabling both colocated and disaggregated (训推分离) training modes. The changes include new Ray-based worker groups for rollout and training, a centralized orchestrator, and support for NCCL-based GPU-to-GPU data transfer. My feedback highlights performance bottlenecks in the rollout dispatch logic, potential runtime errors in the DataCollator initialization, risks associated with clearing the global device mesh, the need for robust error handling when waiting for distributed workers, and unnecessary synchronization in the data transfer utility.
| for dp_idx in range(dp_size): | ||
| data_batch, actual_count = self._fetch_multi_local( | ||
| self._rollout_batch_size | ||
| ) | ||
| if data_batch is None: | ||
| break | ||
| per_dp_count = actual_count | ||
|
|
||
| send_ref = self.rollout_group.generate_and_send_single( | ||
| tp_idx, data_batch, [sp_leader_ranks[dp_idx]] | ||
| ) | ||
| send_refs.append(send_ref) |
There was a problem hiding this comment.
The current implementation performs dp_size separate forward passes on the target model per logical training step. Since the target model is typically much larger than the draft model, this creates a significant performance bottleneck, especially as the number of DP groups increases.
Consider batching all dp_size requests into a single forward pass on the RolloutWorkerGroup (with a total batch size of dp_size * rollout_batch_size), then sharding and sending the results to the respective TrainWorker groups. This would leverage GPU parallelism much more effectively for the target model inference.
| def __init__(self, sp_degree=None, ulysses_degree=None): | ||
| if sp_degree is not None: | ||
| self.sp_degree = sp_degree | ||
| else: | ||
| self.sp_degree = torch.distributed.get_world_size(get_draft_sp_group()) | ||
| if ulysses_degree is not None: | ||
| self.ulysses_degree = ulysses_degree | ||
| else: | ||
| self.ulysses_degree = torch.distributed.get_world_size( | ||
| get_sp_ulysses_group() | ||
| ) |
There was a problem hiding this comment.
Calling torch.distributed.get_world_size() in the constructor of DataCollatorWithPadding will raise a RuntimeError if the collator is instantiated in a process where torch.distributed is not yet initialized (e.g., the driver process during dataset pre-building or in the orchestrator before workers are launched).
While the current RayOrchestrator passes these values explicitly, other utility functions like prepare_dp_dataloaders use the default constructor, which could lead to crashes if called outside a distributed context. Consider deferring the world size check until the first call to __call__ or providing safe defaults.
| _SP_RING_GROUP = PROCESS_GROUP.RING_PG if sp_size > 1 else my_draft_sp_group | ||
| _TP_DEVICE_MESH = dist.DeviceMesh.from_group(my_tp_group, device_type="cuda") | ||
| _DP_DEVICE_MESH = dist.DeviceMesh.from_group(my_dp_group, device_type="cuda") | ||
| _DEVICE_MESH = None # 2D mesh not available in subgroup mode |
There was a problem hiding this comment.
Setting _DEVICE_MESH = None in init_distributed_from_subgroup might cause failures in other parts of the codebase that rely on get_device_mesh(). While 1D meshes (_TP_DEVICE_MESH, _DP_DEVICE_MESH) are initialized, some FSDP configurations or monitoring tools in the existing codebase might expect the global 2D mesh to be present.
| if self._enable_perf: | ||
| t3 = time.perf_counter() | ||
|
|
||
| metrics = ray.get(train_refs[0]) |
There was a problem hiding this comment.
ray.get(train_refs[0]) only waits for the first worker (rank 0) to complete. If any other worker in the distributed group encounters an error or is significantly slower, the orchestrator may proceed to the next step prematurely or hang in subsequent collective operations, making debugging difficult.
It is safer to wait for all workers to ensure consistency and catch exceptions occurring on non-zero ranks.
| metrics = ray.get(train_refs[0]) | |
| metrics_list = ray.get(train_refs) | |
| metrics = metrics_list[0] |
| position_ids=_to(batch.position_ids), | ||
| ) | ||
| if needs_sync: | ||
| torch.cuda.synchronize() |
There was a problem hiding this comment.
torch.cuda.synchronize() is a heavy operation that stalls the CPU until all GPU tasks are finished, which can reduce the benefits of using non_blocking=True for overlapping transfers.
Since this is called immediately before the forward pass, you can rely on the default stream's serialization or use CUDA events for more fine-grained synchronization if multiple streams are involved.
|
need to add |
|
Hi @jiapingW — apologies, I only came across this PR today, otherwise I'd have chimed in earlier 😂. I'd independently been working on the same train/inference disaggregation problem in #573. The two take fairly different approaches: yours is built on Ray with orchestrated worker groups, while #573 is Ray-free and centers on a remote target-serving design with an async prefetch pipeline (configurable depth + multi-server round-robin scheduling) to overlap target inference with draft training — we measured up to ~2.37x speedup with dual-server prefetch. It also adds GPU-direct NCCL transfer with a wire-format fallback, TP broadcast, and all-to-all sharding (committed at https://github.com/moehanabi/SpecForge/tree/remote_train_sharded_nccl and has not merged to my pr now). |
Hi, it's a good job! Now we are developing train and infer disaggreation. We are refactoring the code to make it more maintainable. This feature will be completed in the next two days, and we welcome your further optimizations at that time. |
|
We also hope to decouple the system and improve the efficiency of online training by running the model via an SGL server instead of an SGL model instance. |
great work! hope I can see it soon! |
I saw many runtime-related pr such as #618 . Are they all about this "refactoring the code"? Are you working together for it? |
Motivation
Modifications
Related Issues
Accuracy Test
Benchmark & Profiling
Checklist