diff --git a/tests/rl/test_rollout_logic.py b/tests/rl/test_rollout_logic.py index cb3f73a47..8788029c0 100644 --- a/tests/rl/test_rollout_logic.py +++ b/tests/rl/test_rollout_logic.py @@ -23,6 +23,7 @@ from xtuner.v1.rl.agent_loop import AgentLoopConfig from xtuner.v1.rl.rollout.controller import RolloutController from xtuner.v1.rl.rollout.health_manager import RolloutHealthManager +from xtuner.v1.rl.rollout.rollout_topology import RolloutTopology from xtuner.v1.rl.rollout.proxy_manager import RolloutProxyManager from xtuner.v1.rl.rollout.worker_registry import RolloutWorkerRegistry, WorkerLifecycleState, WorkerSnapshot from xtuner.v1.rl.rollout.sglang import SGLangWorker @@ -122,6 +123,45 @@ def test_trainer_auto_enables_rollout_proxy_when_agent_loop_requires_it(self): self.assertTrue(trainer._rollout_config.enable_proxy) +class TestRolloutTopologyAPI(unittest.TestCase): + def test_rollout_topology_resolves_engine_dist_init_addr_when_created(self): + rank_to_dist_init_addr = {0: "host0:25000", 1: "host1:25004"} + dist_init_addr_owner_rank = 0 + engine = RolloutTopology.engine( + engine_ranks=(0, 1), + dist_init_addr=rank_to_dist_init_addr[dist_init_addr_owner_rank], + server_processes=( + RolloutTopology.server_process( + worker_rank=0, + placement_group_bundle_idxs=(0,), + accepts_rollout_requests=True, + ), + RolloutTopology.server_process( + worker_rank=1, + placement_group_bundle_idxs=(1,), + accepts_rollout_requests=False, + ), + ), + ) + + topology = RolloutTopology( + engines=(engine,), + training_engine_mesh=((0, 1),), + ) + + launch_specs = topology.server_launch_specs() + self.assertEqual(tuple(spec.worker_rank for spec in launch_specs), (0, 1)) + rank_0_launch_spec, rank_1_launch_spec = launch_specs + self.assertEqual(rank_0_launch_spec.dist_init_addr, "host0:25000") + self.assertEqual(rank_1_launch_spec.dist_init_addr, "host0:25000") + self.assertEqual(rank_0_launch_spec.engine_rank, 0) + self.assertEqual(rank_1_launch_spec.engine_rank, 1) + self.assertEqual(rank_1_launch_spec.placement_group_bundle_idxs, (1,)) + self.assertTrue(topology.is_request_entrypoint_rank(0)) + self.assertFalse(topology.is_request_entrypoint_rank(1)) + self.assertEqual(topology.lifecycle_group_for_server_rank(1), (0, 1)) + + class TestRolloutController(unittest.IsolatedAsyncioTestCase): def _state(self, uid: int, session_id: int) -> RolloutState: return RolloutState( @@ -142,6 +182,28 @@ def _build_controller(self, router): controller.logger = MagicMock() return controller + def _build_registry(self, ranks): + rollout_topology = RolloutTopology( + engines=tuple( + RolloutTopology.engine( + engine_ranks=(rank,), + dist_init_addr=f"addr{rank}", + server_processes=( + RolloutTopology.server_process( + worker_rank=rank, + placement_group_bundle_idxs=(rank,), + ), + ), + ) + for rank in ranks + ), + training_engine_mesh=tuple((rank,) for rank in ranks), + ) + return RolloutWorkerRegistry( + rollout_topology=rollout_topology, + rollout_config=SimpleNamespace(), + ) + async def test_generate_fails_fast_when_no_active_worker(self): # router 找不到 active worker 时,controller 应直接把原样本标成 FAILED,避免请求悬挂。 state = self._state(uid=1, session_id=123) @@ -175,20 +237,18 @@ async def test_generate_routes_to_active_worker(self): def test_register_active_workers_to_proxy_delegates_active_session_urls(self): controller = RolloutController.__new__(RolloutController) - controller.registry = RolloutWorkerRegistry(engine_rank_mesh_array=[[0], [1]], rollout_config=SimpleNamespace()) + controller.registry = self._build_registry((0, 1)) controller.registry.register_started_server( rank=0, actor=object(), server_url="http://worker-0", session_url="http://session-0", - is_request_entrypoint=True, ) controller.registry.register_started_server( rank=1, actor=object(), server_url="http://worker-1", session_url="http://session-1", - is_request_entrypoint=True, ) controller.registry.mark_unhealthy_ranks({1}) controller.proxy_manager = MagicMock() @@ -199,7 +259,7 @@ def test_register_active_workers_to_proxy_delegates_active_session_urls(self): def test_register_active_workers_to_proxy_noops_without_proxy_manager(self): controller = RolloutController.__new__(RolloutController) - controller.registry = RolloutWorkerRegistry(engine_rank_mesh_array=[], rollout_config=SimpleNamespace()) + controller.registry = self._build_registry(()) controller.proxy_manager = None controller.register_active_workers_to_proxy() @@ -360,34 +420,65 @@ class TestRolloutWorkerRegistry(unittest.TestCase): def _worker_by_rank(self, registry, rank): return next(worker for worker in registry.all_workers() if worker.rank == rank) - def test_registry_filters_entrypoints_and_builds_metadata_snapshot(self): + def _runtime_topology( + self, + *, + engine_ranks=(0,), + server_processes=None, + training_engine_mesh=None, + ): + if server_processes is None: + server_processes = ( + RolloutTopology.server_process( + worker_rank=engine_ranks[0], + placement_group_bundle_idxs=tuple(range(len(engine_ranks))), + accepts_rollout_requests=True, + ), + ) + dist_init_addr_owner_rank = server_processes[0].worker_rank + return RolloutTopology( + engines=( + RolloutTopology.engine( + engine_ranks=tuple(engine_ranks), + dist_init_addr=f"addr{dist_init_addr_owner_rank}", + server_processes=tuple(server_processes), + ), + ), + training_engine_mesh=tuple(training_engine_mesh or (tuple(engine_ranks),)), + ) + + def test_registry_filters_entrypoints_and_tracks_lifecycle(self): config = SimpleNamespace() - registry = RolloutWorkerRegistry(engine_rank_mesh_array=[[0, 1]], rollout_config=config) + runtime_topology = self._runtime_topology( + engine_ranks=(0, 1), + server_processes=( + RolloutTopology.server_process( + worker_rank=0, + placement_group_bundle_idxs=(0,), + accepts_rollout_requests=True, + ), + RolloutTopology.server_process( + worker_rank=1, + placement_group_bundle_idxs=(1,), + accepts_rollout_requests=False, + ), + ), + training_engine_mesh=((0, 1),), + ) + registry = RolloutWorkerRegistry(rollout_topology=runtime_topology, rollout_config=config) registry.register_started_server( rank=0, actor=object(), server_url="http://worker-0", session_url="http://session-0", - lifecycle_group_ranks=(0, 1), - is_request_entrypoint=True, ) registry.register_started_server( rank=1, actor=object(), server_url="http://worker-1", session_url=None, - lifecycle_group_ranks=(0, 1), - is_request_entrypoint=False, ) - metadata = registry.training_metadata_snapshot() - - self.assertEqual(metadata["engine_rank_mesh_array"], [[0, 1]]) - self.assertIs(metadata["rollout_config"], config) - self.assertEqual(metadata["server_url_dict"], {0: "http://worker-0"}) - self.assertEqual(metadata["worker_server_urls_status"], {"http://worker-0": True}) - self.assertEqual(metadata["worker_session_url_dict"], {0: "http://session-0"}) - self.assertEqual(metadata["worker_session_urls_status"], {"http://session-0": True}) active_entrypoint = registry.active_entrypoints()[0] self.assertIsInstance(active_entrypoint, WorkerSnapshot) self.assertEqual(active_entrypoint.rank, 0) @@ -395,11 +486,8 @@ def test_registry_filters_entrypoints_and_builds_metadata_snapshot(self): active_entrypoint.lifecycle_state = WorkerLifecycleState.INACTIVE unhealthy_groups = registry.mark_unhealthy_ranks({0}) - metadata = registry.training_metadata_snapshot() self.assertEqual(unhealthy_groups[0].ranks, (0, 1)) - self.assertEqual(metadata["worker_server_urls_status"], {"http://worker-0": False}) - self.assertEqual(metadata["worker_session_urls_status"], {"http://session-0": False}) self.assertEqual(tuple(worker.rank for worker in registry.inactive_workers()), (0, 1)) self.assertEqual(registry.active_entrypoints(), ()) claimed_groups = registry.claim_inactive_groups_for_recovery() @@ -408,13 +496,46 @@ def test_registry_filters_entrypoints_and_builds_metadata_snapshot(self): registry.set_group_recovery_result(claimed_groups[0], recovered=False) self.assertEqual(self._worker_by_rank(registry, 0).lifecycle_state, WorkerLifecycleState.INACTIVE) + class TestSessionRouter(unittest.IsolatedAsyncioTestCase): async def test_sticky_session_reselects_when_previous_entrypoint_is_inactive(self): actor_0 = object() actor_1 = object() - registry = RolloutWorkerRegistry(engine_rank_mesh_array=[[0], [1]], rollout_config=SimpleNamespace()) - registry.register_started_server(rank=0, actor=actor_0, server_url="http://worker-0") - registry.register_started_server(rank=1, actor=actor_1, server_url="http://worker-1") + rollout_topology = RolloutTopology( + engines=( + RolloutTopology.engine( + engine_ranks=(0,), + dist_init_addr="addr0", + server_processes=( + RolloutTopology.server_process(worker_rank=0, placement_group_bundle_idxs=(0,)), + ), + ), + RolloutTopology.engine( + engine_ranks=(1,), + dist_init_addr="addr1", + server_processes=( + RolloutTopology.server_process(worker_rank=1, placement_group_bundle_idxs=(1,)), + ), + ), + ), + training_engine_mesh=((0,), (1,)), + ) + registry = RolloutWorkerRegistry( + rollout_topology=rollout_topology, + rollout_config=SimpleNamespace(), + ) + registry.register_started_server( + rank=0, + actor=actor_0, + server_url="http://worker-0", + session_url="http://session-0", + ) + registry.register_started_server( + rank=1, + actor=actor_1, + server_url="http://worker-1", + session_url="http://session-1", + ) router = SessionRouter(registry, max_idle_seconds=None) self.assertIs(await router.get_worker(7), actor_0) @@ -667,19 +788,36 @@ def _worker_by_rank(self, registry, rank): return next(worker for worker in registry.all_workers() if worker.rank == rank) def _build_registry(self, workers_info): + engines = [] + for rank in sorted(workers_info): + engines.append( + RolloutTopology.engine( + engine_ranks=(rank,), + dist_init_addr=f"addr{rank}", + server_processes=( + RolloutTopology.server_process( + worker_rank=rank, + placement_group_bundle_idxs=(rank,), + accepts_rollout_requests=True, + ), + ), + ) + ) + rollout_topology = RolloutTopology( + engines=tuple(engines), + training_engine_mesh=tuple((rank,) for rank in sorted(workers_info)), + ) registry = RolloutWorkerRegistry( - engine_rank_mesh_array=[sorted(workers_info)], + rollout_topology=rollout_topology, rollout_config=SimpleNamespace(), ) for rank, worker_info in workers_info.items(): - lifecycle_group_ranks = worker_info.lifecycle_group_ranks or (rank,) registry.register_started_server( rank=rank, actor=worker_info.actor, server_url=worker_info.url, - session_url=worker_info.session_url, - lifecycle_group_ranks=lifecycle_group_ranks, - is_request_entrypoint=worker_info.is_request_entrypoint, + session_url=worker_info.session_url or f"http://session-{rank}", + lifecycle_state=worker_info.lifecycle_state, ) if worker_info.lifecycle_state is WorkerLifecycleState.INACTIVE: registry.mark_unhealthy_ranks({rank}) @@ -710,7 +848,7 @@ def _build_manager( def test_marks_worker_inactive_after_consecutive_health_failures(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(False)) - worker_info = WorkerSnapshot(actor=actor, url="http://worker-0") + worker_info = WorkerSnapshot(rank=0, actor=actor, url="http://worker-0") workers_info = {0: worker_info} inactive_groups = [] listener = SimpleNamespace( @@ -738,7 +876,7 @@ def test_marks_worker_inactive_after_consecutive_health_failures(self): def test_inactive_listener_runs_under_operation_lock(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(False)) - worker_info = WorkerSnapshot(actor=actor, url="http://worker-0") + worker_info = WorkerSnapshot(rank=0, actor=actor, url="http://worker-0") lock_acquired_by_listener = [] manager, _ = self._build_manager({0: worker_info}, failure_threshold=1) @@ -764,6 +902,7 @@ def test_inactive_worker_is_not_cleaned_up_again(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) workers_info = { 0: WorkerSnapshot( + rank=0, actor=actor, url="http://worker-0", lifecycle_state=WorkerLifecycleState.INACTIVE, @@ -779,7 +918,7 @@ def test_inactive_worker_is_not_cleaned_up_again(self): def test_health_check_threshold_zero_disables_periodic_health_check(self): # threshold <= 0 表示关闭周期健康监测,不应把 active worker 直接判 inactive。 actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(False)) - worker_info = WorkerSnapshot(actor=actor, url="http://worker-0") + worker_info = WorkerSnapshot(rank=0, actor=actor, url="http://worker-0") manager, registry = self._build_manager({0: worker_info}, failure_threshold=0) checked_count = manager._check_and_deactivate_failed_worker_groups() @@ -790,7 +929,7 @@ def test_health_check_threshold_zero_disables_periodic_health_check(self): def test_fail_fast_health_check_still_runs_when_periodic_health_check_is_disabled(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(False)) - worker_info = WorkerSnapshot(actor=actor, url="http://worker-0") + worker_info = WorkerSnapshot(rank=0, actor=actor, url="http://worker-0") manager, registry = self._build_manager({0: worker_info}, failure_threshold=0) checked_count = manager._check_and_deactivate_failed_worker_groups(fail_fast=True) @@ -801,7 +940,7 @@ def test_fail_fast_health_check_still_runs_when_periodic_health_check_is_disable def test_health_check_uses_configured_timeout(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) - worker_info = WorkerSnapshot(actor=actor, url="http://worker-0") + worker_info = WorkerSnapshot(rank=0, actor=actor, url="http://worker-0") manager, _ = self._build_manager({0: worker_info}, check_timeout=2.5) observed_timeouts = [] @@ -817,6 +956,7 @@ async def fake_wait_for(awaitable, timeout): def test_shutdown_barrier_keeps_failed_shutdown_group_inactive(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) worker_info = WorkerSnapshot( + rank=0, actor=actor, url="http://worker-0", lifecycle_state=WorkerLifecycleState.INACTIVE, @@ -838,6 +978,7 @@ def test_shutdown_barrier_keeps_failed_shutdown_group_inactive(self): def test_restart_barrier_keeps_failed_recovery_group_inactive(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) worker_info = WorkerSnapshot( + rank=0, actor=actor, url="http://worker-0", lifecycle_state=WorkerLifecycleState.INACTIVE, @@ -859,6 +1000,7 @@ def test_restart_barrier_keeps_failed_recovery_group_inactive(self): def test_restart_barrier_notifies_recovered_group_after_success(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) worker_info = WorkerSnapshot( + rank=0, actor=actor, url="http://worker-0", session_url="http://session-0", @@ -884,6 +1026,7 @@ def test_restart_barrier_notifies_recovered_group_after_success(self): def test_recovered_listener_runs_under_operation_lock(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) worker_info = WorkerSnapshot( + rank=0, actor=actor, url="http://worker-0", lifecycle_state=WorkerLifecycleState.INACTIVE, diff --git a/xtuner/v1/rl/rollout/controller.py b/xtuner/v1/rl/rollout/controller.py index ae4572334..17448f8c9 100644 --- a/xtuner/v1/rl/rollout/controller.py +++ b/xtuner/v1/rl/rollout/controller.py @@ -19,7 +19,7 @@ RolloutConfig, get_rollout_worker_base_cls, ) -from .worker_registry import RolloutWorkerMetadata, RolloutWorkerRegistry +from .worker_registry import RolloutWorkerRegistry # Keep this as a Ray actor because Ray AgentLoop actors need a shared, cross-process handle to the same controller @@ -61,19 +61,15 @@ def __init__( registry=self.registry, worker_lifecycle_listeners=[self.proxy_manager] if self.proxy_manager is not None else None, ) - self.health_manager.start() + self.health_manager.start_background_checks() - def get_rollout_metadata(self) -> RolloutWorkerMetadata: + def get_rollout_metadata(self) -> dict: """Get information about the current rollout setup. Returns: - dict: A dictionary containing the engine mesh list, server URL - dictionary, and the rollout configuration. + Legacy trainer/update-weight rollout metadata dictionary. """ - rollout_metadata = self.registry.training_metadata_snapshot() - self.logger.info(f"Rollout worker server URLs: {rollout_metadata['server_url_dict']}") - self.logger.info(f"Rollout worker session server URLs: {rollout_metadata['worker_session_url_dict']}") - return rollout_metadata + return self.registry.metadata().to_legacy() def register_active_workers_to_proxy(self) -> None: if self.proxy_manager is None: @@ -133,7 +129,7 @@ def set_enable_partial_rollout(self, enable: bool) -> None: ) def pause_generation(self): - self.health_manager.pause() + self.health_manager.pause_background_checks() active_workers = self.registry.active_workers() futures = [ worker.actor.pause_generation.remote() # type: ignore[attr-defined] @@ -164,7 +160,7 @@ async def restart_inactive_workers(self): def continue_generation(self): self._broadcast_to_active_workers("continue_generation") - self.health_manager.resume() + self.health_manager.resume_background_checks() def offload(self): self._broadcast_to_active_workers("offload") @@ -181,7 +177,7 @@ def onload_kvcache(self): def shutdown(self): """Shut down all rollout workers tracked by the controller.""" - self.health_manager.stop() + self.health_manager.stop_background_checks() actors = self.registry.all_actors() ray.get( [actor.shutdown.remote(stop_session_server=True) for actor in actors], # type: ignore[attr-defined] @@ -203,13 +199,16 @@ def _build_remote_worker_cls(self, worker_base_cls): }, )(worker_base_cls) - def _init_workers(self, placement_group: PlacementGroup) -> RolloutWorkerRegistry: + def _init_workers( + self, + placement_group: PlacementGroup, + ) -> RolloutWorkerRegistry: """Initializes and configures the pool of RolloutWorker actors. This method follows the same high-level flow as the legacy implementation: - create workers, initialize worker-local ports, build engine groups, - select workers that launch rollout servers, launch servers, and - expose request-entrypoint server URLs to rollout traffic. + create workers, initialize worker-local ports, build the bound rollout + topology, launch rollout servers, and expose request-entrypoint server + URLs to rollout traffic. Returns: A registry containing all server-process workers and the public @@ -222,79 +221,63 @@ def _init_workers(self, placement_group: PlacementGroup) -> RolloutWorkerRegistr workers, rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group( worker_cls, self.config, placement_group ) - rank_to_actor = {rank: worker for (rank, _), worker in zip(rank_bundle_idx_list, workers)} - - # Reserve worker-local ports for all actors first. build_engine_launch_specs - # uses the returned addresses to bind each ServerProcessSpec to its - # logical engine rendezvous address; only server-process owners call init(). - rank_to_dist_init_addr = { - rank: dist_init_addr - for (rank, _), dist_init_addr in zip( - rank_bundle_idx_list, - ray.get([worker.init_dist_port.remote() for worker in workers]), # type: ignore[attr-defined] - ) + dist_init_results = ray.get( + [ + worker.init_dist_port.remote() # type: ignore[attr-defined] + for worker in workers + ] + ) + rank_to_worker = { + rank: worker for worker, (rank, _dist_init_addr) in zip(workers, dist_init_results, strict=True) } + rank_to_dist_init_addr = dict(dist_init_results) - # Build engine groups and server-process specs from the rank/bundle mapping. - engine_launch_specs = worker_base_cls.build_engine_launch_specs( + rollout_topology = worker_base_cls.build_rollout_topology( self.config, rank_bundle_idx_list, rank_to_dist_init_addr, ) - # Keep the public metadata mesh compatible with origin/main. Backends - # may expose a different update-weight mesh than their internal launch - # topology, e.g. LMDeploy EP has one logical engine but one public entry - # per request-serving EP rank. - engine_rank_mesh_array = worker_base_cls.build_metadata_engine_rank_mesh_array(engine_launch_specs) - - # Launch every server process described by the backend-specific specs. - server_rank_to_url = dict( - ray.get( - [ - rank_to_actor[server_process.worker_rank].init.remote( # type: ignore[attr-defined] - engine_launch_spec=engine_spec, - ) - for engine_spec in engine_launch_specs - for server_process in engine_spec.server_processes - ] - ) + server_launch_specs = rollout_topology.server_launch_specs() + server_workers = tuple( + (launch_spec, rank_to_worker[launch_spec.worker_rank]) for launch_spec in server_launch_specs ) - session_url_by_rank = dict( + + ray.get( + [ + worker.bind_server_launch_spec.remote(launch_spec) # type: ignore[attr-defined] + for launch_spec, worker in server_workers + ] + ) + init_results = tuple( ray.get( [ - ( - rank_to_actor[server_process.worker_rank].get_session_server_info.remote() # type: ignore[attr-defined] - ) - for engine_spec in engine_launch_specs - for server_process in engine_spec.server_processes + worker.init.remote() # type: ignore[attr-defined] + for _launch_spec, worker in server_workers ] ) ) - - registry = RolloutWorkerRegistry( - engine_rank_mesh_array=engine_rank_mesh_array, - rollout_config=self.config, - ) - for engine_spec in engine_launch_specs: - for server_process in engine_spec.server_processes: - rank = server_process.worker_rank - url = server_rank_to_url[rank] - session_url = session_url_by_rank.get(rank) - if server_process.accepts_rollout_requests and session_url is None: - raise RuntimeError(f"Rollout worker rank={rank} did not return session server URL during init.") - registry.register_started_server( - rank=rank, - actor=rank_to_actor[rank], - server_url=url, - session_url=session_url, - lifecycle_group_ranks=engine_spec.server_worker_ranks, - is_request_entrypoint=server_process.accepts_rollout_requests, + registry = RolloutWorkerRegistry(rollout_topology=rollout_topology, rollout_config=self.config) + for init_result in init_results: + if rollout_topology.is_request_entrypoint_rank(init_result.rank) and init_result.session_url is None: + raise RuntimeError( + f"Rollout worker rank={init_result.rank} did not return session server URL during init." ) + registry.register_started_server( + rank=init_result.rank, + actor=rank_to_worker[init_result.rank], + server_url=init_result.server_url, + session_url=init_result.session_url, + ) - server_process_workers_info = registry.all_workers() - self.logger.info(f"Rollout server-process worker URLs: {[info.url for info in server_process_workers_info]}") - lifecycle_groups = sorted({info.lifecycle_group_ranks for info in server_process_workers_info}) - self.logger.info(f"Rollout worker lifecycle groups: {lifecycle_groups}") + rollout_metadata = registry.metadata() + legacy_metadata = rollout_metadata.to_legacy() + self.logger.info( + "Rollout worker registry snapshot: " + f"server_urls={legacy_metadata['server_url_dict']}, " + f"session_urls={legacy_metadata['worker_session_url_dict']}, " + f"server_process_urls={[worker.url for worker in registry.all_workers()]}, " + f"lifecycle_groups={registry.lifecycle_groups()}" + ) return registry diff --git a/xtuner/v1/rl/rollout/health_manager.py b/xtuner/v1/rl/rollout/health_manager.py index 846cd37bd..6fa8109b6 100644 --- a/xtuner/v1/rl/rollout/health_manager.py +++ b/xtuner/v1/rl/rollout/health_manager.py @@ -64,7 +64,7 @@ def __init__( self._worker_health_failure_counts: dict[int, int] = {} self._stopped = False - def start(self) -> None: + def start_background_checks(self) -> None: health_thread_alive = self._thread is not None and self._thread.is_alive() if health_thread_alive: return @@ -73,11 +73,11 @@ def start(self) -> None: self._stop_event = threading.Event() self._pause_event = threading.Event() self._pause_event.set() - self._thread = threading.Thread(target=self._run_loop, daemon=True) + self._thread = threading.Thread(target=self._run_background_check_loop, daemon=True) self._thread.start() - logger.info("RolloutHealthManager started.") + logger.info("RolloutHealthManager background checks started.") - def stop(self) -> None: + def stop_background_checks(self) -> None: thread = self._thread if not thread: return @@ -99,19 +99,19 @@ def stop(self) -> None: self._thread = None self._stop_event = None self._pause_event = None - logger.info("RolloutHealthManager stopped.") + logger.info("RolloutHealthManager background checks stopped.") - def pause(self) -> None: + def pause_background_checks(self) -> None: if self._pause_event is None: return self._pause_event.set() - logger.info("RolloutHealthManager paused.") + logger.info("RolloutHealthManager background checks paused.") - def resume(self) -> None: + def resume_background_checks(self) -> None: if self._pause_event is None: return self._pause_event.clear() - logger.info("RolloutHealthManager resumed.") + logger.info("RolloutHealthManager background checks resumed.") def _is_paused(self) -> bool: return self._pause_event is None or self._pause_event.is_set() @@ -121,20 +121,20 @@ def _is_stopping(self) -> bool: return self._stopped or (self._stop_event is not None and self._stop_event.is_set()) @contextmanager - def _background_health_checks_paused(self): + def _background_checks_paused(self): was_paused = self._is_paused() if not was_paused: - self.pause() + self.pause_background_checks() try: yield finally: if not was_paused: - self.resume() + self.resume_background_checks() def restart_inactive_workers(self) -> None: """Synchronously restart inactive groups before the next sync-step weight update.""" - with self._background_health_checks_paused(): + with self._background_checks_paused(): with self._operation_lock: failed_groups = list(self._registry.claim_inactive_groups_for_recovery()) if not failed_groups: @@ -217,7 +217,7 @@ def check_and_shutdown_inactive_workers(self) -> None: """Fail-fast health-check active workers, mark failures inactive, and shut down every non-active group so shared resources can be reused by training.""" - with self._background_health_checks_paused(): + with self._background_checks_paused(): self._check_and_deactivate_failed_worker_groups(fail_fast=True) with self._operation_lock: inactive_groups = list(self._registry.claim_inactive_groups_for_recovery()) @@ -248,7 +248,7 @@ def check_and_shutdown_inactive_workers(self) -> None: ) ) - def run_once(self) -> None: + def run_periodic_health_check(self) -> None: logger.debug("RolloutHealthManager running health checks for all workers.") checked_active_count = self._check_and_deactivate_failed_worker_groups() if self._registry.active_workers() or self._is_stopping(): @@ -367,9 +367,9 @@ async def check_workers(workers: list[WorkerSnapshot]) -> list[bool]: return [keep_active_by_rank[worker.rank] for worker in workers_to_check] - def _run_loop(self) -> None: + def _run_background_check_loop(self) -> None: assert self._stop_event is not None and self._pause_event is not None - logger.info("RolloutHealthManager loop started.") + logger.info("RolloutHealthManager background check loop started.") while not self._stop_event.is_set(): while self._pause_event.is_set() and not self._stop_event.is_set(): @@ -385,13 +385,13 @@ def _run_loop(self) -> None: continue try: - self.run_once() + self.run_periodic_health_check() except RuntimeError: if self._is_stopping(): break - logger.exception("RolloutHealthManager run_once failed.") + logger.exception("RolloutHealthManager periodic health check failed.") except Exception: - logger.exception("RolloutHealthManager run_once failed.") + logger.exception("RolloutHealthManager periodic health check failed.") def _shutdown_worker_group( self, @@ -486,8 +486,8 @@ def _restart_worker_group( ) init_results = ray.get( [ - # init() reuses the immutable launch spec cached on each actor - # during controller startup, including placement bundles and dist addr. + # init() reuses the server launch spec bound during + # controller startup. worker.actor.init.remote() # type: ignore[attr-defined] for worker in group.workers ], @@ -505,11 +505,11 @@ def _restart_worker_group( return False for worker, init_result in zip(group.workers, init_results): - init_rank, init_url = init_result - if init_rank != worker.rank or init_url != worker.url: + if init_result.rank != worker.rank or init_result.server_url != worker.url: logger.error( f"Rollout worker restart returned unexpected endpoint: rank={worker.rank}, " - f"init_rank={init_rank}, expected_url={worker.url}, init_url={init_url}." + f"init_rank={init_result.rank}, expected_url={worker.url}, " + f"init_url={init_result.server_url}." ) self._shutdown_worker_group(group, wait_server_down=False, best_effort=True) return False diff --git a/xtuner/v1/rl/rollout/lmdeploy.py b/xtuner/v1/rl/rollout/lmdeploy.py index 7c4506238..044eb1c19 100644 --- a/xtuner/v1/rl/rollout/lmdeploy.py +++ b/xtuner/v1/rl/rollout/lmdeploy.py @@ -1,6 +1,6 @@ import os from argparse import Namespace -from typing import Any, Dict, List +from typing import Any, Dict, List, Mapping import numpy as np import ray @@ -10,7 +10,8 @@ from transformers import AutoTokenizer from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams -from .worker import EngineLaunchSpec, EngineLaunchSpecs, RolloutConfig, RolloutWorker, ServerProcessSpec +from .rollout_topology import RolloutTopology +from .worker import RolloutConfig, RolloutWorker SHARED_STORE = "shared_store" @@ -80,118 +81,77 @@ def __init__( self.lmdeploy_actor = None @classmethod - def build_engine_launch_specs( + def build_rollout_topology( cls, config: RolloutConfig, rank_bundle_idx_list: list[tuple[int, int]], - rank_to_dist_init_addr: dict[int, str] | None = None, - ) -> EngineLaunchSpecs: - """Build LMDeploy server launch layout. - - LMDeploy EP starts one request-serving server per EP rank. - - Example with expert_parallel_size=2: - rank_bundle_idx_list is [(0, 0), (1, 1), (2, 2), (3, 3)]. - rank identifies the rollout worker; bundle idx identifies the Ray - placement-group bundle that owns the GPU resource. - - If rank_to_dist_init_addr is: - {0: "addr0", 1: "addr1", 2: "addr2", 3: "addr3"} - - The launch specs are: - EngineLaunchSpec( - engine_ranks=(0, 1), - server_processes=( - ServerProcessSpec( - worker_rank=0, - placement_group_bundle_idxs=(0,), - dist_init_addr="addr0", - ), - ServerProcessSpec( - worker_rank=1, - placement_group_bundle_idxs=(1,), - dist_init_addr="addr0", - ), - ), - ) - EngineLaunchSpec( - engine_ranks=(2, 3), - server_processes=( - ServerProcessSpec( - worker_rank=2, - placement_group_bundle_idxs=(2,), - dist_init_addr="addr2", - ), - ServerProcessSpec( - worker_rank=3, - placement_group_bundle_idxs=(3,), - dist_init_addr="addr2", - ), - ), - ) - - Each EP rank launches a server process, so server_worker_ranks is the - same as engine_ranks, and every server accepts rollout requests. - """ - if config.expert_parallel_size <= 1: - return RolloutWorker.build_engine_launch_specs( - config, - rank_bundle_idx_list, - rank_to_dist_init_addr, - ) - - ep_size = config.expert_parallel_size + rank_to_dist_init_addr: Mapping[int, str], + ) -> RolloutTopology: + """Build LMDeploy rollout topology with bound engine dist-init + addresses.""" + engines = [] num_workers = len(rank_bundle_idx_list) - if num_workers % ep_size != 0: - raise ValueError(f"num_rollout_workers={num_workers} must be divisible by expert_parallel_size={ep_size}.") - - engine_launch_specs: list[EngineLaunchSpec] = [] - for engine_start in range(0, num_workers, ep_size): - engine_meta = rank_bundle_idx_list[engine_start : engine_start + ep_size] - engine_ranks = tuple(rank for rank, _ in engine_meta) - engine_dist_init_addr = None if rank_to_dist_init_addr is None else rank_to_dist_init_addr[engine_ranks[0]] - # LMDeploy EP launches one server process for each EP rank. Each - # server owns exactly one placement-group bundle, and every server - # can be used as a rollout request entrypoint. - engine_launch_specs.append( - EngineLaunchSpec( - engine_ranks=engine_ranks, - server_processes=tuple( - ServerProcessSpec( - worker_rank=server_rank, - placement_group_bundle_idxs=(bundle_idx,), - dist_init_addr=engine_dist_init_addr, - ) - for server_rank, bundle_idx in engine_meta - ), + if config.expert_parallel_size <= 1: + num_gpus_per_engine = config.num_gpus_per_engine + if num_workers % num_gpus_per_engine != 0: + raise ValueError( + f"num_rollout_workers={num_workers} must be divisible by " + f"num_gpus_per_engine={num_gpus_per_engine}." ) - ) - return cls.validate_engine_launch_specs( - tuple(engine_launch_specs), - known_worker_ranks=tuple(rank for rank, _ in rank_bundle_idx_list), - ) - - @classmethod - def build_metadata_engine_rank_mesh_array( - cls, - engine_launch_specs: EngineLaunchSpecs, - ) -> list[list[int]]: - """Keep LMDeploy EP metadata compatible with origin/main. - - Pure EP uses one request-serving server per EP rank. The logical engine topology is still stored in - EngineLaunchSpec.engine_ranks for dp_rank and lifecycle operations, but update_weighter expects the public - metadata mesh to contain one single-rank entry per request server. - """ - metadata_engine_rank_mesh_array: list[list[int]] = [] - for engine_spec in engine_launch_specs: - request_entrypoint_servers = engine_spec.request_entrypoint_servers - if len(request_entrypoint_servers) > 1: - metadata_engine_rank_mesh_array.extend( - [server_process.worker_rank] for server_process in request_entrypoint_servers + for engine_start in range(0, num_workers, num_gpus_per_engine): + engine_meta = rank_bundle_idx_list[engine_start : engine_start + num_gpus_per_engine] + engine_ranks = tuple(rank for rank, _ in engine_meta) + engine_bundle_idxs = tuple(bundle_idx for _, bundle_idx in engine_meta) + dist_init_addr_owner_rank = engine_ranks[0] + engines.append( + RolloutTopology.engine( + engine_ranks=engine_ranks, + dist_init_addr=rank_to_dist_init_addr[dist_init_addr_owner_rank], + server_processes=( + RolloutTopology.server_process( + worker_rank=engine_ranks[0], + placement_group_bundle_idxs=engine_bundle_idxs, + ), + ), + ) + ) + else: + ep_size = config.expert_parallel_size + if num_workers % ep_size != 0: + raise ValueError( + f"num_rollout_workers={num_workers} must be divisible by expert_parallel_size={ep_size}." + ) + for engine_start in range(0, num_workers, ep_size): + engine_meta = rank_bundle_idx_list[engine_start : engine_start + ep_size] + engine_ranks = tuple(rank for rank, _ in engine_meta) + dist_init_addr_owner_rank = engine_ranks[0] + engines.append( + RolloutTopology.engine( + engine_ranks=engine_ranks, + dist_init_addr=rank_to_dist_init_addr[dist_init_addr_owner_rank], + server_processes=tuple( + RolloutTopology.server_process( + worker_rank=server_rank, + placement_group_bundle_idxs=(bundle_idx,), + ) + for server_rank, bundle_idx in engine_meta + ), + ) ) + + training_engine_mesh: list[tuple[int, ...]] = [] + for engine in engines: + entrypoint_processes = tuple( + process for process in engine.server_processes if process.accepts_rollout_requests + ) + if len(entrypoint_processes) == 1: + training_engine_mesh.append(tuple(engine.engine_ranks)) else: - metadata_engine_rank_mesh_array.append(list(engine_spec.engine_ranks)) - return metadata_engine_rank_mesh_array + training_engine_mesh.extend((process.worker_rank,) for process in entrypoint_processes) + return RolloutTopology( + engines=tuple(engines), + training_engine_mesh=tuple(training_engine_mesh), + ) def offload(self): """Offloads the model weights and KV cache.""" @@ -342,7 +302,7 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: "NPU": "ascend", } - extra_config = self.config.extra_rollout_config or dict() + extra_config = self.config.extra_rollout_config lmdeploy_config_kwargs = { k.replace("lmdeploy_", ""): v for k, v in extra_config.items() if k.startswith("lmdeploy_") } @@ -383,14 +343,13 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: if backend == "pytorch" and self.config.max_prefill_token_num: extra_engine_config["max_prefill_token_num"] = self.config.max_prefill_token_num + assert self.server_launch_spec is not None dp_rank = 0 if backend == "pytorch": # currently only support ep > 1 and tp == 1 / ep == 1 and tp > 1 assert ep_size == 1 or tp_size == 1 if ep_size > 1: - engine_launch_spec = self.engine_launch_spec - assert engine_launch_spec is not None - dp_rank = engine_launch_spec.engine_ranks.index(self.rank) + dp_rank = self.server_launch_spec.engine_rank backend_config = ( PytorchEngineConfig( @@ -413,7 +372,10 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: else TurbomindEngineConfig( tp=tp_size, max_batch_size=self.config.rollout_max_batch_size_per_instance, - devices=[bundle_idxs % self.config.gpus_per_node for bundle_idxs in self.engine_bundle_idxs], + devices=[ + bundle_idx % self.config.gpus_per_node + for bundle_idx in self.server_launch_spec.placement_group_bundle_idxs + ], empty_init=self.config.skip_load_weights, session_len=self.config.context_length, model_format="fp8" if self.config.enable_float8 else None, @@ -431,7 +393,9 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: env = { "LMDEPLOY_RAY_EXTERNAL_NS": ray_runtime_ctx.namespace, "LMDEPLOY_RAY_EXTERNAL_PG_NAME": current_pg_name, - "LMDEPLOY_RAY_EXTERNAL_PG_BUNDLES": ",".join(map(str, self.engine_bundle_idxs)), + "LMDEPLOY_RAY_EXTERNAL_PG_BUNDLES": ",".join( + map(str, self.server_launch_spec.placement_group_bundle_idxs) + ), } if self.accelerator == "NPU": @@ -444,7 +408,7 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: ) if tp_size > 1: - dist_addr, dist_port = self.dist_init_addr.split(":")[:2] + dist_addr, dist_port = self.server_launch_spec.dist_init_addr.split(":")[:2] env.update( { "LMDEPLOY_DIST_MASTER_ADDR": dist_addr, @@ -452,7 +416,7 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: } ) elif ep_size > 1: - dist_addr, dist_port = self.dist_init_addr.split(":")[:2] + dist_addr, dist_port = self.server_launch_spec.dist_init_addr.split(":")[:2] if speculative_num_draft_tokens is not None: deepep_max_tokens_per_rank = max_batch_size * (1 + speculative_num_draft_tokens) else: diff --git a/xtuner/v1/rl/rollout/rollout_topology.py b/xtuner/v1/rl/rollout/rollout_topology.py new file mode 100644 index 000000000..db54f8050 --- /dev/null +++ b/xtuner/v1/rl/rollout/rollout_topology.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + + +__all__ = ["RolloutTopology", "ServerLaunchSpec"] + + +@dataclass(frozen=True) +class _ServerProcess: + """Rollout topology expression for one worker-owned server process.""" + + worker_rank: int + placement_group_bundle_idxs: tuple[int, ...] + accepts_rollout_requests: bool = True + node_rank: int = 0 + nnodes: int = 1 + + +@dataclass(frozen=True) +class _Engine: + """Rollout layout for one logical inference engine.""" + + engine_ranks: tuple[int, ...] + dist_init_addr: str + server_processes: tuple[_ServerProcess, ...] + + +@dataclass(frozen=True) +class ServerLaunchSpec: + """Worker-facing launch data projected from rollout topology.""" + + worker_rank: int + placement_group_bundle_idxs: tuple[int, ...] + dist_init_addr: str + engine_rank: int + node_rank: int = 0 + nnodes: int = 1 + + +@dataclass(frozen=True) +class RolloutTopology: + """Immutable rollout engine layout after dist-init addresses are resolved. + + Actor handles, server URLs, session URLs, and lifecycle state belong to RolloutWorkerRegistry. + """ + + engines: tuple[_Engine, ...] + training_engine_mesh: tuple[tuple[int, ...], ...] + _server_process_by_rank: dict[int, _ServerProcess] = field(init=False, repr=False, compare=False) + _lifecycle_group_by_rank: dict[int, tuple[int, ...]] = field(init=False, repr=False, compare=False) + + def __post_init__(self) -> None: + server_process_by_rank: dict[int, _ServerProcess] = {} + lifecycle_group_by_rank: dict[int, tuple[int, ...]] = {} + for engine in self.engines: + lifecycle_group = tuple(server.worker_rank for server in engine.server_processes) + for server in engine.server_processes: + if server.worker_rank in server_process_by_rank: + raise ValueError(f"Duplicate rollout server process worker_rank={server.worker_rank}.") + server_process_by_rank[server.worker_rank] = server + lifecycle_group_by_rank[server.worker_rank] = lifecycle_group + + object.__setattr__(self, "_server_process_by_rank", server_process_by_rank) + object.__setattr__(self, "_lifecycle_group_by_rank", lifecycle_group_by_rank) + + @staticmethod + def engine( + *, + engine_ranks: tuple[int, ...], + dist_init_addr: str, + server_processes: tuple[_ServerProcess, ...], + ) -> _Engine: + return _Engine( + engine_ranks=engine_ranks, + dist_init_addr=dist_init_addr, + server_processes=server_processes, + ) + + @staticmethod + def server_process( + *, + worker_rank: int, + placement_group_bundle_idxs: tuple[int, ...], + accepts_rollout_requests: bool = True, + node_rank: int = 0, + nnodes: int = 1, + ) -> _ServerProcess: + return _ServerProcess( + worker_rank=worker_rank, + placement_group_bundle_idxs=placement_group_bundle_idxs, + accepts_rollout_requests=accepts_rollout_requests, + node_rank=node_rank, + nnodes=nnodes, + ) + + def server_launch_specs(self) -> tuple[ServerLaunchSpec, ...]: + return tuple( + ServerLaunchSpec( + worker_rank=server.worker_rank, + placement_group_bundle_idxs=server.placement_group_bundle_idxs, + dist_init_addr=engine.dist_init_addr, + engine_rank=engine.engine_ranks.index(server.worker_rank), + node_rank=server.node_rank, + nnodes=server.nnodes, + ) + for engine in self.engines + for server in engine.server_processes + ) + + def lifecycle_groups(self) -> tuple[tuple[int, ...], ...]: + return tuple(dict.fromkeys(self._lifecycle_group_by_rank.values())) + + def is_request_entrypoint_rank(self, rank: int) -> bool: + server = self._server_process_by_rank.get(rank) + return server is not None and server.accepts_rollout_requests + + def lifecycle_group_for_server_rank(self, rank: int) -> tuple[int, ...]: + try: + return self._lifecycle_group_by_rank[rank] + except KeyError: + raise KeyError(f"rank={rank} does not own a rollout server process.") from None diff --git a/xtuner/v1/rl/rollout/sglang.py b/xtuner/v1/rl/rollout/sglang.py index 047fa2d5a..89dd125c7 100644 --- a/xtuner/v1/rl/rollout/sglang.py +++ b/xtuner/v1/rl/rollout/sglang.py @@ -1,6 +1,6 @@ import base64 import os -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Mapping, Union import numpy as np import ray @@ -11,13 +11,8 @@ from xtuner.v1.data_proto.rl_data import RolloutState from xtuner.v1.utils import XTUNER_DETERMINISTIC -from .worker import ( - EngineLaunchSpec, - EngineLaunchSpecs, - RolloutConfig, - RolloutWorker, - ServerProcessSpec, -) +from .rollout_topology import RolloutTopology +from .worker import RolloutConfig, RolloutWorker class SGLangWorker(RolloutWorker): @@ -49,53 +44,14 @@ def __init__( self.enable_return_routed_experts = self.config.enable_return_routed_experts @classmethod - def build_engine_launch_specs( + def build_rollout_topology( cls, config: RolloutConfig, rank_bundle_idx_list: list[tuple[int, int]], - rank_to_dist_init_addr: dict[int, str] | None = None, - ) -> EngineLaunchSpecs: - """Build SGLang server launch layout. - - SGLang starts one server per node in a logical engine. Only node 0 is - used as the rollout request entrypoint. - - Example with expert_parallel_size=16 and gpus_per_node=8: - rank_bundle_idx_list is: - [(0, 0), (1, 1), ..., (15, 15)] - - If rank_to_dist_init_addr is: - {0: "addr0", 1: "addr1", ..., 15: "addr15"} - - The launch spec is: - EngineLaunchSpec( - engine_ranks=(0, 1, 2, 3, 4, 5, 6, 7, - 8, 9, 10, 11, 12, 13, 14, 15), - server_processes=( - ServerProcessSpec( - worker_rank=0, - placement_group_bundle_idxs=(0, 1, 2, 3, 4, 5, 6, 7), - dist_init_addr="addr0", - accepts_rollout_requests=True, - node_rank=0, - nnodes=2, - ), - ServerProcessSpec( - worker_rank=8, - placement_group_bundle_idxs=(8, 9, 10, 11, 12, 13, 14, 15), - dist_init_addr="addr0", - accepts_rollout_requests=False, - node_rank=1, - nnodes=2, - ), - ), - ) - - SGLang starts one server per node, so server_worker_ranks is (0, 8). - Only the node-0 server accepts rollout requests. - """ + rank_to_dist_init_addr: Mapping[int, str], + ) -> RolloutTopology: num_workers = len(rank_bundle_idx_list) - num_gpus_per_engine = cls._get_num_gpus_per_engine(config) + num_gpus_per_engine = config.num_gpus_per_engine if num_workers % num_gpus_per_engine != 0: raise ValueError( f"num_rollout_workers={num_workers} must be divisible by num_gpus_per_engine={num_gpus_per_engine}." @@ -106,7 +62,7 @@ def build_engine_launch_specs( ) nnodes = max(1, num_gpus_per_engine // config.gpus_per_node) - engine_launch_specs: list[EngineLaunchSpec] = [] + engines = [] for engine_start in range(0, num_workers, num_gpus_per_engine): engine_meta = rank_bundle_idx_list[engine_start : engine_start + num_gpus_per_engine] engine_ranks = tuple(rank for rank, _ in engine_meta) @@ -115,30 +71,30 @@ def build_engine_launch_specs( # first rank of each node owns that node's bundles, while only node # 0 is exposed as the rollout request entrypoint. server_ranks = engine_ranks[:: config.gpus_per_node] - engine_dist_init_addr = None if rank_to_dist_init_addr is None else rank_to_dist_init_addr[server_ranks[0]] - server_processes: list[ServerProcessSpec] = [] + dist_init_addr_owner_rank = server_ranks[0] + server_processes = [] for node_rank, server_rank in enumerate(server_ranks): node_bundle_start = node_rank * config.gpus_per_node node_bundle_end = node_bundle_start + config.gpus_per_node server_processes.append( - ServerProcessSpec( + RolloutTopology.server_process( worker_rank=server_rank, placement_group_bundle_idxs=engine_bundle_idxs[node_bundle_start:node_bundle_end], - dist_init_addr=engine_dist_init_addr, accepts_rollout_requests=node_rank == 0, node_rank=node_rank, nnodes=nnodes, ) ) - engine_launch_specs.append( - EngineLaunchSpec( + engines.append( + RolloutTopology.engine( engine_ranks=engine_ranks, + dist_init_addr=rank_to_dist_init_addr[dist_init_addr_owner_rank], server_processes=tuple(server_processes), ) ) - return cls.validate_engine_launch_specs( - tuple(engine_launch_specs), - known_worker_ranks=tuple(rank for rank, _ in rank_bundle_idx_list), + return RolloutTopology( + engines=tuple(engines), + training_engine_mesh=tuple(tuple(engine.engine_ranks) for engine in engines), ) def _get_request_payload(self, rollout_state: RolloutState) -> dict: @@ -325,7 +281,7 @@ def _transform_rollout_config_to_server_configs(self): os.environ.pop("CUDA_VISIBLE_DEVICES", None) from sglang.srt.server_args import ServerArgs - extra_config = self.config.extra_rollout_config or dict() + extra_config = self.config.extra_rollout_config sglang_config_kwargs = { k.replace("sglang_", ""): v for k, v in extra_config.items() if k.startswith("sglang_") } @@ -338,13 +294,7 @@ def _transform_rollout_config_to_server_configs(self): ) tp_size = num_gpus_per_engine if self.config.expert_parallel_size > 1 else self.config.tensor_parallel_size ep_size = num_gpus_per_engine if self.config.expert_parallel_size > 1 else self.config.expert_parallel_size - server_process_spec = self._get_current_server_process_spec() - nnodes = ( - server_process_spec.nnodes - if server_process_spec is not None - else max(1, num_gpus_per_engine // self.config.gpus_per_node) - ) - node_rank = server_process_spec.node_rank if server_process_spec is not None else 0 + assert self.server_launch_spec is not None assigned_gpu_id = int(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0]) # SGLang 0.5.10 默认启用的 Piecewise CUDA Graph 在启动 warmup compile 阶段会报错。sglang的文档提到这个功能还是实验功能,可能还不太稳定(https://sgl-project-sglang-93.mintlify.app/optimization/cuda-graph#bug-report)。暂时先通过disable_piecewise_cuda_graph=True关掉改功能 @@ -354,11 +304,11 @@ def _transform_rollout_config_to_server_configs(self): host=self.host, port=self.server_port, nccl_port=self.nccl_port, - dist_init_addr=self.dist_init_addr, + dist_init_addr=self.server_launch_spec.dist_init_addr, base_gpu_id=assigned_gpu_id, gpu_id_step=1, - nnodes=nnodes, - node_rank=node_rank, + nnodes=self.server_launch_spec.nnodes, + node_rank=self.server_launch_spec.node_rank, skip_server_warmup=True, mem_fraction_static=self.config.gpu_memory_utilization, enable_memory_saver=True, diff --git a/xtuner/v1/rl/rollout/vllm.py b/xtuner/v1/rl/rollout/vllm.py index 8cbeaef69..a999baf98 100644 --- a/xtuner/v1/rl/rollout/vllm.py +++ b/xtuner/v1/rl/rollout/vllm.py @@ -2,7 +2,7 @@ import os import traceback from argparse import Namespace -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Mapping, Union import numpy as np import ray @@ -16,6 +16,7 @@ from xtuner.v1.data_proto.rl_data import RolloutState, Status, update_status_from_finish_reason from xtuner.v1.utils.device import get_device, get_torch_device_module +from .rollout_topology import RolloutTopology from .worker import RolloutConfig, RolloutWorker @@ -131,6 +132,15 @@ def run_lmdeploy_server_wrapper(server_namespace: Namespace): class vLLMWorker(RolloutWorker): + @classmethod + def build_rollout_topology( + cls, + config: RolloutConfig, + rank_bundle_idx_list: list[tuple[int, int]], + rank_to_dist_init_addr: Mapping[int, str], + ) -> RolloutTopology: + raise NotImplementedError("vLLM rollout topology has not been verified after topology refactor.") + def __init__( self, config: RolloutConfig, @@ -323,13 +333,14 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: args["limit_mm_per_prompt"] = {"image": 10, "video": 0} args["enable_log_requests"] = False args["uvicorn_log_level"] = "error" + assert self.server_launch_spec is not None env = { "VLLM_VERSION": "0.11.0", "TASK_QUEUE_ENABLE": "0", "CPU_AFFINITY_CONF": "2", "VLLM_USE_V1": "1", "VLLM_RAY_PER_WORKER_GPUS": "0.1", - "VLLM_RAY_BUNDLE_INDICES": ",".join(map(str, self.engine_bundle_idxs)), + "VLLM_RAY_BUNDLE_INDICES": ",".join(map(str, self.server_launch_spec.placement_group_bundle_idxs)), "VLLM_MONITOR": "1", "VLLM_ACCU_MONITOR": "0", "CUSTOM_SCHEDULE_KV_LIMIT": "0.9", diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py index c8c0ce246..4a9800282 100644 --- a/xtuner/v1/rl/rollout/worker.py +++ b/xtuner/v1/rl/rollout/worker.py @@ -9,7 +9,7 @@ from abc import abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, TypeAlias, Union, cast +from typing import TYPE_CHECKING, Any, Callable, List, Literal, Mapping, Optional, Union, cast import httpx import ray @@ -40,6 +40,7 @@ from .constants import ROLLOUT_HTTP_MAX_CONNECTIONS, ROLLOUT_RAY_GENERATE_MAX_CONCURRENCY from .health_manager import ROLLOUT_RAY_GET_TIMEOUT +from .rollout_topology import RolloutTopology, ServerLaunchSpec from .session_server import SessionServerActor from .utils import PartialRolloutHandler @@ -53,56 +54,12 @@ @dataclass(frozen=True) -class ServerProcessSpec: - """How to start one rollout server process.""" - - # Worker rank that owns this server process. - worker_rank: int - # Placement-group bundle indexes assigned to this server process. - placement_group_bundle_idxs: tuple[int, ...] - # Distributed init address used by every server process in the same engine. - # Filled after init_dist_port initializes worker-local ports. - dist_init_addr: str | None = None - # Whether this server is exposed as a rollout request entrypoint. Some - # backends launch extra server processes that must participate in - # lifecycle/health operations but must not be added to worker_server_urls_map - # or receive normal rollout traffic. - accepts_rollout_requests: bool = True - # Node index of this server inside a multi-node logical engine. - node_rank: int = 0 - # Number of nodes used by this logical engine. - nnodes: int = 1 +class RolloutWorkerInitResult: + """Result returned by RolloutWorker.init() after its server starts.""" - -@dataclass(frozen=True) -class EngineLaunchSpec: - """How to launch rollout servers for one logical inference engine.""" - - # All worker ranks that form this logical inference engine. - engine_ranks: tuple[int, ...] - # Server processes required by this engine. - server_processes: tuple[ServerProcessSpec, ...] - - @property - def server_worker_ranks(self) -> tuple[int, ...]: - return tuple(server.worker_rank for server in self.server_processes) - - @property - def request_entrypoint_servers(self) -> tuple[ServerProcessSpec, ...]: - return tuple(server for server in self.server_processes if server.accepts_rollout_requests) - - @property - def request_entrypoint_worker_ranks(self) -> tuple[int, ...]: - return tuple(server.worker_rank for server in self.request_entrypoint_servers) - - @property - def placement_group_bundle_idxs(self) -> tuple[int, ...]: - return tuple( - bundle_idx for server in self.server_processes for bundle_idx in server.placement_group_bundle_idxs - ) - - -EngineLaunchSpecs: TypeAlias = tuple[EngineLaunchSpec, ...] + rank: int + server_url: str + session_url: str | None def get_rollout_worker_base_cls(config: "RolloutConfig") -> type["RolloutWorker"]: @@ -555,8 +512,7 @@ def __init__( self.accelerator = accelerator self.server_func: Callable self.endpoints: dict[str, str] = dict() - self.engine_rank_mesh_array: list[list[int]] = [] - self.engine_launch_spec: EngineLaunchSpec | None = None + self.server_launch_spec: ServerLaunchSpec | None = None # Keep this deliberately large so requests do not queue in the # RolloutWorker/httpx client; the inference engine owns rollout request # scheduling and queueing. @@ -564,7 +520,6 @@ def __init__( limits = httpx.Limits(max_connections=http_concurrency, max_keepalive_connections=100) self.client = httpx.AsyncClient(limits=limits, timeout=self.config.rollout_timeout) self.server_task = None - self.engine_bundle_idxs: list[int] = [] self.server_process: Optional[multiprocessing.Process] = None self.session_server_actor: Any | None = None self.session_server_url: str | None = None @@ -578,205 +533,48 @@ def __init__( self.logger.info(f"Using eos_token: {eos_token} for model at {self.config.model_path}") self.eos_token: List[int] = [eos_token] if isinstance(eos_token, int) else eos_token self.receive_abort_request = threading.Event() - self.dist_init_addr: str = "" self.serverl_url: str = "" self.partial_rollout_handler = PartialRolloutHandler() self.enable_partial_rollout: bool = False - @staticmethod - def _get_num_gpus_per_engine(config: RolloutConfig) -> int: - return config.num_gpus_per_engine - @classmethod - def validate_engine_launch_specs( - cls, - engine_launch_specs: EngineLaunchSpecs, - *, - known_worker_ranks: tuple[int, ...] | None = None, - ) -> EngineLaunchSpecs: - """Validate backend launch layout before the controller launches - servers.""" - if not engine_launch_specs: - raise ValueError("engine_launch_specs must define at least one engine.") - - known_worker_rank_set = set(known_worker_ranks) if known_worker_ranks is not None else None - seen_engine_ranks: set[int] = set() - seen_server_ranks: set[int] = set() - seen_bundle_idxs: set[int] = set() - for engine_index, engine_spec in enumerate(engine_launch_specs): - if not engine_spec.engine_ranks: - raise ValueError(f"EngineLaunchSpec[{engine_index}] must define at least one engine rank.") - engine_rank_set = set(engine_spec.engine_ranks) - if len(engine_rank_set) != len(engine_spec.engine_ranks): - raise ValueError( - f"EngineLaunchSpec[{engine_index}] has duplicate engine ranks: {engine_spec.engine_ranks}." - ) - if known_worker_rank_set is not None: - unknown_engine_ranks = sorted( - rank for rank in engine_spec.engine_ranks if rank not in known_worker_rank_set - ) - if unknown_engine_ranks: - raise ValueError( - f"EngineLaunchSpec[{engine_index}] references unknown engine ranks: {unknown_engine_ranks}." - ) - duplicated_engine_ranks = sorted(rank for rank in engine_spec.engine_ranks if rank in seen_engine_ranks) - if duplicated_engine_ranks: - raise ValueError( - f"EngineLaunchSpec[{engine_index}] engine ranks appear in more than one engine: " - f"{duplicated_engine_ranks}." - ) - seen_engine_ranks.update(engine_spec.engine_ranks) - - if not engine_spec.server_processes: - raise ValueError(f"EngineLaunchSpec[{engine_index}] must define at least one server process.") - - for server_process in engine_spec.server_processes: - server_rank = server_process.worker_rank - if server_rank not in engine_rank_set: - raise ValueError( - f"EngineLaunchSpec[{engine_index}] server worker_rank={server_rank} " - f"must be part of engine_ranks={engine_spec.engine_ranks}." - ) - if server_rank in seen_server_ranks: - raise ValueError(f"Server worker_rank={server_rank} appears in more than one server process.") - seen_server_ranks.add(server_rank) - - if not server_process.placement_group_bundle_idxs: - raise ValueError(f"Server worker_rank={server_rank} must own at least one placement-group bundle.") - if len(set(server_process.placement_group_bundle_idxs)) != len( - server_process.placement_group_bundle_idxs - ): - raise ValueError( - f"Server worker_rank={server_rank} has duplicate placement-group bundles: " - f"{server_process.placement_group_bundle_idxs}." - ) - duplicated_bundle_idxs = sorted( - bundle_idx - for bundle_idx in server_process.placement_group_bundle_idxs - if bundle_idx in seen_bundle_idxs - ) - if duplicated_bundle_idxs: - raise ValueError( - f"Placement-group bundles are assigned to multiple server processes: {duplicated_bundle_idxs}." - ) - seen_bundle_idxs.update(server_process.placement_group_bundle_idxs) - - if server_process.nnodes < 1: - raise ValueError(f"Server worker_rank={server_rank} must have nnodes >= 1.") - if server_process.node_rank < 0 or server_process.node_rank >= server_process.nnodes: - raise ValueError( - f"Server worker_rank={server_rank} has invalid node_rank={server_process.node_rank} " - f"for nnodes={server_process.nnodes}." - ) - - if not engine_spec.request_entrypoint_servers: - raise ValueError(f"EngineLaunchSpec[{engine_index}] must expose at least one request entrypoint.") - - if known_worker_rank_set is not None: - missing_engine_ranks = sorted(known_worker_rank_set - seen_engine_ranks) - if missing_engine_ranks: - raise ValueError( - f"EngineLaunchSpecs do not cover known worker ranks in engine_ranks: {missing_engine_ranks}." - ) - - return engine_launch_specs - - @classmethod - def build_engine_launch_specs( + @abstractmethod + def build_rollout_topology( cls, config: RolloutConfig, rank_bundle_idx_list: list[tuple[int, int]], - rank_to_dist_init_addr: dict[int, str] | None = None, - ) -> EngineLaunchSpecs: - """Build default launch spec: one request-serving server per engine.""" - num_gpus_per_engine = cls._get_num_gpus_per_engine(config) - num_workers = len(rank_bundle_idx_list) - if num_workers % num_gpus_per_engine != 0: - raise ValueError( - f"num_rollout_workers={num_workers} must be divisible by num_gpus_per_engine={num_gpus_per_engine}." - ) - - engine_launch_specs: list[EngineLaunchSpec] = [] - for engine_start in range(0, num_workers, num_gpus_per_engine): - engine_meta = rank_bundle_idx_list[engine_start : engine_start + num_gpus_per_engine] - engine_ranks = tuple(rank for rank, _ in engine_meta) - engine_bundle_idxs = tuple(bundle_idx for _, bundle_idx in engine_meta) - engine_dist_init_addr = None if rank_to_dist_init_addr is None else rank_to_dist_init_addr[engine_ranks[0]] - engine_launch_specs.append( - EngineLaunchSpec( - engine_ranks=engine_ranks, - server_processes=( - ServerProcessSpec( - worker_rank=engine_ranks[0], - placement_group_bundle_idxs=engine_bundle_idxs, - dist_init_addr=engine_dist_init_addr, - ), - ), - ) - ) - return cls.validate_engine_launch_specs( - tuple(engine_launch_specs), - known_worker_ranks=tuple(rank for rank, _ in rank_bundle_idx_list), - ) - - @classmethod - def build_metadata_engine_rank_mesh_array( - cls, - engine_launch_specs: EngineLaunchSpecs, - ) -> list[list[int]]: - """Build the public engine mesh returned in rollout metadata. - - By default, the public metadata mesh matches the logical engine topology. Backends with multiple request - servers per logical engine can override this to preserve their legacy update-weight mesh semantics. - """ - return [list(engine_spec.engine_ranks) for engine_spec in engine_launch_specs] - - def _get_current_server_process_spec( - self, - engine_launch_spec: EngineLaunchSpec | None = None, - ) -> ServerProcessSpec | None: - engine_launch_spec = engine_launch_spec or self.engine_launch_spec - if engine_launch_spec is None: - return None - - for server_process_spec in engine_launch_spec.server_processes: - if server_process_spec.worker_rank == self.rank: - return server_process_spec - raise RuntimeError( - f"Engine launch spec does not include rollout worker rank={self.rank} " - f"in server_worker_ranks={engine_launch_spec.server_worker_ranks}." - ) + rank_to_dist_init_addr: Mapping[int, str], + ) -> RolloutTopology: + raise NotImplementedError("Concrete rollout worker classes must implement build_rollout_topology().") def set_enable_partial_rollout(self, enable: bool) -> None: self.enable_partial_rollout = enable + def bind_server_launch_spec(self, server_launch_spec: ServerLaunchSpec) -> None: + if server_launch_spec.worker_rank != self.rank: + raise ValueError( + f"Server launch spec rank={server_launch_spec.worker_rank} does not match worker rank={self.rank}." + ) + self.server_launch_spec = server_launch_spec + def init( self, - *, - engine_launch_spec: EngineLaunchSpec | None = None, - ) -> tuple[int, str]: + ) -> RolloutWorkerInitResult: """Initialize the worker and launch the server. Returns: - Tuple[int, str]: A tuple containing the worker's rank and its - server URL. + Startup result containing rank, server URL, and session URL. """ - if engine_launch_spec is not None: - # Initial controller startup passes the immutable launch spec and caches - # it on the actor. Recovery calls init() without arguments after - # shutdown, intentionally reusing this cached placement/dist layout. - self.engine_launch_spec = engine_launch_spec - server_process_spec = cast( - ServerProcessSpec, - self._get_current_server_process_spec(engine_launch_spec), - ) - self.engine_bundle_idxs = list(server_process_spec.placement_group_bundle_idxs) - if server_process_spec.dist_init_addr is not None: - self.dist_init_addr = server_process_spec.dist_init_addr + if self.server_launch_spec is None: + raise RuntimeError("RolloutWorker.bind_server_launch_spec() must be called before init().") self.receive_abort_request.clear() self._launch_server() self._start_session_server() - return (self.rank, self.server_url) + return RolloutWorkerInitResult( + rank=self.rank, + server_url=self.server_url, + session_url=self.session_server_url, + ) def set_skip_load_weights(self, skip_load_weights: bool) -> None: self.config = self.config.model_copy(update={"skip_load_weights": skip_load_weights}) @@ -784,7 +582,7 @@ def set_skip_load_weights(self, skip_load_weights: bool) -> None: def restore_skip_load_weights(self) -> None: self.config = self.config.model_copy(update={"skip_load_weights": self._default_skip_load_weights}) - def init_dist_port(self) -> str: + def init_dist_port(self) -> tuple[int, str]: """Initialize distributed communication ports. This method initializes four fixed ports for the distributed setup: @@ -792,7 +590,7 @@ def init_dist_port(self) -> str: for NCCL, and one for the session server. Returns: - str: The distributed initialization address (host:port). + Worker rank and distributed initialization address (host:port). """ local_rank = int(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0]) base_port = self.config.dist_port_base + local_rank * 4 @@ -801,9 +599,9 @@ def init_dist_port(self) -> str: self.server_port = base_port + 1 self.nccl_port = base_port + 2 self.session_server_port = base_port + 3 - self.dist_init_addr = f"{self.host}:{self.dist_port}" + dist_init_addr = f"{self.host}:{self.dist_port}" self.server_url = f"http://{self.host}:{self.server_port}" - return self.dist_init_addr + return self.rank, dist_init_addr def shutdown(self, *, stop_session_server: bool = False): """Shut down the worker, its server task, and any child processes.""" @@ -849,11 +647,12 @@ def _start_session_server(self) -> None: if self.session_server_actor is not None: return + assert self.server_launch_spec is not None current_pg = ray.util.get_current_placement_group() scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=current_pg, placement_group_capture_child_tasks=False, - placement_group_bundle_index=self.engine_bundle_idxs[0], + placement_group_bundle_index=self.server_launch_spec.placement_group_bundle_idxs[0], ) self.session_server_actor = ( ray.remote(SessionServerActor) @@ -883,9 +682,6 @@ def _stop_session_server(self) -> None: self.session_server_actor = None self.session_server_url = None - def get_session_server_info(self) -> tuple[int, str | None]: - return self.rank, self.session_server_url - async def pause_generation(self): """Pause the worker's generation process.""" self.receive_abort_request.set() @@ -1194,11 +990,12 @@ def _launch_server(self): else: # launch the server as ray task # so that the lmdeploy backend could get externl pg + assert self.server_launch_spec is not None current_pg = ray.util.get_current_placement_group() scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=current_pg, placement_group_capture_child_tasks=True, - placement_group_bundle_index=self.engine_bundle_idxs[0], + placement_group_bundle_index=self.server_launch_spec.placement_group_bundle_idxs[0], ) assert ray.is_initialized() ray_kwargs = ( diff --git a/xtuner/v1/rl/rollout/worker_registry.py b/xtuner/v1/rl/rollout/worker_registry.py index 4770dd59b..25f6a785d 100644 --- a/xtuner/v1/rl/rollout/worker_registry.py +++ b/xtuner/v1/rl/rollout/worker_registry.py @@ -3,13 +3,15 @@ import threading from dataclasses import dataclass, replace from enum import Enum -from typing import TYPE_CHECKING, Iterable, TypedDict +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from .rollout_topology import RolloutTopology from .worker import RolloutConfig, RolloutWorker __all__ = [ + "RolloutWorkerEndpointMetadata", "RolloutWorkerMetadata", "RolloutWorkerRegistry", "WorkerGroup", @@ -31,20 +33,12 @@ class WorkerLifecycleState(str, Enum): class WorkerSnapshot: """Read-only snapshot for one rollout server process.""" + rank: int actor: RolloutWorker url: str session_url: str | None = None - lifecycle_state: WorkerLifecycleState = WorkerLifecycleState.ACTIVE - lifecycle_group_ranks: tuple[int, ...] = () is_request_entrypoint: bool = True - rank: int = -1 - - def __post_init__(self) -> None: - lifecycle_state = ( - WorkerLifecycleState.ACTIVE if self.lifecycle_state is None else WorkerLifecycleState(self.lifecycle_state) - ) - object.__setattr__(self, "lifecycle_state", lifecycle_state) - object.__setattr__(self, "lifecycle_group_ranks", tuple(self.lifecycle_group_ranks)) + lifecycle_state: WorkerLifecycleState = WorkerLifecycleState.ACTIVE def is_active(self) -> bool: return self.lifecycle_state is WorkerLifecycleState.ACTIVE @@ -56,31 +50,49 @@ class WorkerGroup: workers: tuple[WorkerSnapshot, ...] -class RolloutWorkerMetadata(TypedDict): - """Legacy rollout worker metadata consumed by trainer/update-weight - code.""" +@dataclass(frozen=True) +class RolloutWorkerEndpointMetadata: + """URL and lifecycle state for one request-serving rollout endpoint.""" - engine_rank_mesh_array: list[list[int]] - server_url_dict: dict[int, str] - rollout_config: RolloutConfig - worker_server_urls_status: dict[str, bool] - worker_session_url_dict: dict[int, str] - worker_session_urls_status: dict[str, bool] + rank: int + server_url: str + session_url: str | None + lifecycle_state: WorkerLifecycleState + + @property + def is_active(self) -> bool: + return self.lifecycle_state is WorkerLifecycleState.ACTIVE -def _build_worker_groups(workers: Iterable[WorkerSnapshot]) -> dict[tuple[int, ...], WorkerGroup]: - grouped_workers: dict[tuple[int, ...], list[WorkerSnapshot]] = {} - for worker in workers: - group_ranks = worker.lifecycle_group_ranks or (worker.rank,) - grouped_workers.setdefault(group_ranks, []).append(worker) +@dataclass(frozen=True) +class RolloutWorkerMetadata: + """Structured rollout worker metadata consumed by trainer/update-weight + code.""" - return { - group_ranks: WorkerGroup( - ranks=group_ranks, - workers=tuple(sorted(group_workers, key=lambda worker: worker.rank)), - ) - for group_ranks, group_workers in grouped_workers.items() - } + rollout_config: RolloutConfig + training_engine_mesh: tuple[tuple[int, ...], ...] + request_endpoints: tuple[RolloutWorkerEndpointMetadata, ...] + + def to_legacy(self) -> dict[str, Any]: + """Serialize to the current trainer-facing rollout metadata dict.""" + return { + "engine_rank_mesh_array": [list(engine_ranks) for engine_ranks in self.training_engine_mesh], + "server_url_dict": {endpoint.rank: endpoint.server_url for endpoint in self.request_endpoints}, + "rollout_config": self.rollout_config, + "worker_server_urls_status": { + endpoint.server_url: endpoint.is_active for endpoint in self.request_endpoints + }, + "worker_session_url_dict": { + endpoint.rank: endpoint.session_url + for endpoint in self.request_endpoints + if endpoint.session_url is not None + }, + "worker_session_urls_status": { + endpoint.session_url: endpoint.is_active + for endpoint in self.request_endpoints + if endpoint.session_url is not None + }, + } class RolloutWorkerRegistry: @@ -90,12 +102,11 @@ class RolloutWorkerRegistry: def __init__( self, *, - engine_rank_mesh_array: list[list[int]], + rollout_topology: RolloutTopology, rollout_config: RolloutConfig, ): - """Initialize an empty registry with the training-side metadata - projection.""" - self._engine_rank_mesh_array = [list(engine_ranks) for engine_ranks in engine_rank_mesh_array] + """Initialize an empty registry with the rollout topology.""" + self._rollout_topology = rollout_topology self._rollout_config = rollout_config self._workers: dict[int, WorkerSnapshot] = {} self._lock = threading.RLock() @@ -107,8 +118,7 @@ def register_started_server( actor: RolloutWorker, server_url: str, session_url: str | None = None, - lifecycle_group_ranks: tuple[int, ...] = (), - is_request_entrypoint: bool = True, + lifecycle_state: WorkerLifecycleState = WorkerLifecycleState.ACTIVE, ) -> None: """Register one worker actor after its rollout server process has started.""" @@ -118,8 +128,8 @@ def register_started_server( actor=actor, url=server_url, session_url=session_url, - lifecycle_group_ranks=lifecycle_group_ranks or (rank,), - is_request_entrypoint=is_request_entrypoint, + is_request_entrypoint=self._rollout_topology.is_request_entrypoint_rank(rank), + lifecycle_state=lifecycle_state, ) def all_workers(self) -> tuple[WorkerSnapshot, ...]: @@ -162,11 +172,28 @@ def active_entrypoint_by_rank(self, rank: int) -> WorkerSnapshot | None: return None return worker + def lifecycle_groups(self) -> tuple[tuple[int, ...], ...]: + """Return registered lifecycle groups in rank order.""" + with self._lock: + return tuple(sorted(self._rollout_topology.lifecycle_groups())) + + def _build_worker_groups(self) -> dict[tuple[int, ...], WorkerGroup]: + grouped_ranks = { + self._rollout_topology.lifecycle_group_for_server_rank(worker.rank) for worker in self._workers.values() + } + return { + group_ranks: WorkerGroup( + ranks=group_ranks, + workers=tuple(self._workers[rank] for rank in group_ranks if rank in self._workers), + ) + for group_ranks in grouped_ranks + } + def claim_inactive_groups_for_recovery(self) -> tuple[WorkerGroup, ...]: """Claim non-active worker groups by moving them to recovering state.""" with self._lock: - worker_groups = _build_worker_groups(self._workers.values()) + worker_groups = self._build_worker_groups() inactive_groups = [ group for group in worker_groups.values() @@ -184,16 +211,14 @@ def mark_unhealthy_ranks(self, ranks: set[int]) -> tuple[WorkerGroup, ...]: """Mark every lifecycle group containing a failed rank as inactive.""" with self._lock: failed_group_ranks = { - worker.lifecycle_group_ranks or (worker.rank,) - for rank, worker in self._workers.items() - if rank in ranks + self._rollout_topology.lifecycle_group_for_server_rank(rank) for rank in ranks if rank in self._workers } for group_ranks in failed_group_ranks: for rank in group_ranks: worker = self._workers.get(rank) if worker is not None: self._workers[rank] = replace(worker, lifecycle_state=WorkerLifecycleState.INACTIVE) - worker_groups = _build_worker_groups(self._workers.values()) + worker_groups = self._build_worker_groups() return tuple( worker_groups[group_ranks] for group_ranks in sorted(failed_group_ranks) @@ -214,29 +239,24 @@ def set_group_recovery_result( worker = self._workers.get(rank) if worker is not None: self._workers[rank] = replace(worker, lifecycle_state=lifecycle_state) - worker_groups = _build_worker_groups(self._workers.values()) + worker_groups = self._build_worker_groups() return worker_groups.get(group.ranks) - def training_metadata_snapshot(self) -> RolloutWorkerMetadata: - """Build the legacy trainer/update-weight metadata from one registry - snapshot.""" + def metadata(self) -> RolloutWorkerMetadata: + """Build trainer/update-weight metadata from one registry snapshot.""" with self._lock: - request_entrypoints = {rank: info for rank, info in self._workers.items() if info.is_request_entrypoint} - worker_server_urls_map = {rank: info.url for rank, info in request_entrypoints.items()} - worker_server_urls_status = {info.url: info.is_active() for info in request_entrypoints.values()} - worker_session_url_dict: dict[int, str] = {} - worker_session_urls_status: dict[str, bool] = {} - for rank, info in request_entrypoints.items(): - if info.session_url is None: - continue - worker_session_url_dict[rank] = info.session_url - worker_session_urls_status[info.session_url] = info.is_active() - - return { - "engine_rank_mesh_array": [list(engine_ranks) for engine_ranks in self._engine_rank_mesh_array], - "server_url_dict": worker_server_urls_map, - "rollout_config": self._rollout_config, - "worker_server_urls_status": worker_server_urls_status, - "worker_session_url_dict": worker_session_url_dict, - "worker_session_urls_status": worker_session_urls_status, - } + request_endpoints = tuple( + RolloutWorkerEndpointMetadata( + rank=worker.rank, + server_url=worker.url, + session_url=worker.session_url, + lifecycle_state=worker.lifecycle_state, + ) + for worker in self.all_workers() + if worker.is_request_entrypoint + ) + return RolloutWorkerMetadata( + rollout_config=self._rollout_config, + training_engine_mesh=self._rollout_topology.training_engine_mesh, + request_endpoints=request_endpoints, + )