Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 177 additions & 34 deletions tests/rl/test_rollout_logic.py

Large diffs are not rendered by default.

133 changes: 58 additions & 75 deletions xtuner/v1/rl/rollout/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
RolloutConfig,
get_rollout_worker_base_cls,
)
from .worker_registry import RolloutWorkerMetadata, RolloutWorkerRegistry
from .worker_registry import RolloutWorkerRegistry


# Keep this as a Ray actor because Ray AgentLoop actors need a shared, cross-process handle to the same controller
Expand Down Expand Up @@ -61,19 +61,15 @@ def __init__(
registry=self.registry,
worker_lifecycle_listeners=[self.proxy_manager] if self.proxy_manager is not None else None,
)
self.health_manager.start()
self.health_manager.start_background_checks()

def get_rollout_metadata(self) -> RolloutWorkerMetadata:
def get_rollout_metadata(self) -> dict:
"""Get information about the current rollout setup.

Returns:
dict: A dictionary containing the engine mesh list, server URL
dictionary, and the rollout configuration.
Legacy trainer/update-weight rollout metadata dictionary.
"""
rollout_metadata = self.registry.training_metadata_snapshot()
self.logger.info(f"Rollout worker server URLs: {rollout_metadata['server_url_dict']}")
self.logger.info(f"Rollout worker session server URLs: {rollout_metadata['worker_session_url_dict']}")
return rollout_metadata
return self.registry.metadata().to_legacy()

def register_active_workers_to_proxy(self) -> None:
if self.proxy_manager is None:
Expand Down Expand Up @@ -133,7 +129,7 @@ def set_enable_partial_rollout(self, enable: bool) -> None:
)

def pause_generation(self):
self.health_manager.pause()
self.health_manager.pause_background_checks()
active_workers = self.registry.active_workers()
futures = [
worker.actor.pause_generation.remote() # type: ignore[attr-defined]
Expand Down Expand Up @@ -164,7 +160,7 @@ async def restart_inactive_workers(self):

def continue_generation(self):
self._broadcast_to_active_workers("continue_generation")
self.health_manager.resume()
self.health_manager.resume_background_checks()

def offload(self):
self._broadcast_to_active_workers("offload")
Expand All @@ -181,7 +177,7 @@ def onload_kvcache(self):

def shutdown(self):
"""Shut down all rollout workers tracked by the controller."""
self.health_manager.stop()
self.health_manager.stop_background_checks()
actors = self.registry.all_actors()
ray.get(
[actor.shutdown.remote(stop_session_server=True) for actor in actors], # type: ignore[attr-defined]
Expand All @@ -203,13 +199,16 @@ def _build_remote_worker_cls(self, worker_base_cls):
},
)(worker_base_cls)

def _init_workers(self, placement_group: PlacementGroup) -> RolloutWorkerRegistry:
def _init_workers(
self,
placement_group: PlacementGroup,
) -> RolloutWorkerRegistry:
"""Initializes and configures the pool of RolloutWorker actors.

This method follows the same high-level flow as the legacy implementation:
create workers, initialize worker-local ports, build engine groups,
select workers that launch rollout servers, launch servers, and
expose request-entrypoint server URLs to rollout traffic.
create workers, initialize worker-local ports, build the bound rollout
topology, launch rollout servers, and expose request-entrypoint server
URLs to rollout traffic.

Returns:
A registry containing all server-process workers and the public
Expand All @@ -222,79 +221,63 @@ def _init_workers(self, placement_group: PlacementGroup) -> RolloutWorkerRegistr
workers, rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group(
worker_cls, self.config, placement_group
)
rank_to_actor = {rank: worker for (rank, _), worker in zip(rank_bundle_idx_list, workers)}

# Reserve worker-local ports for all actors first. build_engine_launch_specs
# uses the returned addresses to bind each ServerProcessSpec to its
# logical engine rendezvous address; only server-process owners call init().
rank_to_dist_init_addr = {
rank: dist_init_addr
for (rank, _), dist_init_addr in zip(
rank_bundle_idx_list,
ray.get([worker.init_dist_port.remote() for worker in workers]), # type: ignore[attr-defined]
)
dist_init_results = ray.get(
[
worker.init_dist_port.remote() # type: ignore[attr-defined]
for worker in workers
]
)
rank_to_worker = {
rank: worker for worker, (rank, _dist_init_addr) in zip(workers, dist_init_results, strict=True)
}
rank_to_dist_init_addr = dict(dist_init_results)

# Build engine groups and server-process specs from the rank/bundle mapping.
engine_launch_specs = worker_base_cls.build_engine_launch_specs(
rollout_topology = worker_base_cls.build_rollout_topology(
self.config,
rank_bundle_idx_list,
rank_to_dist_init_addr,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit — rank_to_worker is built by iterating zip(workers, dist_init_results, strict=True). This correctly pairs each worker actor with the (rank, dist_init_addr) tuple it returned. Compared to the old code that relied on index-aligned ordering between workers and rank_bundle_idx_list, the new approach is safer since it uses the rank returned by the worker itself.

However, the variable name _dist_init_addr in the comprehension could be clearer — it looks like a bug at first glance (why discard?). A brief comment would help:

Suggested change
)
rank_to_worker = {
rank: worker
for worker, (rank, _dist_init_addr) in zip(workers, dist_init_results, strict=True) # rank from worker
}

# Keep the public metadata mesh compatible with origin/main. Backends
# may expose a different update-weight mesh than their internal launch
# topology, e.g. LMDeploy EP has one logical engine but one public entry
# per request-serving EP rank.
engine_rank_mesh_array = worker_base_cls.build_metadata_engine_rank_mesh_array(engine_launch_specs)

# Launch every server process described by the backend-specific specs.
server_rank_to_url = dict(
ray.get(
[
rank_to_actor[server_process.worker_rank].init.remote( # type: ignore[attr-defined]
engine_launch_spec=engine_spec,
)
for engine_spec in engine_launch_specs
for server_process in engine_spec.server_processes
]
)
server_launch_specs = rollout_topology.server_launch_specs()
server_workers = tuple(
(launch_spec, rank_to_worker[launch_spec.worker_rank]) for launch_spec in server_launch_specs
)
session_url_by_rank = dict(

ray.get(
[
worker.bind_server_launch_spec.remote(launch_spec) # type: ignore[attr-defined]
for launch_spec, worker in server_workers
]
)
init_results = tuple(
ray.get(
[
(
rank_to_actor[server_process.worker_rank].get_session_server_info.remote() # type: ignore[attr-defined]
)
for engine_spec in engine_launch_specs
for server_process in engine_spec.server_processes
worker.init.remote() # type: ignore[attr-defined]
for _launch_spec, worker in server_workers
]
)
)

registry = RolloutWorkerRegistry(
engine_rank_mesh_array=engine_rank_mesh_array,
rollout_config=self.config,
)
for engine_spec in engine_launch_specs:
for server_process in engine_spec.server_processes:
rank = server_process.worker_rank
url = server_rank_to_url[rank]
session_url = session_url_by_rank.get(rank)
if server_process.accepts_rollout_requests and session_url is None:
raise RuntimeError(f"Rollout worker rank={rank} did not return session server URL during init.")
registry.register_started_server(
rank=rank,
actor=rank_to_actor[rank],
server_url=url,
session_url=session_url,
lifecycle_group_ranks=engine_spec.server_worker_ranks,
is_request_entrypoint=server_process.accepts_rollout_requests,
registry = RolloutWorkerRegistry(rollout_topology=rollout_topology, rollout_config=self.config)
for init_result in init_results:
if rollout_topology.is_request_entrypoint_rank(init_result.rank) and init_result.session_url is None:
raise RuntimeError(
f"Rollout worker rank={init_result.rank} did not return session server URL during init."
)
registry.register_started_server(
rank=init_result.rank,
actor=rank_to_worker[init_result.rank],
server_url=init_result.server_url,
session_url=init_result.session_url,
)
Comment thread
YanhuiDua marked this conversation as resolved.

server_process_workers_info = registry.all_workers()
self.logger.info(f"Rollout server-process worker URLs: {[info.url for info in server_process_workers_info]}")
lifecycle_groups = sorted({info.lifecycle_group_ranks for info in server_process_workers_info})
self.logger.info(f"Rollout worker lifecycle groups: {lifecycle_groups}")
rollout_metadata = registry.metadata()
legacy_metadata = rollout_metadata.to_legacy()
self.logger.info(
"Rollout worker registry snapshot: "
f"server_urls={legacy_metadata['server_url_dict']}, "
f"session_urls={legacy_metadata['worker_session_url_dict']}, "
f"server_process_urls={[worker.url for worker in registry.all_workers()]}, "
f"lifecycle_groups={registry.lifecycle_groups()}"
)
return registry


Expand Down
50 changes: 25 additions & 25 deletions xtuner/v1/rl/rollout/health_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
self._worker_health_failure_counts: dict[int, int] = {}
self._stopped = False

def start(self) -> None:
def start_background_checks(self) -> None:
health_thread_alive = self._thread is not None and self._thread.is_alive()
if health_thread_alive:
return
Expand All @@ -73,11 +73,11 @@ def start(self) -> None:
self._stop_event = threading.Event()
self._pause_event = threading.Event()
self._pause_event.set()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread = threading.Thread(target=self._run_background_check_loop, daemon=True)
self._thread.start()
logger.info("RolloutHealthManager started.")
logger.info("RolloutHealthManager background checks started.")

def stop(self) -> None:
def stop_background_checks(self) -> None:
thread = self._thread
if not thread:
return
Expand All @@ -99,19 +99,19 @@ def stop(self) -> None:
self._thread = None
self._stop_event = None
self._pause_event = None
logger.info("RolloutHealthManager stopped.")
logger.info("RolloutHealthManager background checks stopped.")

def pause(self) -> None:
def pause_background_checks(self) -> None:
if self._pause_event is None:
return
self._pause_event.set()
logger.info("RolloutHealthManager paused.")
logger.info("RolloutHealthManager background checks paused.")

def resume(self) -> None:
def resume_background_checks(self) -> None:
if self._pause_event is None:
return
self._pause_event.clear()
logger.info("RolloutHealthManager resumed.")
logger.info("RolloutHealthManager background checks resumed.")

def _is_paused(self) -> bool:
return self._pause_event is None or self._pause_event.is_set()
Expand All @@ -121,20 +121,20 @@ def _is_stopping(self) -> bool:
return self._stopped or (self._stop_event is not None and self._stop_event.is_set())

@contextmanager
def _background_health_checks_paused(self):
def _background_checks_paused(self):
was_paused = self._is_paused()
if not was_paused:
self.pause()
self.pause_background_checks()
try:
yield
finally:
if not was_paused:
self.resume()
self.resume_background_checks()

def restart_inactive_workers(self) -> None:
"""Synchronously restart inactive groups before the next sync-step
weight update."""
with self._background_health_checks_paused():
with self._background_checks_paused():
with self._operation_lock:
failed_groups = list(self._registry.claim_inactive_groups_for_recovery())
if not failed_groups:
Expand Down Expand Up @@ -217,7 +217,7 @@ def check_and_shutdown_inactive_workers(self) -> None:
"""Fail-fast health-check active workers, mark failures inactive, and
shut down every non-active group so shared resources can be reused by
training."""
with self._background_health_checks_paused():
with self._background_checks_paused():
self._check_and_deactivate_failed_worker_groups(fail_fast=True)
with self._operation_lock:
inactive_groups = list(self._registry.claim_inactive_groups_for_recovery())
Expand Down Expand Up @@ -248,7 +248,7 @@ def check_and_shutdown_inactive_workers(self) -> None:
)
)

def run_once(self) -> None:
def run_periodic_health_check(self) -> None:
logger.debug("RolloutHealthManager running health checks for all workers.")
checked_active_count = self._check_and_deactivate_failed_worker_groups()
if self._registry.active_workers() or self._is_stopping():
Expand Down Expand Up @@ -367,9 +367,9 @@ async def check_workers(workers: list[WorkerSnapshot]) -> list[bool]:

return [keep_active_by_rank[worker.rank] for worker in workers_to_check]

def _run_loop(self) -> None:
def _run_background_check_loop(self) -> None:
assert self._stop_event is not None and self._pause_event is not None
logger.info("RolloutHealthManager loop started.")
logger.info("RolloutHealthManager background check loop started.")

while not self._stop_event.is_set():
while self._pause_event.is_set() and not self._stop_event.is_set():
Expand All @@ -385,13 +385,13 @@ def _run_loop(self) -> None:
continue

try:
self.run_once()
self.run_periodic_health_check()
except RuntimeError:
if self._is_stopping():
break
logger.exception("RolloutHealthManager run_once failed.")
logger.exception("RolloutHealthManager periodic health check failed.")
except Exception:
logger.exception("RolloutHealthManager run_once failed.")
logger.exception("RolloutHealthManager periodic health check failed.")

def _shutdown_worker_group(
self,
Expand Down Expand Up @@ -486,8 +486,8 @@ def _restart_worker_group(
)
init_results = ray.get(
[
# init() reuses the immutable launch spec cached on each actor
# during controller startup, including placement bundles and dist addr.
# init() reuses the server launch spec bound during
# controller startup.
worker.actor.init.remote() # type: ignore[attr-defined]
for worker in group.workers
],
Expand All @@ -505,11 +505,11 @@ def _restart_worker_group(
return False

for worker, init_result in zip(group.workers, init_results):
init_rank, init_url = init_result
if init_rank != worker.rank or init_url != worker.url:
if init_result.rank != worker.rank or init_result.server_url != worker.url:
logger.error(
f"Rollout worker restart returned unexpected endpoint: rank={worker.rank}, "
f"init_rank={init_rank}, expected_url={worker.url}, init_url={init_url}."
f"init_rank={init_result.rank}, expected_url={worker.url}, "
f"init_url={init_result.server_url}."
)
self._shutdown_worker_group(group, wait_server_down=False, best_effort=True)
return False
Expand Down
Loading
Loading