diff --git a/tests/rl/test_rl_trainer_checkpoint.py b/tests/rl/test_rl_trainer_checkpoint.py index d8750d545e..b3f2db0bc3 100644 --- a/tests/rl/test_rl_trainer_checkpoint.py +++ b/tests/rl/test_rl_trainer_checkpoint.py @@ -127,11 +127,17 @@ def __init__(self): self.update_weights_count = 0 self.rollout_info = None - def set_train_rollout_mode(self, mode: str): - self.train_rollout_mode = mode - - def update_rollout_info(self, info): + def update_rollout_info( + self, + info, + train_rollout_mode, + weight_update_host, + weight_update_port + ): self.rollout_info = info + self.train_rollout_mode = train_rollout_mode + self.weight_update_host = weight_update_host + self.weight_update_port = weight_update_port def onload(self, target="all"): return f"onload:{target}" diff --git a/tests/rl/test_update_weight_disaggregated.py b/tests/rl/test_update_weight_disaggregated.py index e5d0c4c342..b53c9121d4 100644 --- a/tests/rl/test_update_weight_disaggregated.py +++ b/tests/rl/test_update_weight_disaggregated.py @@ -1,23 +1,20 @@ import os -import hashlib -import sys import tempfile -import time import unittest -from pathlib import Path - import ray -import torch -import torch.distributed as dist +import requests -from xtuner.v1.data_proto.rl_data import SampleParams, RolloutState -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams +from xtuner.v1.model.compose.qwen3_vl import Qwen3VLDense4BConfig +from xtuner.v1.rl.loss import GRPOLossConfig as LossConfig from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.trainer import ( + TrainingController, + TrainingWorker as BaseTrainingWorker, + WorkerConfig, +) from xtuner.v1.rl.utils import ( AcceleratorResourcesConfig, AutoAcceleratorWorkers, @@ -25,93 +22,18 @@ clear_cpu_resource_manager, set_cpu_resource_manager, ) -from xtuner.v1.rl.trainer import WorkerConfig, TrainingController, TrainingWorker as BaseTrainingWorker -from xtuner.v1.rl.loss import GRPOLossConfig as LossConfig -from xtuner.v1.model.moe.qwen3 import Qwen3MoE30BA3Config -from xtuner.v1.utils import ray_method -import re TEST_TEXT_MESSAGES = [{"role": "user", "content": "Hello!"}] -MODEL_PATH = os.environ.get("QWEN3_MOE_PATH") - -def _is_sglang_update_weight_sha256_test_enabled(): - """Return whether the SGLang-side received-weight SHA256 check is enabled. - - This test-only switch controls whether the unit test expects SGLang to - compute and return received bucket hashes for sent/received hash - comparison. - - ! Note that upstream SGLang does not provide this SHA256 check - by default. - """ - return os.environ.get("SGLANG_ENABLE_UPDATE_WEIGHT_SHA256_TEST", "0") == "1" - -class HashingTrainingWorker(BaseTrainingWorker): - _RECEIVED_SHA256_PATTERN = re.compile(r"received_sha256=([0-9a-fA-F]{64})") - def _init_update_weighter(self): - super()._init_update_weighter() - self._test_update_weight_sent_sha256_list = [] - self._test_update_weight_received_sha256_list = [] - - @ray_method - def reset_update_weight_sha256(self): - self._test_update_weight_sent_sha256_list = [] - self._test_update_weight_received_sha256_list = [] - - @ray_method - def get_update_weight_sha256(self): - return { - "rank": self.rank, - "sent_sha256_list": self._test_update_weight_sent_sha256_list, - "received_sha256_list": self._test_update_weight_received_sha256_list, - "bucket_count": len(self._test_update_weight_sent_sha256_list), - } - - def request_update_params(self, state_dict, train_enable_ep=False, finished=False, profile_context=None): - if state_dict and dist.get_rank() == 0: - bucket_sha256 = hashlib.sha256() - for name, tensor in sorted(state_dict.items(), key=lambda x: x[0]): - tensor = tensor.detach().contiguous().cpu() - bucket_sha256.update(name.encode("utf-8")) - bucket_sha256.update(str(tensor.dtype).encode("utf-8")) - bucket_sha256.update(str(tuple(tensor.shape)).encode("utf-8")) - bucket_sha256.update(tensor.view(torch.uint8).numpy().tobytes()) - self._test_update_weight_sent_sha256_list.append(bucket_sha256.hexdigest()) - return super().request_update_params( - state_dict, - train_enable_ep=train_enable_ep, - finished=finished, - ) - - def _hook_compare_test_sent_and_received_weight_hash( - self, - result: dict, - *, - bucket_idx=None, - names=None, - ) -> None: - """Record the received bucket SHA256 returned by SGLang for test comparison. +MODEL_PATH = os.environ["QWEN3_VL_DENSE_PATH"] - This unit-test override parses the SGLang response message and stores the - received bucket hash so the test can compare training-side sent hashes with - rollout-side received hashes. - """ - if not _is_sglang_update_weight_sha256_test_enabled(): - return - if dist.get_rank() != 0: - return - message = result.get("message", "") - match = self._RECEIVED_SHA256_PATTERN.search(message) - if match is not None: - self._test_update_weight_received_sha256_list.append(match.group(1)) - -class TestUpdateWeight(unittest.TestCase): +class TestUpdateWeightDisaggregated(unittest.TestCase): @classmethod def setUpClass(cls) -> None: if MODEL_PATH is None: - raise unittest.SkipTest("QWEN3_MOE_PATH is not set") + raise unittest.SkipTest("MODEL_PATH is not set") os.environ["XTUNER_USE_FA3"] = "1" - # TODO(shipengcheng) 当前训推分离sglang的权重同步不能用NCCL_CUMEM,后面需要排查一下原因 + # TODO(shipengcheng): SGLang disaggregated weight update cannot use + # NCCL_CUMEM for now. Remove this after the root cause is fixed. os.environ["NCCL_CUMEM_ENABLE"] = "0" @classmethod @@ -124,9 +46,12 @@ def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") self.init_config() - self.pg = AutoAcceleratorWorkers.build_placement_group( - self.train_resources_cfg, - name=f"test_update_weight_train_{id(self)}", + self.train_pg = AutoAcceleratorWorkers.build_placement_group(self.train_resources_cfg, + name=f"test_update_weight_train_{id(self)}") + self.rollout_pg = AutoAcceleratorWorkers.build_placement_group(self.rollout_resources_cfg, + name=f"test_update_weight_rollout_{id(self)}") + set_cpu_resource_manager( + CPUResourceManager(accelerator_placement_groups=[self.train_pg, self.rollout_pg]) ) def tearDown(self): @@ -137,44 +62,40 @@ def tearDown(self): def init_config(self): train_num_workers = int(os.environ.get("TRAIN_NUM_WORKERS", "4")) rollout_num_workers = int(os.environ.get("ROLLOUT_NUM_WORKERS", "4")) - rollout_tp_size = int(os.environ.get("ROLLOUT_TP_SIZE", str(rollout_num_workers))) - rollout_ep_size = int(os.environ.get("ROLLOUT_EP_SIZE", "1")) - train_ep_size = int(os.environ.get("TRAIN_EP_SIZE", "1")) self.train_resources_cfg = AcceleratorResourcesConfig( accelerator="GPU", num_workers=train_num_workers, - num_cpus_per_worker=float(os.environ.get("TRAIN_CPUS_PER_WORKER", "12")), - cpu_memory_per_worker=int(os.environ.get("TRAIN_CPU_MEMORY_PER_WORKER", str(16 * 1024**3))), + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, ) self.rollout_resources_cfg = AcceleratorResourcesConfig( accelerator="GPU", num_workers=rollout_num_workers, - num_cpus_per_worker=float(os.environ.get("ROLLOUT_CPUS_PER_WORKER", "12")), - cpu_memory_per_worker=int(os.environ.get("ROLLOUT_CPU_MEMORY_PER_WORKER", str(16 * 1024**3))), + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, ) self.rollout_cfg = RolloutConfig( env="test_rollout", model_path=MODEL_PATH, model_name=os.path.basename(MODEL_PATH).lower(), tokenizer_path=MODEL_PATH, - rollout_cross_node_comm=os.environ.get("XTUNER_USE_SGLANG", "0") != "0", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpus_per_node=int(os.environ.get("GPUS_PER_NODE", "8")), # gpu: 8, npu: 16 + rollout_cross_node_comm=False, + tensor_parallel_size=int(os.environ.get("ROLLOUT_TP_SIZE", "4")), + expert_parallel_size=1, + gpus_per_node=int(os.environ.get("GPUS_PER_NODE", "8")), dtype="bfloat16", skip_load_weights=True, - context_length=int(os.environ.get("ROLLOUT_CONTEXT_LENGTH", "256")), + context_length=256, worker_log_dir=self.worker_log_dir, gpu_memory_utilization=float(os.environ.get("ROLLOUT_GPU_MEMORY_UTILIZATION", "0.5")), ) - # training config - model_cfg = Qwen3MoE30BA3Config() - optim_cfg: AdamWConfig = AdamWConfig(lr=5e-7, foreach=False) - fsdp_cfg: FSDPConfig = FSDPConfig(ep_size=train_ep_size) + model_cfg = Qwen3VLDense4BConfig() + optim_cfg = AdamWConfig(lr=5e-7, foreach=False) + fsdp_cfg = FSDPConfig(ep_size=1) lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=5e-7) - self.worker_cfg: WorkerConfig = WorkerConfig( + self.worker_cfg = WorkerConfig( model_cfg=model_cfg, optim_cfg=optim_cfg, loss_cfg=LossConfig( @@ -187,7 +108,8 @@ def init_config(self): use_kl_loss=False, kl_loss_coef=0.001, kl_loss_type="low_var_kl", - mode="eager"), + mode="eager", + ), lr_cfg=lr_cfg, fsdp_cfg=fsdp_cfg, load_from=MODEL_PATH, @@ -195,34 +117,27 @@ def init_config(self): pack_max_length=1024, ) - def _build_train_controller(self, worker_cls=BaseTrainingWorker): - TrainingWorker = ray.remote( - runtime_env={ - "env_vars": { - "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", - "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", - } - }, - )(worker_cls) - train_workers, _ = AutoAcceleratorWorkers.from_placement_group( - TrainingWorker, self.worker_cfg, self.pg - ) - ray.get([worker.test_all_reduce.remote() for worker in train_workers]) - train_controller = TrainingController(workers=train_workers) - return train_controller - - def _build_sglang_rollout_controller(self): - rollout_pg = AutoAcceleratorWorkers.build_placement_group( - self.rollout_resources_cfg, - name=f"test_update_weight_rollout_{id(self)}", - ) - set_cpu_resource_manager(CPUResourceManager(accelerator_placement_groups=[self.pg, rollout_pg])) - self.rollout_cfg.skip_load_weights = False - return self.rollout_cfg.build(rollout_pg) + def _check_sglang_weights(self, rollout_controller, action): + info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) + active_urls = [ + url + for url, is_active in info_dict["worker_server_urls_status"].items() + if is_active + ] + self.assertGreater(len(active_urls), 0) + results = [] + for url in active_urls: + response = requests.post( + f"{url}/weights_checker", + json={"action": action}, + timeout=300, + ) + response.raise_for_status() + results.append(response.json()) + return results @unittest.skipIf(os.environ.get("XTUNER_USE_SGLANG", "0") == "0", "sglang backend is not enabled") def test_sglang_disaggregated_update_weight_and_generate(self): - # init train on a dedicated placement group TrainingWorker = ray.remote( runtime_env={ "env_vars": { @@ -232,176 +147,95 @@ def test_sglang_disaggregated_update_weight_and_generate(self): }, )(BaseTrainingWorker) train_workers, _ = AutoAcceleratorWorkers.from_placement_group( - TrainingWorker, self.worker_cfg, self.pg + TrainingWorker, self.worker_cfg, self.train_pg ) - futures = [worker.test_all_reduce.remote() for worker in train_workers] - ray.get(futures) + ray.get([worker.test_all_reduce.remote() for worker in train_workers]) train_controller = TrainingController(workers=train_workers) - - # init rollout on a separate placement group - rollout_pg = AutoAcceleratorWorkers.build_placement_group( - self.rollout_resources_cfg, - name=f"test_update_weight_rollout_{id(self)}", - ) - set_cpu_resource_manager(CPUResourceManager(accelerator_placement_groups=[self.pg, rollout_pg])) + self.rollout_cfg.skip_load_weights = False - rollout_controller = self.rollout_cfg.build(rollout_pg) + rollout_controller = self.rollout_cfg.build(self.rollout_pg) sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1) input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params) res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) - train_controller.update_rollout_info(info_dict) - train_controller.set_train_rollout_mode("disaggregated") - + train_controller.update_rollout_info(info_dict, train_rollout_mode="disaggregated") train_controller.update_weights() res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) self.assertEqual(res_update_weight.response, res_baseline.response) ray.get(rollout_controller.shutdown.remote(), timeout=60) - @unittest.skipIf(os.environ.get("XTUNER_USE_SGLANG", "0") == "0", "sglang backend is not enabled") - def test_sglang_disaggregated_update_weight_after_pause_and_generate(self): - train_controller = self._build_train_controller() - rollout_controller = self._build_sglang_rollout_controller() - - sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1) - input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params) - res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) - - info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) - train_controller.update_rollout_info(info_dict) - train_controller.set_train_rollout_mode("disaggregated") - - ray.get(rollout_controller.pause_generation.remote()) - time.sleep(float(os.environ.get("XTUNER_UPDATE_WEIGHT_PAUSE_SLEEP", "2"))) - train_controller.update_weights() - ray.get(rollout_controller.continue_generation.remote()) - - res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) - self.assertEqual(res_update_weight.response, res_baseline.response) - ray.get(rollout_controller.shutdown.remote(), timeout=60) - - @unittest.skipIf(os.environ.get("XTUNER_USE_SGLANG", "0") == "0", "sglang backend is not enabled") - def test_sglang_disaggregated_update_weight_sha256_is_stable(self): - train_controller = self._build_train_controller(worker_cls=HashingTrainingWorker) - rollout_controller = self._build_sglang_rollout_controller() - - info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) - train_controller.update_rollout_info(info_dict) - train_controller.set_train_rollout_mode("disaggregated") - - ray.get([worker.reset_update_weight_sha256.remote() for worker in train_controller.workers]) - train_controller.update_weights() - first_hashes = ray.get([worker.get_update_weight_sha256.remote() for worker in train_controller.workers]) + @unittest.skip("skip sglang parameter-only weight check test until the parameter-check-only patch is applied") + def test_sglang_disaggregated_update_weight_equal_after_reset(self): + # This test verifies SGLang rollout weight update correctness with a parameter-only check. + # The SGLang parameter-only WeightChecker actions are implemented in + # https://github.com/PengchengShi00/sglang/commit/05e89d63b5a1a80671b267ff4494ad950b2aba75. + # Flow: snapshot_parameters -> reset_parameters -> update_weights -> compare_parameters. + TrainingWorker = ray.remote( + runtime_env={ + "env_vars": { + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", + } + }, + )(BaseTrainingWorker) + train_workers, _ = AutoAcceleratorWorkers.from_placement_group( + TrainingWorker, self.worker_cfg, self.train_pg + ) + ray.get([worker.test_all_reduce.remote() for worker in train_workers]) + train_controller = TrainingController(workers=train_workers) - ray.get([worker.reset_update_weight_sha256.remote() for worker in train_controller.workers]) - train_controller.update_weights() - second_hashes = ray.get([worker.get_update_weight_sha256.remote() for worker in train_controller.workers]) + self.rollout_cfg.skip_load_weights = False + rollout_controller = self.rollout_cfg.build(self.rollout_pg) - first_rank0_hash = next(item for item in first_hashes if item["rank"] == 0) - second_rank0_hash = next(item for item in second_hashes if item["rank"] == 0) - self.assertGreater(first_rank0_hash["bucket_count"], 0) - self.assertEqual(first_rank0_hash["bucket_count"], second_rank0_hash["bucket_count"]) - self.assertEqual(first_rank0_hash["sent_sha256_list"], second_rank0_hash["sent_sha256_list"]) - if _is_sglang_update_weight_sha256_test_enabled(): - self.assertEqual(first_rank0_hash["bucket_count"], len(first_rank0_hash["received_sha256_list"])) - self.assertEqual(first_rank0_hash["sent_sha256_list"], first_rank0_hash["received_sha256_list"]) - self.assertEqual(second_rank0_hash["bucket_count"], len(second_rank0_hash["received_sha256_list"])) - self.assertEqual(second_rank0_hash["sent_sha256_list"], second_rank0_hash["received_sha256_list"]) - self.assertEqual(first_rank0_hash["received_sha256_list"], second_rank0_hash["received_sha256_list"]) + try: + self._check_sglang_weights(rollout_controller, action="snapshot_parameters") + self._check_sglang_weights(rollout_controller, action="reset_parameters") - ray.get(rollout_controller.shutdown.remote(), timeout=60) + info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) + train_controller.update_rollout_info(info_dict, train_rollout_mode="disaggregated") + train_controller.update_weights() - def _build_lmdeploy_rollout_controller(self): - rollout_pg = AutoAcceleratorWorkers.build_placement_group( - self.rollout_resources_cfg, - name=f"test_update_weight_rollout_{id(self)}", - ) - set_cpu_resource_manager(CPUResourceManager(accelerator_placement_groups=[self.pg, rollout_pg])) - self.rollout_cfg.skip_load_weights = False - return self.rollout_cfg.build(rollout_pg) + self._check_sglang_weights(rollout_controller, action="compare_parameters") + finally: + ray.get(rollout_controller.shutdown.remote(), timeout=60) @unittest.skip("skip lmdeploy disaggregated update-weight generation test until PR4638 is merged") def test_lmdeploy_disaggregated_update_weight_and_generate(self): - train_controller = self._build_train_controller() - rollout_controller = self._build_lmdeploy_rollout_controller() - - sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1) - input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params) - res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) - - info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) - train_controller.update_rollout_info(info_dict) - train_controller.set_train_rollout_mode("disaggregated") - - train_controller.update_weights() - - res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) - self.assertEqual(res_update_weight.response, res_baseline.response) - ray.get(rollout_controller.shutdown.remote(), timeout=60) + TrainingWorker = ray.remote( + runtime_env={ + "env_vars": { + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", + } + }, + )(BaseTrainingWorker) + train_workers, _ = AutoAcceleratorWorkers.from_placement_group( + TrainingWorker, self.worker_cfg, self.train_pg + ) + ray.get([worker.test_all_reduce.remote() for worker in train_workers]) + train_controller = TrainingController(workers=train_workers) - @unittest.skip("skip lmdeploy disaggregated update-weight generation test until PR4638 is merged") - def test_lmdeploy_disaggregated_update_weight_after_pause_and_generate(self): - train_controller = self._build_train_controller() - rollout_controller = self._build_lmdeploy_rollout_controller() + self.rollout_cfg.skip_load_weights = False + self.rollout_cfg.extra_rollout_config = { + "lmdeploy_backend": "pytorch", + "lmdeploy_distributed_executor_backend": "ray", + } + rollout_controller = self.rollout_cfg.build(self.rollout_pg) sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1) input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params) res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) - train_controller.update_rollout_info(info_dict) - train_controller.set_train_rollout_mode("disaggregated") - - ray.get(rollout_controller.pause_generation.remote()) - time.sleep(float(os.environ.get("XTUNER_UPDATE_WEIGHT_PAUSE_SLEEP", "2"))) + train_controller.update_rollout_info(info_dict, train_rollout_mode="disaggregated") train_controller.update_weights() - ray.get(rollout_controller.continue_generation.remote()) res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) self.assertEqual(res_update_weight.response, res_baseline.response) ray.get(rollout_controller.shutdown.remote(), timeout=60) - @unittest.skip("skip lmdeploy disaggregated update-weight generation test until PR4638 is merged") - def test_lmdeploy_disaggregated_multi_update_and_generate(self): - """Drive N consecutive update_weights+generate cycles on a single rollout engine. - - LMDeploy's PyTorch backend runs a per-FusedMoE ``update_weights()`` finalize that - REPLACES ``gate_up.weight`` / ``down.weight`` Parameter objects (see - ``lmdeploy/pytorch/nn/moe/default.py`` ``LinearWeights.update_weight``). The CUDA-graph - staleness this introduces is handled by ``reset_graph_runner()`` inside the finalize, - but the second-round behaviour of the transpose-contig-transpose layout transform is - untested. This test catches any regression in back-to-back updates without sleep/wakeup - between them. Same method also exercises ascend / NPU where the finalize is a no-op - and graph capture is disabled (eager mode), so it should be trivially safe there. - """ - train_controller = self._build_train_controller() - rollout_controller = self._build_lmdeploy_rollout_controller() - - sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1) - input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params) - res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) - - info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) - train_controller.update_rollout_info(info_dict) - train_controller.set_train_rollout_mode("disaggregated") - - # Trainer never actually steps, so each broadcast carries the same bytes; - # the rollout response should remain identical to baseline across all rounds. - num_iterations = int(os.environ.get("XTUNER_LMDEPLOY_MULTI_UPDATE_ITERS", "2")) - for i in range(num_iterations): - train_controller.update_weights() - res = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) - self.assertEqual( - res.response, - res_baseline.response, - f"iteration {i}: response diverged from baseline after multi-update", - ) - - ray.get(rollout_controller.shutdown.remote(), timeout=60) - - if __name__ == "__main__": unittest.main() diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py index c8c0ce2466..c50ffa6bfb 100644 --- a/xtuner/v1/rl/rollout/worker.py +++ b/xtuner/v1/rl/rollout/worker.py @@ -144,6 +144,10 @@ class RolloutConfig(BaseModel): gpu_memory_utilization (float): GPU memory utilization ratio. Defaults to 0.85. random_seed (int): Random seed for reproducible generation. Defaults to 1024. rollout_cross_node_comm (bool): Enable cross-node communication. Defaults to False. + weight_update_host (Optional[str]): Host used by train rank 0 to initialize the external NCCL weight update + group. Defaults to None. + weight_update_port (Optional[int]): Port used by train rank 0 to initialize the external NCCL weight update + group. Defaults to 30000. rollout_max_batch_size_per_instance (int): Maximum batch size for the rollout worker. If not set, it will be determined automatically based on `context_length`. Defaults to 512. allow_over_concurrency_ratio (float): Deprecated compatibility option. Rollout runtime concurrency is @@ -223,6 +227,26 @@ class RolloutConfig(BaseModel): help="Base port number for distributed communication among rollout workers.", ), ] = 25000 + weight_update_host: Annotated[ + Optional[str], + Parameter( + group=infer_group, + help=( + "Host used by train rank 0 to initialize the external NCCL weight update group. " + "Only used for NCCL weight update." + ), + ), + ] = None + weight_update_port: Annotated[ + Optional[int], + Parameter( + group=infer_group, + help=( + "Port used by train rank 0 to initialize the external NCCL weight update group. " + "Only used for NCCL weight update." + ), + ), + ] = 30000 rollout_max_batch_size_per_instance: Annotated[ Optional[int], Parameter( diff --git a/xtuner/v1/rl/trainer/controller.py b/xtuner/v1/rl/trainer/controller.py index 92b4575398..d0965f2167 100644 --- a/xtuner/v1/rl/trainer/controller.py +++ b/xtuner/v1/rl/trainer/controller.py @@ -290,11 +290,18 @@ def onload(self, target: Literal["model", "optimizer", "all"] = "all"): ray.get([worker.onload_optimizer.remote() for worker in self.workers], timeout=TRAIN_RAY_GET_TIMEOUT) # type: ignore return - def update_rollout_info(self, info_dict): - ray.get([worker.update_rollout_info.remote(**info_dict) for worker in self.workers]) # type: ignore[attr-defined] - - def set_train_rollout_mode(self, train_rollout_mode: str): - ray.get([worker.set_train_rollout_mode.remote(train_rollout_mode) for worker in self.workers]) + def update_rollout_info(self, info_dict, train_rollout_mode, weight_update_host=None, weight_update_port=None): + ray.get( + [ + worker.update_rollout_info.remote( + **info_dict, + train_rollout_mode=train_rollout_mode, + weight_update_host=weight_update_host, + weight_update_port=weight_update_port, + ) + for worker in self.workers + ] + ) def update_weights(self): """Update the weights of the training workers.""" diff --git a/xtuner/v1/rl/trainer/update_weighter.py b/xtuner/v1/rl/trainer/update_weighter.py deleted file mode 100644 index 00b86d8256..0000000000 --- a/xtuner/v1/rl/trainer/update_weighter.py +++ /dev/null @@ -1,1144 +0,0 @@ -import json -import os -import socket -from concurrent.futures import ThreadPoolExecutor -from datetime import timedelta -from itertools import chain -from threading import Lock -from typing import Any, Dict, List, TypeAlias, cast - -import requests -import torch -import torch.distributed as dist -import tqdm -from packaging.version import parse as parse_version -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.distributed_c10d import ( - Backend, - PrefixStore, - Store, - _new_process_group_helper, - _world, - default_pg_timeout, - rendezvous, -) -from torch.distributed.tensor import DTensor - -from xtuner.v1.model.compose.base import BaseComposeConfig -from xtuner.v1.model.compose.qwen3_vl import Qwen3VLForConditionalGeneration -from xtuner.v1.model.moe.moe import MoE -from xtuner.v1.rl.rollout.worker import RolloutConfig -from xtuner.v1.utils import ( - get_device, - get_torch_device_module, - monkey_unpatch_torch_reductions, - ray_method, -) -from xtuner.v1.utils.load_spec import LoadEnum, LoadSpec - - -DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices -ServiceUrlMap: TypeAlias = Dict[int, str] # A dictionary mapping service names to their URLs -RolloutEngineInfo: TypeAlias = list[tuple[int, str, int]] # (rollout rank, server url, engine gpu count) -DEVICE = get_device() -DEVICE_MODULE = get_torch_device_module() - - -class UpdateWeighter: - rank: int - logger: Any - config: Any - - def _init_update_weighter(self): - # Used to update weight to rollout engine - self.rollout_device_mesh: DeviceMesh | None = None - self.rollout_url: str | None = None - self.rollout_cfg_info: dict = dict() - self.endpoints: dict[str, str] = dict() - self.endpoints["update_weights"] = "update_weights" - - self.rollout_engine_rank_mesh_array: DeviceMeshRaw = [] - self.rollout_server_url_dict: ServiceUrlMap = {} - self.worker_server_urls_status: dict[str, bool] = {} - - self._global_hf_keys_mapping_cache: dict[str, list[str]] = dict() - self._default_ipc_tensor_bytes: int = int(self.config.update_weight_bucket_size_in_gb * 1024**3) - self._ipc_tensor_bytes_dict_by_dtype: dict[torch.dtype, int] = {} - self._update_params_ipc_tensor_dict_by_dtype: dict[torch.dtype, torch.Tensor] = {} - self._last_update_params_ipc_tensor_dtype: torch.dtype | None = None - self._update_params_ipc_event = None - self._sglang_disagg_group: dist.ProcessGroup | None = None - self._sglang_disagg_group_name: str | None = None - self._sglang_disagg_engine_urls: list[str] = [] - self._sglang_disagg_executor: ThreadPoolExecutor | None = None - self._train_update_sync_group: dist.ProcessGroup | None = None - self._sglang_disagg_update_lock = Lock() - self._lmdeploy_disagg_group: dist.ProcessGroup | None = None - self._lmdeploy_disagg_group_name: str | None = None - self._lmdeploy_disagg_engine_urls: list[str] = [] - self._lmdeploy_disagg_executor: ThreadPoolExecutor | None = None - self._lmdeploy_disagg_update_lock = Lock() - self.use_fake_weight_update = ( - False # 仅在 lmdeploy turbomind 后端的 disaggregated 模式下使用,表示是否使用 fake 接口进行权重更新 - ) - - def _hook_compare_test_sent_and_received_weight_hash( - self, - result: dict[str, Any], - *, - bucket_idx: int | None = None, - names: list[str] | None = None, - ) -> None: - """Test hook for comparing sent and received weight hashes. - - This hook is intentionally a no-op in production code and is expected to be overridden in unit tests that need - to compare training-side sent hashes with rollout-side received hashes returned by SGLang. - """ - return - - @ray_method - def update_rollout_info( - self, - engine_rank_mesh_array: DeviceMeshRaw, - server_url_dict: ServiceUrlMap, - rollout_config: RolloutConfig, - worker_server_urls_status: Dict[str, bool], - worker_session_url_dict: ServiceUrlMap | None = None, - worker_session_urls_status: Dict[str, bool] | None = None, - ): - """Update the rollout information for the training worker.""" - tp = rollout_config.tensor_parallel_size - ep = rollout_config.expert_parallel_size - assert tp == 1 or ep == 1, "Either tensor parallel size or engine parallel size must be 1." - rollout_server_url = server_url_dict.get(self.rank, "") - if worker_server_urls_status.get(rollout_server_url, "False") is False: - self.logger.error(f"Rollout server url {rollout_server_url} is not available.") - self.rollout_url = None - else: - self.rollout_url = rollout_server_url - - self.rollout_engine_rank_mesh_array = [[int(rank) for rank in ranks] for ranks in engine_rank_mesh_array] - self.rollout_server_url_dict = {int(rank): url for rank, url in server_url_dict.items()} - self.worker_server_urls_status = worker_server_urls_status - - self.rollout_cfg_info["tp"] = tp - self.rollout_cfg_info["ep"] = ep - self.rollout_cfg_info["api_key"] = rollout_config.api_key - if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": - self.rollout_cfg_info["backend"] = "sglang" - elif os.environ.get("XTUNER_USE_VLLM", "0") == "1": - self.rollout_cfg_info["backend"] = "vllm" - else: - self.rollout_cfg_info["backend"] = (rollout_config.extra_rollout_config or dict()).get( - "lmdeploy_backend", "pytorch" - ) - - def _ensure_rollout_device_mesh(self) -> DeviceMesh: - if self.rollout_device_mesh is None: - # 非共卡 SGLang 不使用这个 mesh;只有共卡/旧权重同步路径需要 - # 用 rollout rank 构造 torch DeviceMesh。 - self.rollout_device_mesh = DeviceMesh( - "cpu", - mesh=self.rollout_engine_rank_mesh_array, - mesh_dim_names=("engine_instance", "engine_parallel"), - ) - return self.rollout_device_mesh - - @ray_method - def set_train_rollout_mode(self, train_rollout_mode: str): - mode = train_rollout_mode.lower() - if mode == "colocate": - self.is_train_rollout_colocated = True - elif mode == "disaggregated": - self.is_train_rollout_colocated = False - - backend = self.rollout_cfg_info.get("backend", "").lower() - if backend == "vllm": - raise NotImplementedError("Disaggregated train-rollout mode is not supported for vLLM backend.") - - elif backend == "pytorch": - self.use_fake_weight_update = False - - elif backend == "turbomind": - self.logger.warning( - "Disaggregated train-rollout mode for lmdeploy turbomind backend is not yet supported. " - "A fake no-op interface will be used temporarily.", - ) - self.use_fake_weight_update = True # 后续 fake 接口可根据这个标志跳过实际同步 - - elif backend == "sglang": - self.use_fake_weight_update = False - else: - raise ValueError( - f"Unsupported rollout backend for disaggregated mode: {backend!r}. " - "Expected 'vllm', 'pytorch', 'turbomind' or 'sglang'." - ) - - else: - raise ValueError( - f"Unsupported train_rollout_mode: {train_rollout_mode!r}. Expected 'colocate' or 'disaggregated'." - ) - - if self.is_train_rollout_colocated: - self._reset_sglang_disagg_group() - self._reset_lmdeploy_disagg_group() - - def _reset_sglang_disagg_group(self): - if self._sglang_disagg_executor is not None: - self._sglang_disagg_executor.shutdown(wait=False, cancel_futures=True) - try: - if self._sglang_disagg_group is not None: - dist.destroy_process_group(self._sglang_disagg_group) - except Exception: - pass - self._sglang_disagg_group = None - self._sglang_disagg_group_name = None - self._sglang_disagg_engine_urls = [] - self._sglang_disagg_executor = None - - def _reset_lmdeploy_disagg_group(self): - if self._lmdeploy_disagg_executor is not None: - self._lmdeploy_disagg_executor.shutdown(wait=False, cancel_futures=True) - try: - if self._lmdeploy_disagg_group is not None: - dist.destroy_process_group(self._lmdeploy_disagg_group) - except Exception: - pass - self._lmdeploy_disagg_group = None - self._lmdeploy_disagg_group_name = None - self._lmdeploy_disagg_engine_urls = [] - self._lmdeploy_disagg_executor = None - - def _get_train_update_sync_group(self) -> dist.ProcessGroup: - if self._train_update_sync_group is None: - ranks = list(range(dist.get_world_size())) - self._train_update_sync_group = dist.new_group(ranks=ranks, backend="gloo") - return self._train_update_sync_group - - @ray_method - def update_weights(self): - """Update the model weights.""" - if not hasattr(self, "is_train_rollout_colocated"): - raise RuntimeError( - "train/rollout mode is not set. Please call set_train_rollout_mode() before update_weights()." - ) - - if self.is_train_rollout_colocated: - self._update_weights_colocated() - else: - self._update_weights_disaggregated() - - def _update_weights_colocated(self): - DEVICE_MODULE.empty_cache() - self._update_params_ipc_event = DEVICE_MODULE.Event(interprocess=True) - if self.rollout_cfg_info.get("backend") == "turbomind": - self._update_weights_by_layer() - else: - if isinstance(self.config.model_cfg, BaseComposeConfig): - self._update_weights_hf_generator(submodule="language_model", final_update=False) - self._update_weights_hf_generator(submodule="vision_tower", final_update=False) - self._update_weights_hf_generator(submodule="multi_modal_projector", final_update=True) - else: - self._update_weights_hf_generator(final_update=True) - self._update_params_ipc_tensor_dict_by_dtype = {} - self._last_update_params_ipc_tensor_dtype = None - self._update_params_ipc_event = None - - DEVICE_MODULE.empty_cache() - - def _update_weights_disaggregated(self): - if self.use_fake_weight_update: - self.logger.warning( - "Using fake weight update interface, no actual weight synchronization will happen. This is only for testing purposes and should not be used in production." - ) - return - - DEVICE_MODULE.empty_cache() - try: - if isinstance(self.config.model_cfg, BaseComposeConfig): - self._update_weights_hf_generator(submodule="language_model", final_update=False) - self._update_weights_hf_generator(submodule="vision_tower", final_update=False) - self._update_weights_hf_generator(submodule="multi_modal_projector", final_update=True) - else: - self._update_weights_hf_generator(final_update=True) - finally: - DEVICE_MODULE.empty_cache() - - def _rl_get_fused_ep_hf_param(self, model: MoE, target_ep_rank: int, target_ep_size: int, bucket_size: int): - fused_param_groups: list[tuple[torch.Tensor, LoadSpec]] = model._group_param_by_load_spec(LoadEnum.FUSED) - model_ep_size = 1 if model.fsdp_config is None else model.fsdp_config.ep_size - if not fused_param_groups: - return - - def _get_hf_params( - fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]], - ) -> tuple[list[torch.Tensor], list[str]]: - hf_keys_list: list[str] = [] - hf_tensor_list: list[torch.Tensor] = [] - - for fsdp_tensor, load_spec in fsdp_tensor_list: - hf_keys = load_spec.hf_keys - if model_ep_size > 1 and model.ep_mesh is not None: - if load_spec.name not in self._global_hf_keys_mapping_cache: - global_hf_keys: list[list[str] | None] = [None] * model_ep_size - dist.all_gather_object(global_hf_keys, hf_keys, group=model.ep_mesh.get_group()) - global_hf_keys_gathered = cast(list[list[str]], global_hf_keys) - self._global_hf_keys_mapping_cache[load_spec.name] = list( - chain.from_iterable(global_hf_keys_gathered) - ) - hf_keys = self._global_hf_keys_mapping_cache[load_spec.name] - - fused_full_tensor = fsdp_tensor.bfloat16() - if isinstance(fused_full_tensor, DTensor): - fused_full_tensor = fused_full_tensor.full_tensor() - dim = cast(int, load_spec.dim) - num_split = len(hf_keys) - hf_tensor_size = fused_full_tensor.shape[dim] / num_split - assert hf_tensor_size.is_integer(), "Internal Error, hf_tensor_size is not integer" - hf_tensor_size = int(hf_tensor_size) - - hf_tensor = fused_full_tensor.split([hf_tensor_size] * num_split, dim=dim) - assert num_split % target_ep_size == 0, ( - f"len(hf_keys) of '{hf_keys}' is {num_split}, it must be divisible by target_ep_size {target_ep_size}" - ) - start_idx = (num_split // target_ep_size) * target_ep_rank - end_idx = (num_split // target_ep_size) * (target_ep_rank + 1) - - hf_keys_list.extend(hf_keys[start_idx:end_idx]) - hf_tensor_list.extend(hf_tensor[start_idx:end_idx]) - - hf_tensor_list = [ - model.param_to_safetensor(safetensor, name) for safetensor, name in zip(hf_tensor_list, hf_keys_list) - ] - - return hf_tensor_list, hf_keys_list - - safetensor_size = 0 - dtype = torch.bfloat16 - tensor_list: list[tuple[torch.Tensor, LoadSpec]] = [] - - for param, load_spec in fused_param_groups: - tensor_size = dtype.itemsize * param.numel() // target_ep_size - if safetensor_size + tensor_size > bucket_size and tensor_list: - hf_params, name_list = _get_hf_params(tensor_list) - yield name_list, hf_params - safetensor_size = tensor_size - name_list = load_spec.hf_keys.copy() - tensor_list = [(param, load_spec)] - continue - safetensor_size += tensor_size - tensor_list.append((param, load_spec)) - - if tensor_list: - hf_params, name_list = _get_hf_params(tensor_list) - yield name_list, hf_params - - @torch.no_grad() - def _update_weights_hf_generator(self, submodule=None, final_update=False): - """Update the model weights.""" - self.endpoints["update_weights"] = "update_weights" - - model = self._engine.model - if submodule: - model = getattr(model, submodule) - - dtype = torch.bfloat16 - bucket_size = int(self.config.update_weight_bucket_size_in_gb * 1024**3) - same_gen = model._get_same_hf_param( - model._group_param_by_load_spec(LoadEnum.SAME), - dtype=dtype, - device=DEVICE, - bucket_size=bucket_size, - ) - - train_enable_ep = model.fsdp_config is not None and model.fsdp_config.ep_size > 1 - - if train_enable_ep: - if self.is_train_rollout_colocated and self.rollout_cfg_info["ep"] > 1: - rollout_device_mesh = self._ensure_rollout_device_mesh() - fused_gen = self._rl_get_fused_ep_hf_param( - model, - target_ep_rank=rollout_device_mesh["engine_parallel"].get_coordinate()[0], - target_ep_size=rollout_device_mesh["engine_parallel"].size(), - bucket_size=bucket_size, - ) - else: - # Disaggregated update uses one external trainer+rollout process group. - # Broadcast the same full expert bucket to every rollout rank and let - # the backend loader apply its local TP/EP slicing. - fused_gen = self._rl_get_fused_ep_hf_param( - model, - target_ep_rank=0, - target_ep_size=1, - bucket_size=bucket_size, - ) - else: - fused_gen = model._get_fused_hf_param( - model._group_param_by_load_spec(LoadEnum.FUSED), - dtype=dtype, - device=DEVICE, - bucket_size=bucket_size, - update_weights_for_rl=True, - ) - shard_gen = model._get_shard_hf_param( - model._group_param_by_load_spec(LoadEnum.SHARD), - dtype=dtype, - device=DEVICE, - bucket_size=bucket_size, - ) - - for name_list, fused_param_list in fused_gen: - state_dict = {name: param.detach() for name, param in zip(name_list, fused_param_list)} - self.request_update_params(state_dict, train_enable_ep=train_enable_ep, finished=False) - del state_dict, name_list, fused_param_list - - for name_list, param_list in chain(same_gen, shard_gen): - state_dict = {name: param.detach() for name, param in zip(name_list, param_list)} - self.request_update_params(state_dict, train_enable_ep=train_enable_ep, finished=False) - del state_dict, name_list, param_list - - if self.rollout_cfg_info["backend"] in ("pytorch", "vllm") and final_update: - self.request_update_params({}, train_enable_ep=train_enable_ep, finished=True) - - if self.is_train_rollout_colocated: - dist.barrier() - else: - dist.barrier(group=self._get_train_update_sync_group()) - DEVICE_MODULE.empty_cache() - return - - def _update_weights_by_layer(self): - """Update the model weights.""" - self.endpoints["update_weights"] = "update_weights" - assert self.rollout_device_mesh is not None - - model = self._engine.model - DEVICE_MODULE.empty_cache() - - if isinstance(model.config, BaseComposeConfig): - # TODO: support float8 for vision compose model - dtype = torch.bfloat16 - else: - if (model.config.float8_cfg is not None) and (model.config.float8_cfg.enable_float8): - dtype = torch.float8_e4m3fn - else: - dtype = torch.bfloat16 - - def get_params(tensor_list, name_list, save_dtype): - _tensor_list, _spec_list = list(zip(*tensor_list)) - fsdp_unshard_tensor_list = model._fsdp_foreach_allgather(_tensor_list, _spec_list) - if save_dtype == torch.float8_e4m3fn: - fsdp_unshard_tensor_list, name_list = model._to_float8( - fsdp_unshard_tensor_list, name_list, _tensor_list, save_dtype - ) - return fsdp_unshard_tensor_list, name_list - - saved_list = [] - is_qwen3vl = False - if isinstance(model.config, BaseComposeConfig): - language_model = model.language_model - if isinstance(model, Qwen3VLForConditionalGeneration): - is_qwen3vl = True - else: - language_model = model - - if is_qwen3vl: - vision_hf_prefix = "model.visual." - projector_hf_prefix = "model.visual." - else: - vision_hf_prefix = "model.vision_tower." - projector_hf_prefix = "model.multi_modal_projector." - - for i, layer in tqdm.tqdm(language_model.layers.items(), desc="[gather weight]"): - tensor_list = [] - name_list = [] - for sub_name, param in layer.state_dict().items(): - if isinstance(model.config, BaseComposeConfig): - saved_list.append(f"language_model.layers.{i}.{sub_name}") - else: - saved_list.append(f"layers.{i}.{sub_name}") - local_tensor = param._local_tensor if isinstance(param, DTensor) else param - local_tensor = local_tensor.bfloat16() - load_spec = language_model.load_spec_mapping.get(f"layers.{i}.{sub_name}") - - if isinstance(model.config, BaseComposeConfig): - name = f"model.language_model.layers.{i}.{sub_name}" - else: - name = f"model.layers.{i}.{sub_name}" - - if ".experts." in name and ".mlp.experts." not in name: - name = name.replace(".experts.", ".mlp.experts.") - if ".gate." in name and ".mlp.gate." not in name: - name = name.replace(".gate.", ".mlp.gate.") - name_list.append(name) - tensor_list.append((local_tensor, load_spec)) - fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) - state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) - self.request_update_params(state_dict) - - for name, param in model.state_dict().items(): - if name in saved_list: - continue - local_tensor = param._local_tensor if isinstance(param, DTensor) else param - local_tensor = local_tensor.bfloat16() - load_spec = model.load_spec_mapping.get(name) - - if isinstance(model.config, BaseComposeConfig): - if "vision_tower." in name: - name = name.replace("vision_tower.", vision_hf_prefix) - elif "multi_modal_projector." in name: - name = name.replace("multi_modal_projector.", projector_hf_prefix) - elif name == "language_model.norm.weight": - name = "model.language_model.norm.weight" - elif name == "language_model.embed_tokens.weight": - name = "model.language_model.embed_tokens.weight" - elif name == "language_model.lm_head.weight": - name = "lm_head.weight" - else: - if name == "norm.weight": - name = "model.norm.weight" - elif name == "embed_tokens.weight": - name = "model.embed_tokens.weight" - tensor_list = [(local_tensor, load_spec)] - name_list = [name] - fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) - state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) - self.request_update_params(state_dict) - - if self.rollout_cfg_info["backend"] in ("pytorch", "vllm"): - self.request_update_params({}, finished=True) - - dist.barrier() - DEVICE_MODULE.empty_cache() - return - - @staticmethod - def _compute_state_dict_bytes(state_dict: Dict[str, torch.Tensor]) -> int: - total_bytes = 0 - for tensor in state_dict.values(): - total_bytes += tensor.numel() * tensor.element_size() - return total_bytes - - @staticmethod - def _init_external_process_group( - backend: str | Backend | None = None, - init_method: str | None = None, - timeout: timedelta | None = None, - world_size: int = -1, - rank: int = -1, - store: Store | None = None, - group_name: str | None = None, - pg_options: Any | None = None, - ) -> dist.ProcessGroup: - assert (store is None) or (init_method is None), "Cannot specify both store and init_method." - if store is not None: - assert world_size > 0, "world_size must be positive if using store" - assert rank >= 0, "rank must be non-negative if using store" - elif init_method is None: - init_method = "env://" - - backend = Backend(backend) if backend else Backend("undefined") - if timeout is None: - timeout = default_pg_timeout - - if store is None: - assert init_method is not None - rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) - store, rank, world_size = next(rendezvous_iterator) - store.set_timeout(timeout) - if group_name is not None: - store = PrefixStore(group_name, store) - - pg_options_param_name = ( - "backend_options" if parse_version(torch.__version__) >= parse_version("2.6") else "pg_options" - ) - pg, _ = _new_process_group_helper( - world_size, - rank, - [], - backend, - store, - group_name=group_name, - **{pg_options_param_name: pg_options}, - timeout=timeout, - ) - _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} - return pg - - @staticmethod - def _create_ipc_tensor(size_in_bytes: int, dtype: torch.dtype): - return torch.empty(size_in_bytes, dtype=torch.uint8, device=DEVICE).view(dtype) - - def _build_lmdeploy_flattened_tensor_data(self, state_dict: dict, flattened_tensor_bucket_cls) -> dict: - # LMDeploy flattened buckets require all tensors in one bucket to share a dtype. - state_dict_dtype = state_dict[next(iter(state_dict))].dtype - update_params_ipc_tensor = self._update_params_ipc_tensor_dict_by_dtype.get(state_dict_dtype, None) - state_dict_bytes = self._compute_state_dict_bytes(state_dict) - ipc_tensor_bytes = self._ipc_tensor_bytes_dict_by_dtype.get( - state_dict_dtype, - self._default_ipc_tensor_bytes, - ) - dtype_changed = ( - self._last_update_params_ipc_tensor_dtype is not None - and state_dict_dtype != self._last_update_params_ipc_tensor_dtype - ) - need_resize = state_dict_bytes > ipc_tensor_bytes - send_ipc_tensor = dtype_changed or need_resize or update_params_ipc_tensor is None - - if update_params_ipc_tensor is not None: - self._update_params_ipc_event.wait() - if need_resize: - torch.cuda.synchronize() - - if update_params_ipc_tensor is None or need_resize: - ipc_tensor_bytes = max(ipc_tensor_bytes, state_dict_bytes) - self._ipc_tensor_bytes_dict_by_dtype[state_dict_dtype] = ipc_tensor_bytes - update_params_ipc_tensor = self._create_ipc_tensor( - ipc_tensor_bytes, - state_dict_dtype, - ) - self._update_params_ipc_tensor_dict_by_dtype[state_dict_dtype] = update_params_ipc_tensor - - flattened_tensor_bucket = flattened_tensor_bucket_cls( - named_tensors=list(state_dict.items()), - flattened_tensor=update_params_ipc_tensor, - ) - flattened_tensor_data = { - "metadata": flattened_tensor_bucket.get_metadata(), - "require_clone": False, - } - self._update_params_ipc_event.record() - self._last_update_params_ipc_tensor_dtype = state_dict_dtype - - if send_ipc_tensor: - flattened_tensor_data["flattened_tensor"] = flattened_tensor_bucket.get_flattened_tensor() - flattened_tensor_data["event_ipc_handle"] = self._update_params_ipc_event.ipc_handle() - return flattened_tensor_data - - def _get_disagg_engine_info(self) -> RolloutEngineInfo: - engine_info: RolloutEngineInfo = [] - seen_urls: set[str] = set() - rank_to_engine_size: dict[int, int] = {} - for engine_ranks in self.rollout_engine_rank_mesh_array: - engine_size = len(engine_ranks) - for rank in engine_ranks: - rank_to_engine_size[int(rank)] = engine_size - - for rank, url in sorted(self.rollout_server_url_dict.items(), key=lambda item: int(item[0])): - rank = int(rank) - if not url or url in seen_urls: - continue - if self.worker_server_urls_status.get(url, False) is False: - continue - seen_urls.add(url) - engine_info.append( - ( - rank, - url, - rank_to_engine_size.get( - rank, - max(self.rollout_cfg_info["tp"], self.rollout_cfg_info["ep"]), - ), - ) - ) - return engine_info - - def _ensure_sglang_disagg_group(self): - if self._sglang_disagg_group is not None: - return - engine_info = self._get_disagg_engine_info() - if not engine_info: - self.logger.error("No active rollout engine url, cannot init sglang weight update group") - return - - os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False" - backend = "nccl" - - master_address = None - master_port = None - # get address and port for weight-update - try: - import ray - - master_address = ray.util.get_node_ip_address() - except Exception: - master_address = socket.gethostbyname(socket.gethostname()) - - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("", 0)) - master_port = int(sock.getsockname()[1]) - - group_name = f"xtuner_sglang_weight_update_{self.rank}" - world_size = sum(engine_size for _, _, engine_size in engine_info) + 1 - - self._sglang_disagg_executor = ThreadPoolExecutor(max_workers=max(1, len(engine_info))) - init_futures = [] - rank_offset = 1 - for _, url, engine_size in engine_info: - payload = { - "master_address": master_address, - "master_port": master_port, - "rank_offset": rank_offset, - "world_size": world_size, - "group_name": group_name, - "backend": backend, - } - init_futures.append( - self._sglang_disagg_executor.submit( - requests.post, - f"{url}/init_weights_update_group", - json=payload, - ) - ) - rank_offset += engine_size - - self._sglang_disagg_group = self._init_external_process_group( - backend=backend, - init_method=f"tcp://{master_address}:{master_port}", - world_size=world_size, - rank=0, - group_name=group_name, - ) - - for init_future in init_futures: - response = init_future.result() - response.raise_for_status() - result = response.json() - assert result.get("success", True), ( - f"SGLang init_weights_update_group failed: {result.get('message', result)}" - ) - - self._sglang_disagg_group_name = group_name - self._sglang_disagg_engine_urls = [url for _, url, _ in engine_info] - - def _ensure_lmdeploy_disagg_group(self): - if self._lmdeploy_disagg_group is not None: - return - engine_info = self._get_disagg_engine_info() - if not engine_info: - self.logger.error("No active rollout engine url, cannot init lmdeploy weight update group") - return - - os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False" - backend = "nccl" - - master_address = None - master_port = None - try: - import ray - - master_address = ray.util.get_node_ip_address() - except Exception: - master_address = socket.gethostbyname(socket.gethostname()) - - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("", 0)) - master_port = int(sock.getsockname()[1]) - - group_name = f"xtuner_lmdeploy_weight_update_{self.rank}" - world_size = sum(engine_size for _, _, engine_size in engine_info) + 1 - - self._lmdeploy_disagg_executor = ThreadPoolExecutor(max_workers=max(1, len(engine_info))) - init_futures = [] - rank_offset = 1 - for _, url, engine_size in engine_info: - payload = { - "master_address": master_address, - "master_port": master_port, - "rank_offset": rank_offset, - "world_size": world_size, - "group_name": group_name, - "backend": backend, - } - init_futures.append( - self._lmdeploy_disagg_executor.submit( - requests.post, - f"{url}/init_weights_update_group", - json=payload, - ) - ) - rank_offset += engine_size - - self._lmdeploy_disagg_group = self._init_external_process_group( - backend=backend, - init_method=f"tcp://{master_address}:{master_port}", - world_size=world_size, - rank=0, - group_name=group_name, - ) - - for init_future in init_futures: - response = init_future.result() - response.raise_for_status() - result = response.json() - assert result.get("success", True), ( - f"LMDeploy init_weights_update_group failed: {result.get('message', result)}" - ) - - self._lmdeploy_disagg_group_name = group_name - self._lmdeploy_disagg_engine_urls = [url for _, url, _ in engine_info] - - def _request_update_params_sglang_disaggregated(self, state_dict): - if not state_dict: - return - - train_sync_group = self._get_train_update_sync_group() - head_rank = 0 - if dist.get_rank() != head_rank: - dist.barrier(group=train_sync_group) - return - - self._ensure_sglang_disagg_group() - if self._sglang_disagg_group is None: - dist.barrier(group=train_sync_group) - return - - assert self._sglang_disagg_executor is not None - assert self._sglang_disagg_group_name is not None - with self._sglang_disagg_update_lock: - try: - from sglang.srt.model_executor.model_runner import FlattenedTensorBucket - except Exception as e: - raise RuntimeError( - "Disaggregated update_weights currently only supports sglang builds " - "that provide `sglang.srt.model_executor.model_runner.FlattenedTensorBucket`." - ) from e - - names = list(state_dict.keys()) - tensors = [ - tensor.detach().to(device=DEVICE, non_blocking=True).contiguous() for tensor in state_dict.values() - ] - payload = { - "names": names, - "dtypes": [str(tensor.dtype).replace("torch.", "") for tensor in tensors], - "shapes": [list(tensor.shape) for tensor in tensors], - "group_name": self._sglang_disagg_group_name, - "load_format": "flattened_bucket", - } - update_futures = [ - self._sglang_disagg_executor.submit( - requests.post, - f"{url}/update_weights_from_distributed", - json=payload, - ) - for url in self._sglang_disagg_engine_urls - ] - assert self._sglang_disagg_group is not None - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=list(zip(names, tensors))) - flattened_tensor = flattened_tensor_bucket.get_flattened_tensor() - - dist.broadcast(flattened_tensor, src=0, group=self._sglang_disagg_group) - DEVICE_MODULE.synchronize() - for update_future in update_futures: - response = update_future.result() - response.raise_for_status() - result = response.json() - self._hook_compare_test_sent_and_received_weight_hash( - result, - names=names, - ) - assert result.get("success", True), ( - f"SGLang update_weights_from_distributed failed: {result.get('message', result)}" - ) - dist.barrier(group=train_sync_group) - - def _request_update_params_lmdeploy_disaggregated(self, state_dict, finished: bool = False): - if not state_dict and not finished: - return - - train_sync_group = self._get_train_update_sync_group() - head_rank = 0 - if dist.get_rank() != head_rank: - dist.barrier(group=train_sync_group) - return - - self._ensure_lmdeploy_disagg_group() - if self._lmdeploy_disagg_group is None: - dist.barrier(group=train_sync_group) - return - - assert self._lmdeploy_disagg_executor is not None - assert self._lmdeploy_disagg_group_name is not None - with self._lmdeploy_disagg_update_lock: - try: - from lmdeploy.utils import FlattenedTensorBucket - except Exception as e: - raise RuntimeError( - "Disaggregated update_weights for lmdeploy backend requires lmdeploy builds that provide " - "`lmdeploy.utils.FlattenedTensorBucket`." - ) from e - - if state_dict: - names = list(state_dict.keys()) - tensors = [ - tensor.detach().to(device=DEVICE, non_blocking=True).contiguous() for tensor in state_dict.values() - ] - payload = { - "names": names, - "dtypes": [str(tensor.dtype).replace("torch.", "") for tensor in tensors], - "shapes": [list(tensor.shape) for tensor in tensors], - "group_name": self._lmdeploy_disagg_group_name, - "load_format": "flattened_bucket", - "finished": finished, - } - update_futures = [ - self._lmdeploy_disagg_executor.submit( - requests.post, - f"{url}/update_weights_from_distributed", - json=payload, - ) - for url in self._lmdeploy_disagg_engine_urls - ] - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=list(zip(names, tensors))) - flattened_tensor = flattened_tensor_bucket.get_flattened_tensor() - dist.broadcast(flattened_tensor, src=0, group=self._lmdeploy_disagg_group) - DEVICE_MODULE.synchronize() - for update_future in update_futures: - response = update_future.result() - response.raise_for_status() - result = response.json() - self._hook_compare_test_sent_and_received_weight_hash( - result, - names=names, - ) - assert result.get("success", True), ( - f"LMDeploy update_weights_from_distributed failed: {result.get('message', result)}" - ) - else: - # finalize-only request: no tensors to broadcast, just trigger the - # rollout side's mod.update_weights() finalization hooks. - payload = { - "names": [], - "dtypes": [], - "shapes": [], - "group_name": self._lmdeploy_disagg_group_name, - "load_format": "flattened_bucket", - "finished": True, - } - update_futures = [ - self._lmdeploy_disagg_executor.submit( - requests.post, - f"{url}/update_weights_from_distributed", - json=payload, - ) - for url in self._lmdeploy_disagg_engine_urls - ] - for update_future in update_futures: - response = update_future.result() - response.raise_for_status() - result = response.json() - assert result.get("success", True), ( - f"LMDeploy update_weights_from_distributed (finalize) failed: {result.get('message', result)}" - ) - dist.barrier(group=train_sync_group) - - @ray_method - def request_update_params(self, state_dict, train_enable_ep=False, finished=False): - """Send a request to update the parameters on the rollout workers. - - This method serializes the state dictionary and sends it to the - appropriate rollout worker via an HTTP request. - - Args: - state_dict (dict | list): The state dictionary containing the model - parameters to update. - train_enable_ep (bool): Whether the training engine enables expert parallelism. - Defaults to False. - finished (bool): A flag indicating whether this is the final - batch of updates. Defaults to False. - """ - - if self.rollout_cfg_info["backend"] == "sglang" and not self.is_train_rollout_colocated: - self._request_update_params_sglang_disaggregated(state_dict) - return - - if self.rollout_cfg_info["backend"] == "pytorch" and not self.is_train_rollout_colocated: - self._request_update_params_lmdeploy_disaggregated(state_dict, finished=finished) - return - - cpu_mesh = self._ensure_rollout_device_mesh()["engine_parallel"] - cpu_group = cpu_mesh.get_group() - head_rank = cpu_mesh.mesh[0].item() - if self.rollout_url is None: - self.logger.error(f"rank {self.rank} url in None, cannot update weights and skip") - return - - if self.rollout_cfg_info["backend"] == "vllm": - - def serialize_state_dict(state_dict: dict) -> str: - import base64 - from io import BytesIO - from multiprocessing.reduction import ForkingPickler - - from torch.multiprocessing.reductions import reduce_tensor - - data = [(k, reduce_tensor(v)) for k, v in state_dict.items()] - buf = BytesIO() - ForkingPickler(buf).dump(data) - buf.seek(0) - return base64.b64encode(buf.read()).decode("utf-8") - - serialized_data = [None] * self.rollout_cfg_info["tp"] - dist.gather_object( - serialize_state_dict(state_dict), - serialized_data if dist.get_rank() == head_rank else None, - dst=head_rank, - group=cpu_group, - ) - if dist.get_rank() == head_rank: - headers = { - "Content-Type": "application/json", - } - data_ = json.dumps(dict(serialized_named_tensors=serialized_data, finished=finished)) - data = dict(method="update_weight_npu_ipc", args=[data_]) - response = requests.post(f"{self.rollout_url}/collective_rpc", headers=headers, json=data) - assert response.status_code == 200, f"response.status_code = {response.status_code}" - - if finished: - dist.barrier(group=cpu_group) - return - - if self.rollout_cfg_info["backend"] == "pytorch": - # TODO(chenchiyu): remove lmdeploy related code - from lmdeploy.utils import serialize_state_dict - - try: - from lmdeploy.utils import FlattenedTensorBucket - - use_flattened_tensor_bucket = True - except Exception: - use_flattened_tensor_bucket = False - - if self.rollout_cfg_info["backend"] == "pytorch" and self.rollout_cfg_info["tp"] > 1: - serialized_data = [None] * self.rollout_cfg_info["tp"] - if use_flattened_tensor_bucket and state_dict: - flattened_tensor_data = self._build_lmdeploy_flattened_tensor_data( - state_dict, - FlattenedTensorBucket, - ) - tp_serialized_data = serialize_state_dict(flattened_tensor_data) - else: - tp_serialized_data = serialize_state_dict(state_dict) - dist.gather_object( - tp_serialized_data, - serialized_data if dist.get_rank() == head_rank else None, - dst=head_rank, - group=cpu_group, - ) - elif self.rollout_cfg_info["backend"] == "pytorch": - if use_flattened_tensor_bucket and state_dict: - flattened_tensor_data = self._build_lmdeploy_flattened_tensor_data( - state_dict, - FlattenedTensorBucket, - ) - serialized_data = serialize_state_dict(flattened_tensor_data) - else: - serialized_data = serialize_state_dict(state_dict) - else: - # for turbomind backend, only head_rank should serialize data - serialized_data = serialize_state_dict(state_dict) if dist.get_rank() == head_rank else None - else: - # sglang - from sglang.srt.utils import MultiprocessingSerializer - from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions - - try: - from sglang.srt.model_executor.model_runner import FlattenedTensorBucket - - use_flattened_tensor_bucket = True - except Exception: - use_flattened_tensor_bucket = False - - # NOTE: xtuner目前去掉sglang的patch也不会出问题,但为了保险起见,还是保留patch逻辑,并且在update_weights结束后unpatch - monkey_patch_torch_reductions() - state_dict = state_dict.items() - if self.rollout_cfg_info["tp"] == 1: - if use_flattened_tensor_bucket: - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=state_dict) - metadata = flattened_tensor_bucket.get_metadata() - - flattened_tensor_data = { - "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), - "metadata": metadata, - } - serialized_data = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) - else: - serialized_data = MultiprocessingSerializer.serialize(state_dict, output_str=True) - - serialized_data = [serialized_data] - else: - serialized_data = [None] * self.rollout_cfg_info["tp"] - if use_flattened_tensor_bucket: - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=state_dict) - metadata = flattened_tensor_bucket.get_metadata() - - flattened_tensor_data = { - "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), - "metadata": metadata, - } - tp_serialized_data = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) - dist.gather_object( - tp_serialized_data, - serialized_data if dist.get_rank() == head_rank else None, - dst=head_rank, - group=cpu_group, - ) - else: - tp_serialized_data = MultiprocessingSerializer.serialize(state_dict, output_str=True) - dist.gather_object( - tp_serialized_data, - serialized_data if dist.get_rank() == head_rank else None, - dst=head_rank, - group=cpu_group, - ) - - if dist.get_rank() == head_rank: - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.rollout_cfg_info['api_key']}", - } - if self.rollout_cfg_info["backend"] == "sglang": - payload = { - "serialized_named_tensors": serialized_data, - "flush_cache": False, - } - try: - from sglang.srt.model_executor.model_runner import FlattenedTensorBucket - - use_flattened_tensor_bucket = True - except Exception: - use_flattened_tensor_bucket = False - if use_flattened_tensor_bucket: - payload["load_format"] = "flattened_bucket" - - url = f"{self.rollout_url}/update_weights_from_tensor" - response = requests.post(url, json=payload or {}) - response.raise_for_status() - else: - data = dict(serialized_named_tensors=serialized_data, finished=finished) - try: - from lmdeploy.utils import FlattenedTensorBucket - - use_flattened_tensor_bucket = True - except Exception: - use_flattened_tensor_bucket = False - - if use_flattened_tensor_bucket and state_dict: - data["load_format"] = "flattened_bucket" - response = requests.post( - f"{self.rollout_url}/{self.endpoints['update_weights']}", headers=headers, json=data - ) - assert response.status_code == 200, f"response.status_code = {response.status_code}" - - # TODO(chenchiyu): narrow this condition - if finished or ( - self.rollout_cfg_info["backend"] == "pytorch" and train_enable_ep and self.rollout_cfg_info["tp"] > 1 - ): - # This barrier is aim to make each tp head rank sync with other ranks in engine_parallel group - # which could not be barrier by `fsdp_foreach_allgather` of the next state dict. (Happens in same_gen, shard not tested) - # Without barrier, some ranks in engine_parallel group would not wait for current iter data ipc event recording in lmdeploy. - # They would write next iter state_dict into the ipc tensor before lmdeploy load current iter weight. - dist.barrier(group=cpu_group) - - monkey_unpatch_torch_reductions() - return diff --git a/xtuner/v1/rl/trainer/worker.py b/xtuner/v1/rl/trainer/worker.py index e1043e6410..00ce2e067d 100644 --- a/xtuner/v1/rl/trainer/worker.py +++ b/xtuner/v1/rl/trainer/worker.py @@ -48,6 +48,7 @@ from xtuner.v1.profiler import profiling_memory, profiling_time from xtuner.v1.rl.loss import BaseRLLossConfig, BaseRLLossContext, finalize_train_policy_metrics, kl_penalty from xtuner.v1.rl.utils import SingleAcceleratorWorker +from xtuner.v1.rl.weight_update import UpdateWeighter from xtuner.v1.train.trainer import LoadCheckpointConfig from xtuner.v1.utils import ( XTUNER_DETERMINISTIC, @@ -60,7 +61,6 @@ ) from ..rollout_is import merge_rollout_is_metrics -from .update_weighter import UpdateWeighter DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices @@ -201,7 +201,7 @@ class WorkerLogItem(TypedDict): sft_train_metrics: NotRequired[dict[str, float]] -class TrainingWorker(SingleAcceleratorWorker, UpdateWeighter): +class TrainingWorker(SingleAcceleratorWorker): _SAVE_WEIGHTS_DIR = "weights" _SAVE_SFT_DATALOADER_DIR = "sft_dataloader" _SAVE_SFT_TRAIN_STATE_PATH = "sft_train_state.json" @@ -269,7 +269,20 @@ def __init__( if hasattr(worker_cfg.model_cfg.text_config, "mtp_config"): self.mtp_config = worker_cfg.model_cfg.text_config.mtp_config - self._init_update_weighter() + self.update_weighter = UpdateWeighter( + rank=self.rank, + logger=self.logger, + config=self.config, + engine=self._engine, + ) + + @ray_method + def update_rollout_info(self, *args, **kwargs): + return self.update_weighter.update_rollout_info(*args, **kwargs) + + @ray_method + def update_weights(self): + return self.update_weighter.update_weights() def _init_sft(self, worker_cfg: WorkerConfig): self._sft_dataloader_config = worker_cfg.sft_dataloader_cfg diff --git a/xtuner/v1/rl/weight_update/__init__.py b/xtuner/v1/rl/weight_update/__init__.py new file mode 100644 index 0000000000..312f779fe3 --- /dev/null +++ b/xtuner/v1/rl/weight_update/__init__.py @@ -0,0 +1,46 @@ +from .data import ( + DeviceMeshRaw, + RolloutBackend, + RolloutEngineInfo, + RolloutWeightUpdateInfo, + ServiceUrlMap, + TrainRolloutMode, + WeightTransportType, + WeightUpdateBatch, +) +from .transport import ( + IPCBackendAdapter, + IPCWeightTransport, + LMDeployIPCBackendAdapter, + NCCLBackendAdapter, + NCCLWeightTransport, + SGLangIPCBackendAdapter, + SGLangNCCLBackendAdapter, + WeightTransport, + WeightUpdateRequest, +) +from .update_weighter import UpdateWeighter +from .weight_iterator import WeightIterator + + +__all__ = [ + "DeviceMeshRaw", + "IPCBackendAdapter", + "IPCWeightTransport", + "LMDeployIPCBackendAdapter", + "NCCLBackendAdapter", + "NCCLWeightTransport", + "RolloutBackend", + "RolloutEngineInfo", + "RolloutWeightUpdateInfo", + "SGLangIPCBackendAdapter", + "SGLangNCCLBackendAdapter", + "ServiceUrlMap", + "TrainRolloutMode", + "UpdateWeighter", + "WeightIterator", + "WeightTransportType", + "WeightUpdateBatch", + "WeightUpdateRequest", + "WeightTransport", +] diff --git a/xtuner/v1/rl/weight_update/data.py b/xtuner/v1/rl/weight_update/data.py new file mode 100644 index 0000000000..6041ff6433 --- /dev/null +++ b/xtuner/v1/rl/weight_update/data.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Dict, List, Literal, TypeAlias + +import torch +from torch.distributed.device_mesh import DeviceMesh + + +DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices. +ServiceUrlMap: TypeAlias = Dict[int, str] # A dictionary mapping rollout ranks to their server URLs. +RolloutEngineInfo: TypeAlias = list[tuple[int, str, int]] # (rollout rank, server url, engine gpu count) +TrainRolloutMode: TypeAlias = Literal["colocate", "disaggregated"] # Train and rollout deployment mode. +RolloutBackend: TypeAlias = Literal["sglang", "vllm", "pytorch", "turbomind"] # Rollout inference backend. +WeightTransportType: TypeAlias = Literal["ipc", "nccl"] # Supported weight transport types. + + +@dataclass +class RolloutWeightUpdateInfo: + # Common rollout metadata. + api_key: list[str] | str | None = None + rollout_url: str | None = None + backend: RolloutBackend | None = None + tp: int = 1 + ep: int = 1 + train_rollout_mode: TrainRolloutMode | None = None + transport_type: WeightTransportType | None = None + rollout_cfg_info: dict = field(default_factory=dict) + endpoints: dict[str, str] = field(default_factory=lambda: {"update_weights": "update_weights"}) + + # Colocated rollout metadata. + rollout_device_mesh: DeviceMesh | None = None + rollout_engine_rank_mesh_array: DeviceMeshRaw = field(default_factory=list) + + # Disaggregated rollout metadata. + rollout_server_url_dict: ServiceUrlMap = field(default_factory=dict) + worker_server_urls_status: dict[str, bool] = field(default_factory=dict) + weight_update_host: str | None = None + weight_update_port: int | None = None + + +@dataclass +class WeightUpdateBatch: + """A single bucket of weights to send to rollout workers.""" + + state_dict: dict[str, torch.Tensor] + train_enable_ep: bool = False + finished: bool = False diff --git a/xtuner/v1/rl/weight_update/transport.py b/xtuner/v1/rl/weight_update/transport.py new file mode 100644 index 0000000000..2ae3639bec --- /dev/null +++ b/xtuner/v1/rl/weight_update/transport.py @@ -0,0 +1,862 @@ +from __future__ import annotations + +import importlib +import json +import os +import socket +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, Callable, Protocol, cast + +import requests +import torch +import torch.distributed as dist +from packaging.version import parse as parse_version +from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + Store, + _new_process_group_helper, + _world, + default_pg_timeout, + rendezvous, +) + +from xtuner.v1.utils import ( + get_device, + get_torch_device_module, + monkey_unpatch_torch_reductions, +) + +from .data import RolloutWeightUpdateInfo, WeightUpdateBatch + + +DEVICE = get_device() +DEVICE_MODULE = get_torch_device_module() + + +@dataclass +class WeightUpdateRequest: + endpoint: str + body: dict[str, Any] + + +class WeightTransportAdapter(Protocol): + def before_update(self) -> None: ... + + def after_update_all_groups(self) -> None: ... + + +class WeightTransport(ABC): + def __init__(self, *, rollout_info: RolloutWeightUpdateInfo, logger: Any, rank: int): + self.rollout_info = rollout_info + self.logger = logger + self.rank = rank + self.backend = self.rollout_info.backend + self.rollout_ep = self.rollout_info.ep + self.rollout_tp = self.rollout_info.tp + self._adapter: WeightTransportAdapter | None = None + + self.rollout_url = self.rollout_info.rollout_url + if self.rollout_url is None: + self.logger.error(f"rank {self.rank} url in None, cannot update weights and skip") + + @staticmethod + def post_json(url: str, endpoint: str, payload: dict, *, api_key=None) -> dict: + headers = {"Content-Type": "application/json"} + # TODO move api key to init + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" + response = requests.post(f"{url}/{endpoint}", headers=headers, json=payload) + assert response.status_code == 200, f"response.status_code = {response.status_code}" + response.raise_for_status() + return response.json() + + def update(self, weight_iterator: Any) -> None: + assert self._adapter is not None + self._adapter.before_update() + DEVICE_MODULE.empty_cache() + + try: + for batches in weight_iterator.iter_batch_groups(): + for batch in batches: + self.send(batch) + self.after_update_per_group() + DEVICE_MODULE.empty_cache() + finally: + self.after_update_all_groups() + DEVICE_MODULE.empty_cache() + + @abstractmethod + def send(self, batch: WeightUpdateBatch) -> None: + raise NotImplementedError + + def after_update_all_groups(self) -> None: + return + + def after_update_per_group(self) -> None: + return + + def teardown(self) -> None: + return + + +class IPCBackendAdapter: + # def __init__(self, *, rollout_info: RolloutWeightUpdateInfo): + def __init__(self, *, rollout_tp: int): + self.rollout_tp = rollout_tp + # self.rollout_info = rollout_info + + def before_update(self) -> None: + return + + def after_update(self) -> None: + return + + def build_request( + self, + batch: WeightUpdateBatch, + serialized_data: Any, + ) -> WeightUpdateRequest: + raise NotImplementedError + + def serialize( + self, + batch: WeightUpdateBatch, + cpu_group: dist.ProcessGroup, + head_rank: int, + ) -> list[Any]: + raise NotImplementedError + + def after_update_per_batch( + self, finished: bool, cpu_group: dist.ProcessGroup, train_enable_ep: bool = False + ) -> None: + return + + def after_update_all_groups(self) -> None: + return + + +class VLLMIPCBackendAdapter(IPCBackendAdapter): + @staticmethod + def _serialize_state_dict(state_dict: dict) -> str: + import base64 + from io import BytesIO + from multiprocessing.reduction import ForkingPickler + + from torch.multiprocessing.reductions import reduce_tensor + + data = [(k, reduce_tensor(v)) for k, v in state_dict.items()] + buf = BytesIO() + ForkingPickler(buf).dump(data) + buf.seek(0) + return base64.b64encode(buf.read()).decode("utf-8") + + def serialize( + self, + batch: WeightUpdateBatch, + cpu_group: dist.ProcessGroup, + head_rank: int, + ) -> list[Any]: + serialized_data = [None] * self.rollout_tp + dist.gather_object( + self._serialize_state_dict(batch.state_dict), + serialized_data if dist.get_rank() == head_rank else None, + dst=head_rank, + group=cpu_group, + ) + return serialized_data + + def build_request( + self, + batch: WeightUpdateBatch, + serialized_data: list[Any], + ) -> WeightUpdateRequest: + data_ = json.dumps(dict(serialized_named_tensors=serialized_data, finished=batch.finished)) + data = dict(method="update_weight_npu_ipc", args=[data_]) + return WeightUpdateRequest(endpoint="collective_rpc", body=data) + + def after_update_per_batch( + self, finished: bool, cpu_group: dist.ProcessGroup, train_enable_ep: bool = False + ) -> None: + if finished: + dist.barrier(group=cpu_group) + + +class LMDeployIPCBackendAdapter(IPCBackendAdapter): + def __init__(self, *, rollout_tp: int, backend: str, default_ipc_tensor_bytes: int): + super().__init__(rollout_tp=rollout_tp) + self._default_ipc_tensor_bytes = default_ipc_tensor_bytes + self._ipc_tensor_bytes_by_dtype: dict[torch.dtype, int] = {} + self._update_params_ipc_tensor_by_dtype: dict[torch.dtype, torch.Tensor] = {} + self._last_update_params_ipc_tensor_dtype: torch.dtype | None = None + self._update_params_ipc_event = None + self.backend = backend + self.endpoints: dict[str, str] = dict() + self.endpoints["update_weights"] = "update_weights" + + try: + model_runner = importlib.import_module("lmdeploy.utils") + getattr(model_runner, "FlattenedTensorBucket") + self.use_flattened_tensor_bucket = True + except Exception: + self.use_flattened_tensor_bucket = False + + def before_update(self) -> None: + self._update_params_ipc_event = DEVICE_MODULE.Event(interprocess=True) + + def after_update_all_groups(self) -> None: + self._ipc_tensor_bytes_by_dtype = {} + self._update_params_ipc_tensor_by_dtype = {} + self._last_update_params_ipc_tensor_dtype = None + self._update_params_ipc_event = None + + @staticmethod + def _compute_state_dict_bytes(state_dict: dict[str, torch.Tensor]) -> int: + total_bytes = 0 + for tensor in state_dict.values(): + total_bytes += tensor.numel() * tensor.element_size() + return total_bytes + + @staticmethod + def _create_ipc_tensor(size_in_bytes: int, dtype: torch.dtype): + return torch.empty(size_in_bytes, dtype=torch.uint8, device=DEVICE).view(dtype) + + def build_flattened_tensor_data( + self, + state_dict: dict, + flattened_tensor_bucket_cls, + ) -> dict: + assert self._update_params_ipc_event is not None + # LMDeploy flattened buckets require all tensors in one bucket to share a dtype. + state_dict_dtype = state_dict[next(iter(state_dict))].dtype + # LMDeploy can reuse the same IPC tensor across batches. A new handle is + # sent only when dtype changes, capacity is insufficient, or this is the first batch. + update_params_ipc_tensor = self._update_params_ipc_tensor_by_dtype.get(state_dict_dtype, None) + state_dict_bytes = self._compute_state_dict_bytes(state_dict) + ipc_tensor_bytes = self._ipc_tensor_bytes_by_dtype.get( + state_dict_dtype, + self._default_ipc_tensor_bytes, + ) + dtype_changed = ( + self._last_update_params_ipc_tensor_dtype is not None + and state_dict_dtype != self._last_update_params_ipc_tensor_dtype + ) + need_resize = state_dict_bytes > ipc_tensor_bytes + send_ipc_tensor = dtype_changed or need_resize or update_params_ipc_tensor is None + + if update_params_ipc_tensor is not None: + # Wait until rollout has consumed the previous IPC tensor before reusing it. + self._update_params_ipc_event.wait() + if need_resize: + # Synchronize before replacing a too-small IPC tensor to avoid freeing + # storage that may still be referenced by the rollout process. + DEVICE_MODULE.synchronize() + + if update_params_ipc_tensor is None or need_resize: + ipc_tensor_bytes = max(ipc_tensor_bytes, state_dict_bytes) + self._ipc_tensor_bytes_by_dtype[state_dict_dtype] = ipc_tensor_bytes + update_params_ipc_tensor = self._create_ipc_tensor( + ipc_tensor_bytes, + state_dict_dtype, + ) + self._update_params_ipc_tensor_by_dtype[state_dict_dtype] = update_params_ipc_tensor + + flattened_tensor_bucket = flattened_tensor_bucket_cls( + named_tensors=list(state_dict.items()), + flattened_tensor=update_params_ipc_tensor, + ) + flattened_tensor_data = { + "metadata": flattened_tensor_bucket.get_metadata(), + "require_clone": False, + } + self._update_params_ipc_event.record() + self._last_update_params_ipc_tensor_dtype = state_dict_dtype + + if send_ipc_tensor: + # Subsequent batches with the same cached IPC tensor only need metadata; the + # tensor handle and event handle are resent only when the cached buffer changes. + flattened_tensor_data["flattened_tensor"] = flattened_tensor_bucket.get_flattened_tensor() + flattened_tensor_data["event_ipc_handle"] = self._update_params_ipc_event.ipc_handle() + return flattened_tensor_data + + def serialize( + self, + batch: WeightUpdateBatch, + cpu_group: dist.ProcessGroup, + head_rank: int, + ) -> list[Any]: + from lmdeploy.utils import serialize_state_dict + + state_dict = batch.state_dict + + if self.use_flattened_tensor_bucket and state_dict: + from lmdeploy.utils import FlattenedTensorBucket + + flattened_tensor_data = self.build_flattened_tensor_data( + state_dict, + FlattenedTensorBucket, + ) + serialized_data = serialize_state_dict(flattened_tensor_data) + else: + serialized_data = serialize_state_dict(state_dict) + + if self.rollout_tp == 1: + return serialized_data + else: + all_serialized_data = [None] * self.rollout_tp + dist.gather_object( + serialized_data, + all_serialized_data if dist.get_rank() == head_rank else None, + dst=head_rank, + group=cpu_group, + ) + return all_serialized_data + + def build_request( + self, + batch: WeightUpdateBatch, + serialized_data: tuple[Any, bool], + ) -> WeightUpdateRequest: + state_dict = batch.state_dict + data = dict(serialized_named_tensors=serialized_data, finished=batch.finished) + if self.use_flattened_tensor_bucket and state_dict: + data["load_format"] = "flattened_bucket" + return WeightUpdateRequest(endpoint=self.endpoints["update_weights"], body=data) + + def after_update_per_batch( + self, finished: bool, cpu_group: dist.ProcessGroup, train_enable_ep: bool = False + ) -> None: + # TODO(chenchiyu): narrow this condition. + if finished or (train_enable_ep and self.rollout_tp > 1): + # Make each TP head rank sync with other ranks in engine_parallel group. + # FSDP all-gather of the next state_dict cannot cover this case, so without + # this barrier some ranks could overwrite the IPC tensor before LMDeploy loads it. + dist.barrier(group=cpu_group) + + +class SGLangIPCBackendAdapter(IPCBackendAdapter): + def __init__(self, *, rollout_tp): + super().__init__(rollout_tp=rollout_tp) + + try: + model_runner = importlib.import_module("sglang.srt.model_executor.model_runner") + getattr(model_runner, "FlattenedTensorBucket") + + self.use_flattened_tensor_bucket = True + except Exception: + self.use_flattened_tensor_bucket = False + + def serialize( + self, + batch: WeightUpdateBatch, + cpu_group: dist.ProcessGroup, + head_rank: int, + ) -> list[Any]: + from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions + + # NOTE: XTuner currently also works without the SGLang patch in some cases, + # but keep the patch/unpatch pair for compatibility with SGLang serialization. + # SGLang overrides torch tensor reduction for multiprocessing serialization. + monkey_patch_torch_reductions() + + from sglang.srt.utils import MultiprocessingSerializer + + state_dict = batch.state_dict + + state_items = state_dict.items() + + if self.use_flattened_tensor_bucket: + from sglang.srt.model_executor.model_runner import FlattenedTensorBucket + + flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=state_items) + flattened_tensor_data = { + "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), + "metadata": flattened_tensor_bucket.get_metadata(), + } + serialized_data = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) + else: + serialized_data = MultiprocessingSerializer.serialize(state_items, output_str=True) + + if self.rollout_tp == 1: + return [serialized_data] + else: + all_serialized_data = [None] * self.rollout_tp + dist.gather_object( + serialized_data, + all_serialized_data if dist.get_rank() == head_rank else None, + dst=head_rank, + group=cpu_group, + ) + return all_serialized_data + + def build_request( + self, + batch: WeightUpdateBatch, + serialized_data: tuple[list[Any], bool], + ) -> WeightUpdateRequest: + payload = { + "serialized_named_tensors": serialized_data, + "flush_cache": False, + } + if self.use_flattened_tensor_bucket: + payload["load_format"] = "flattened_bucket" + return WeightUpdateRequest(endpoint="update_weights_from_tensor", body=payload) + + def after_update_per_batch( + self, finished: bool, cpu_group: dist.ProcessGroup, train_enable_ep: bool = False + ) -> None: + # TODO(chenchiyu): narrow this condition. + if finished: + # Make each TP head rank sync with other ranks in engine_parallel group. + # FSDP all-gather of the next state_dict cannot cover this case, so without + # this barrier some ranks could overwrite the IPC tensor before LMDeploy loads it. + dist.barrier(group=cpu_group) + + +class IPCWeightTransport(WeightTransport): + _adapter: IPCBackendAdapter + + def __init__( + self, + *, + rank: int, + logger: Any, + config: Any, + rollout_info: RolloutWeightUpdateInfo, + ): + super().__init__(rank=rank, logger=logger, rollout_info=rollout_info) + self.config = config + self._adapter = self._build_adapter() + + assert self.rollout_info.rollout_device_mesh is not None + self.rollout_device_mesh = self.rollout_info.rollout_device_mesh + self.cpu_mesh = self.rollout_info.rollout_device_mesh["engine_parallel"] + self.cpu_group = self.cpu_mesh.get_group() + self.head_rank = int(self.cpu_mesh.mesh[0].item()) + + def _build_adapter(self) -> IPCBackendAdapter: + if self.backend == "vllm": + return VLLMIPCBackendAdapter(rollout_tp=self.rollout_info.tp) + elif self.backend == "sglang": + return SGLangIPCBackendAdapter(rollout_tp=self.rollout_info.tp) + elif self.backend == "pytorch" or self.backend == "turbomind": + return LMDeployIPCBackendAdapter( + rollout_tp=self.rollout_info.tp, + backend=self.backend, + default_ipc_tensor_bytes=int(self.config.update_weight_bucket_size_in_gb * 1024**3), + ) + else: + raise ValueError( + f"Unsupported IPC weight update backend: {self.backend!r}. Expected 'vllm', 'sglang', 'pytorch' or 'turbomind'." + ) + + def after_update_all_groups(self) -> None: + self._adapter.after_update_all_groups() + DEVICE_MODULE.empty_cache() + + def after_update_per_group(self) -> None: + dist.barrier() + + def send(self, batch: WeightUpdateBatch) -> None: + if self.rollout_url is None: + self.logger.error(f"rank {self.rank} url in None, cannot update weights and skip") + return + + DEVICE_MODULE.empty_cache() + try: + serialized_data = self._adapter.serialize( + batch, + self.cpu_group, + self.head_rank, + ) + if dist.get_rank() == self.head_rank: + request = self._adapter.build_request(batch, serialized_data) + self.post_json( + self.rollout_url, + request.endpoint, + request.body, + api_key=self.rollout_info.api_key, + ) + + self._adapter.after_update_per_batch(batch.finished, self.cpu_group, batch.train_enable_ep) + + finally: + monkey_unpatch_torch_reductions() + + +class NCCLBackendAdapter: + def __init__(self): + pass + + def build_weight_update_payload(self, batch: WeightUpdateBatch, group_name: str): + pass + + def build_request( + self, + payload: dict[str, Any], + ) -> WeightUpdateRequest: + raise NotImplementedError + + def before_update(self) -> None: + return + + def after_update_all_groups(self) -> None: + return + + +class SGLangNCCLBackendAdapter(NCCLBackendAdapter): + def __init__(self): + super().__init__() + + def build_weight_update_payload(self, batch: WeightUpdateBatch, group_name: str): + try: + from sglang.srt.model_executor.model_runner import FlattenedTensorBucket + except Exception as e: + raise RuntimeError( + "Disaggregated update_weights currently only supports sglang builds " + "that provide `sglang.srt.model_executor.model_runner.FlattenedTensorBucket`." + ) from e + + state_dict = batch.state_dict + finished = batch.finished + if not finished: + weight_names = list(state_dict.keys()) + weight_tensors = [ + tensor.detach().to(device=DEVICE, non_blocking=True).contiguous() for tensor in state_dict.values() + ] + payload = { + "names": weight_names, + "dtypes": [str(tensor.dtype).replace("torch.", "") for tensor in weight_tensors], + "shapes": [list(tensor.shape) for tensor in weight_tensors], + "group_name": group_name, + "load_format": "flattened_bucket", + } + flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=list(zip(weight_names, weight_tensors))) + flattened_tensor = flattened_tensor_bucket.get_flattened_tensor() + + return payload, flattened_tensor, weight_names + else: + return None, None, None + + def build_request( + self, + payload: dict[str, Any], + ) -> WeightUpdateRequest: + return WeightUpdateRequest(endpoint="update_weights_from_distributed", body=payload) + + +class LMDeployNCCLBackendAdapter(NCCLBackendAdapter): + def __init__(self): + super().__init__() + + def build_weight_update_payload(self, batch: WeightUpdateBatch, group_name: str): + try: + from lmdeploy.utils import FlattenedTensorBucket + except Exception as e: + raise RuntimeError( + "Disaggregated update_weights for lmdeploy backend requires lmdeploy builds that provide " + "`lmdeploy.utils.FlattenedTensorBucket`." + ) from e + + state_dict = batch.state_dict + finished = batch.finished + # Pytorch backend will send empty state_dict when finished. + if not finished: + weight_names = list(state_dict.keys()) + weight_tensors = [ + tensor.detach().to(device=DEVICE, non_blocking=True).contiguous() for tensor in state_dict.values() + ] + payload = { + "names": weight_names, + "dtypes": [str(tensor.dtype).replace("torch.", "") for tensor in weight_tensors], + "shapes": [list(tensor.shape) for tensor in weight_tensors], + "group_name": group_name, + "load_format": "flattened_bucket", + } + flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=list(zip(weight_names, weight_tensors))) + flattened_tensor = flattened_tensor_bucket.get_flattened_tensor() + + return payload, flattened_tensor, weight_names + else: + # finalize-only request: no tensors to broadcast, just trigger the + # rollout side's mod.update_weights() finalization hooks. + payload = { + "names": [], + "dtypes": [], + "shapes": [], + "group_name": group_name, + "load_format": "flattened_bucket", + "finished": True, + } + return payload, None, None + + def build_request( + self, + payload: dict[str, Any], + ) -> WeightUpdateRequest: + return WeightUpdateRequest(endpoint="update_weights_from_distributed", body=payload) + + +class NCCLWeightTransport(WeightTransport): + _adapter: NCCLBackendAdapter + + def __init__(self, *, rank: int, logger: Any, rollout_info: RolloutWeightUpdateInfo): + super().__init__(rank=rank, logger=logger, rollout_info=rollout_info) + self.group: dist.ProcessGroup | None = None + self.group_name: str | None = None + self.executor: ThreadPoolExecutor | None = None + self.train_update_sync_group: dist.ProcessGroup | None = None + self.hook_compare_test_sent_and_received_weight_hash: Callable[..., None] = lambda result, **kwargs: None + + self.engine_urls: list[str] = [] + self.external_group_world_size: int | None = None + + self._adapter = self._build_adapter() + + def _build_adapter(self) -> NCCLBackendAdapter: + if self.backend == "sglang": + return SGLangNCCLBackendAdapter() + elif self.backend == "pytorch": + return LMDeployNCCLBackendAdapter() + raise ValueError(f"Unsupported NCCL weight update backend: {self.backend!r}") + + def get_train_update_sync_group(self) -> dist.ProcessGroup: + # Create a Gloo process group for synchronization during NCCL weight update. + if self.train_update_sync_group is None: + ranks = list(range(dist.get_world_size())) + self.train_update_sync_group = dist.new_group(ranks=ranks, backend="gloo") + return self.train_update_sync_group + + def get_weight_update_address(self) -> tuple[str, int]: + # NCCL 会建立通信组 [train 0 + all rollout rank] 来进行broadcast,这里需要获得可用ip和port + host = self.rollout_info.weight_update_host + if not host: + try: + import ray + + host = ray.util.get_node_ip_address() + except Exception: + host = socket.gethostbyname(socket.gethostname()) + + port = self.rollout_info.weight_update_port + + return cast(str, host), cast(int, port) + + def ensure_nccl_weight_update_group(self): + """Create the NCCL weight update group if it has not been + initialized.""" + + if self.group is not None: + return + + # Map rollout rank to its engine size. + rank_to_engine_size = { + int(rank): len(engine_ranks) + for engine_ranks in self.rollout_info.rollout_engine_rank_mesh_array + for rank in engine_ranks + } + + # Deduplicate rollout engine URLs while keeping the first rank associated + # with each URL as the representative rank for that engine. + url_to_rank: dict[str, int] = {} + for rank, url in sorted( + self.rollout_info.rollout_server_url_dict.items(), + key=lambda item: int(item[0]), + ): + if url: + url_to_rank.setdefault(url, int(rank)) + + # Collect the representative rank, URL, and engine size needed to create + # the NCCL weight update process group. + engine_info = [ + ( + rank, + url, + rank_to_engine_size.get( + rank, + max(self.rollout_info.tp, self.rollout_info.ep), + ), + ) + for url, rank in url_to_rank.items() + ] + + if not engine_info: + self.logger.error("No active rollout engine url, cannot init sglang weight update group") + return + + os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False" + backend = "nccl" + address, port = self.get_weight_update_address() + + group_name = f"xtuner_NCCL_weight_update_{self.rank}" + # Train rank 0 is external group rank 0. Rollout engine ranks are assigned + # contiguous offsets starting from rank 1. + world_size = sum(engine_size for _, _, engine_size in engine_info) + 1 + + self.external_group_world_size = world_size + + self.executor = ThreadPoolExecutor(max_workers=max(1, len(engine_info))) + init_futures = [] + rank_offset = 1 + + for _, url, engine_size in engine_info: + payload = { + "master_address": address, + "master_port": port, + "rank_offset": rank_offset, + "world_size": world_size, + "group_name": group_name, + "backend": backend, + } + init_futures.append( + self.executor.submit( + self.post_json, + url, + "init_weights_update_group", + payload, + api_key=self.rollout_info.api_key, + ) + ) + rank_offset += engine_size + + self.group = self._init_external_process_group( + backend=backend, + init_method=f"tcp://{address}:{port}", + world_size=world_size, + rank=0, + group_name=group_name, + ) + + for init_future in init_futures: + result = init_future.result() + assert result.get("success", True), f"init_weights_update_group failed: {result.get('message', result)}" + + self.group_name = group_name + self.engine_urls = [url for _, url, _ in engine_info] + + def send(self, batch: WeightUpdateBatch) -> None: + state_dict = batch.state_dict + if not state_dict: + return + + train_sync_group = self.get_train_update_sync_group() + head_rank = 0 + # Only train rank 0 drives the disaggregated NCCL update. Other train + # ranks wait here so training and rollout steps remain aligned. + if dist.get_rank() != head_rank: + dist.barrier(group=train_sync_group) + return + + self.ensure_nccl_weight_update_group() + if self.group is None: + # If the NCCL weight update group could not be initialized, release the + # other training ranks waiting at the sync barrier and skip this update. + dist.barrier(group=train_sync_group) + return + + assert self.executor is not None + assert self.group_name is not None + payload, flattened_tensor, weight_names = self._adapter.build_weight_update_payload(batch, self.group_name) + if payload is not None: + request = self._adapter.build_request(payload) + # Notify rollout engines first so they can join the external NCCL group and + # prepare receive buffers described by names/dtypes/shapes. + update_futures = [ + self.executor.submit( + self.post_json, + url, + request.endpoint, + request.body, + api_key=self.rollout_info.api_key, + ) + for url in self.engine_urls + ] + if flattened_tensor is not None: + # LMDeploy send empty payload finally. + # Send the flattened weight tensor through the external NCCL group. + dist.broadcast(flattened_tensor, src=0, group=self.group) + DEVICE_MODULE.synchronize() + # Wait for rollout engines to finish loading weights and validate + # backend-specific update results. + for update_future in update_futures: + result = update_future.result() + self.hook_compare_test_sent_and_received_weight_hash( + result, + names=weight_names, + ) + assert result.get("success", True), ( + f"update_weights_from_distributed failed: {result.get('message', result)}" + ) + dist.barrier(group=train_sync_group) + + @staticmethod + def _init_external_process_group( + backend: str | Backend | None = None, + init_method: str | None = None, + timeout: timedelta | None = None, + world_size: int = -1, + rank: int = -1, + store: Store | None = None, + group_name: str | None = None, + pg_options: Any | None = None, + ) -> dist.ProcessGroup: + # Build a process group that includes external rollout processes, which + # cannot be represented by dist.new_group over the current training world. + assert (store is None) or (init_method is None), "Cannot specify both store and init_method." + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + backend = Backend(backend) if backend else Backend("undefined") + if timeout is None: + timeout = default_pg_timeout + + if store is None: + assert init_method is not None + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + if group_name is not None: + store = PrefixStore(group_name, store) + + pg_options_param_name = ( + "backend_options" if parse_version(torch.__version__) >= parse_version("2.6") else "pg_options" + ) + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + ) + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + return pg + + def after_update_per_group(self) -> None: + dist.barrier(group=self.get_train_update_sync_group()) + + def teardown(self) -> None: + # Reset only resources that depend on rollout metadata. The train-side sync group + # is independent of rollout workers and should live until worker teardown. + if self.group is not None: + try: + dist.destroy_process_group(self.group) + except Exception as e: + self.logger.warning(f"Failed to destroy NCCL weight update group: {e}") + self.group = None + + if self.executor is not None: + self.executor.shutdown(wait=False, cancel_futures=True) + self.executor = None + + self.group_name = None + self.engine_urls = [] + self.external_group_world_size = None diff --git a/xtuner/v1/rl/weight_update/update_weighter.py b/xtuner/v1/rl/weight_update/update_weighter.py new file mode 100644 index 0000000000..7adb7ba8da --- /dev/null +++ b/xtuner/v1/rl/weight_update/update_weighter.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import os +from typing import Any, cast + +from torch.distributed.device_mesh import DeviceMesh + +from xtuner.v1.rl.rollout.worker import RolloutConfig + +from .data import ( + DeviceMeshRaw, + RolloutBackend, + RolloutWeightUpdateInfo, + ServiceUrlMap, + TrainRolloutMode, +) +from .transport import IPCWeightTransport, NCCLWeightTransport, WeightTransport +from .weight_iterator import WeightIterator + + +class UpdateWeighter: + def __init__(self, *, rank: int, logger: Any, config: Any, engine: Any): + self.rank = rank + self.logger = logger + self.config = config + self._engine = engine + # Used to update weight to rollout engine. + self.rollout_info = RolloutWeightUpdateInfo() + self._global_hf_keys_mapping_cache: dict[str, list[str]] = {} + # Transport is initialized after update_rollout_info() is called. + self._transport: WeightTransport | None = None + # Used to detect changes in rollout metadata that require resetting the transport. + self._transport_signature: tuple[Any, ...] | None = None + + @staticmethod + def _normalize_rollout_backend(rollout_config: RolloutConfig) -> RolloutBackend: + # Backend selection follows rollout launcher precedence: explicit SGLang/vLLM env vars win, + # otherwise the LMDeploy backend decides between pytorch and turbomind. + if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": + backend = "sglang" + elif os.environ.get("XTUNER_USE_VLLM", "0") == "1": + backend = "vllm" + else: + backend = (rollout_config.extra_rollout_config or dict()).get("lmdeploy_backend", "pytorch") + + backend = backend.lower() + if backend not in ("sglang", "vllm", "pytorch", "turbomind"): + raise ValueError( + f"Unsupported rollout backend: {backend!r}. Expected 'sglang', 'vllm', 'pytorch' or 'turbomind'." + ) + return cast(RolloutBackend, backend) + + def update_rollout_info( + self, + engine_rank_mesh_array: DeviceMeshRaw, + server_url_dict: ServiceUrlMap, + rollout_config: RolloutConfig, + worker_server_urls_status: dict[str, bool], + train_rollout_mode: TrainRolloutMode, + weight_update_host: str | None = None, + weight_update_port: int | None = None, + worker_session_url_dict: ServiceUrlMap | None = None, + worker_session_urls_status: dict[str, bool] | None = None, + ): + """Update the rollout information for the training worker.""" + + self.rollout_info.backend = self._normalize_rollout_backend(rollout_config) + self.set_train_rollout_mode(train_rollout_mode=train_rollout_mode) + + # Common rollout metadata. + tp = rollout_config.tensor_parallel_size + ep = rollout_config.expert_parallel_size + assert tp == 1 or ep == 1, "Either tensor parallel size or engine parallel size must be 1." + self.rollout_info.tp = tp + self.rollout_info.ep = ep + self.rollout_info.api_key = rollout_config.api_key + rollout_server_url = server_url_dict.get(self.rank, "") + if not worker_server_urls_status.get(rollout_server_url, False): + self.logger.error(f"Rollout server url {rollout_server_url} is not available.") + self.rollout_info.rollout_url = None + else: + self.rollout_info.rollout_url = rollout_server_url + + if self.rollout_info.transport_type == "ipc": + # Colocated rollout metadata. + # rollout_device_mesh is created after train_rollout_mode is set. + self.rollout_info.rollout_engine_rank_mesh_array = [ + [int(rank) for rank in ranks] for ranks in engine_rank_mesh_array + ] + self._ensure_rollout_device_mesh() + elif self.rollout_info.transport_type == "nccl": + # Disaggregated rollout metadata. + self.rollout_info.rollout_server_url_dict = {int(rank): url for rank, url in server_url_dict.items()} + self.rollout_info.worker_server_urls_status = worker_server_urls_status + self.rollout_info.weight_update_host = weight_update_host + self.rollout_info.weight_update_port = weight_update_port if weight_update_port is not None else 30000 + + new_transport_signature = self._build_transport_signature( + engine_rank_mesh_array=engine_rank_mesh_array, + server_url_dict=server_url_dict, + worker_server_urls_status=worker_server_urls_status, + train_rollout_mode=train_rollout_mode, + backend=self.rollout_info.backend, + tp=tp, + ep=ep, + ) + # Weight transports may cache resources derived from rollout metadata. + # Since rollout workers can fail and recover with new URL/status/mesh metadata, + # reset the cached transport whenever that metadata changes. + if self._transport_signature is not None and new_transport_signature != self._transport_signature: + self.logger.info("Rollout metadata changed, reset weight transport.") + self._reset_transport() + self._transport_signature = new_transport_signature + + self.weight_iterator = WeightIterator( + config=self.config, + engine=self._engine, + rollout_info=self.rollout_info, + global_hf_keys_mapping_cache=self._global_hf_keys_mapping_cache, + ) + if self._transport is None: + self._set_transport() + + def _ensure_rollout_device_mesh(self): + if self.rollout_info.rollout_device_mesh is None: + # 非共卡 SGLang 不使用这个 mesh;只有共卡/旧权重同步路径需要 + # 用 rollout rank 构造 torch DeviceMesh。 + self.rollout_info.rollout_device_mesh = DeviceMesh( + "cpu", + mesh=self.rollout_info.rollout_engine_rank_mesh_array, + mesh_dim_names=("engine_instance", "engine_parallel"), + ) + + def set_train_rollout_mode(self, train_rollout_mode: TrainRolloutMode | str): + assert train_rollout_mode is not None, "update_rollout_info() must set train_rollout_mode." + + if self.rollout_info.backend is None: + raise RuntimeError("rollout backend is not set. Please set rollout backend in update_rollout_info().") + + mode = train_rollout_mode.lower() + if mode not in ("colocate", "disaggregated"): + raise ValueError( + f"Unsupported train_rollout_mode: {train_rollout_mode!r}. Expected 'colocate' or 'disaggregated'." + ) + mode = cast(TrainRolloutMode, mode) + self.rollout_info.train_rollout_mode = mode + if mode == "colocate": + self.rollout_info.transport_type = "ipc" + elif mode == "disaggregated": + self.rollout_info.transport_type = "nccl" + + backend = self.rollout_info.backend + if backend == "vllm" or backend == "turbomind": + raise NotImplementedError(f"Disaggregated train-rollout mode is not supported for {backend} backend.") + + def update_weights(self): + """Update the model weights.""" + + assert self._transport is not None, ( + f"Weight transport is not initialized. transport_type={self.rollout_info.transport_type!r}, " + f"backend={self.rollout_info.backend!r}." + ) + self._transport.update(self.weight_iterator) + + def _set_transport(self) -> None: + if self.rollout_info.transport_type == "ipc": + self._transport = IPCWeightTransport( + rank=self.rank, + logger=self.logger, + config=self.config, + rollout_info=self.rollout_info, + ) + elif self.rollout_info.transport_type == "nccl": + self._transport = NCCLWeightTransport(rank=self.rank, logger=self.logger, rollout_info=self.rollout_info) + else: + raise NotImplementedError + + def _build_transport_signature( + self, + *, + engine_rank_mesh_array: DeviceMeshRaw, + server_url_dict: ServiceUrlMap, + worker_server_urls_status: dict[str, bool], + train_rollout_mode: TrainRolloutMode, + backend: RolloutBackend, + tp: int, + ep: int, + ) -> tuple[Any, ...]: + mesh = tuple(tuple(int(rank) for rank in ranks) for ranks in engine_rank_mesh_array) + + active_urls = tuple( + sorted( + (int(rank), url) + for rank, url in server_url_dict.items() + if url and worker_server_urls_status.get(url, False) + ) + ) + + return ( + train_rollout_mode, + backend, + tp, + ep, + mesh, + active_urls, + ) + + def _reset_transport(self) -> None: + if self._transport is not None: + self._transport.teardown() + self._transport = None diff --git a/xtuner/v1/rl/weight_update/weight_iterator.py b/xtuner/v1/rl/weight_update/weight_iterator.py new file mode 100644 index 0000000000..9e8c5b783d --- /dev/null +++ b/xtuner/v1/rl/weight_update/weight_iterator.py @@ -0,0 +1,352 @@ +from __future__ import annotations + +from itertools import chain +from typing import Any, cast + +import torch +import torch.distributed as dist +import tqdm +from torch.distributed.tensor import DTensor + +from xtuner.v1.model.compose.base import BaseComposeConfig +from xtuner.v1.model.compose.qwen3_vl import Qwen3VLForConditionalGeneration +from xtuner.v1.model.moe.moe import MoE +from xtuner.v1.utils import get_device, get_torch_device_module +from xtuner.v1.utils.load_spec import LoadEnum, LoadSpec + +from .data import RolloutWeightUpdateInfo, WeightUpdateBatch + + +DEVICE = get_device() +DEVICE_MODULE = get_torch_device_module() + + +class WeightIterator: + def __init__( + self, + *, + config: Any, + engine: Any, + rollout_info: RolloutWeightUpdateInfo, + global_hf_keys_mapping_cache: dict[str, list[str]], + ): + self.config = config + self._engine = engine + self.rollout_info = rollout_info + self._global_hf_keys_mapping_cache = global_hf_keys_mapping_cache + + def iter_batch_groups(self): + # Export path depends on rollout protocol: turbomind consumes layer-wise batches, + # compose models update submodules in order, and plain models use HF-style batches. + if self.rollout_info.train_rollout_mode == "colocate" and self.rollout_info.backend == "turbomind": + yield self.iter_layer_batches() + return + + if isinstance(self.config.model_cfg, BaseComposeConfig): + # Only the last compose submodule sends the final update marker. + submodules = ( + ("language_model", False), + ("vision_tower", False), + ("multi_modal_projector", True), + ) + for submodule, final_update in submodules: + yield self.iter_hf_batches(submodule=submodule, final_update=final_update) + return + + yield self.iter_hf_batches(final_update=True) + + def _get_hf_params( + self, + model, + model_ep_size: int, + target_ep_size: int, + target_ep_rank: int, + fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]], + should_gather_train_ep_shards: bool, + ) -> tuple[list[torch.Tensor], list[str]]: + hf_keys_list: list[str] = [] + hf_tensor_list: list[torch.Tensor] = [] + + for fsdp_tensor, load_spec in fsdp_tensor_list: + hf_keys = load_spec.hf_keys + if model_ep_size > 1 and model.ep_mesh is not None: + # Each train EP rank owns only part of the HF key list; gather the global + # mapping once so rollout EP ranks can receive the right slice. + if load_spec.name not in self._global_hf_keys_mapping_cache: + global_hf_keys: list[list[str] | None] = [None] * model_ep_size + dist.all_gather_object(global_hf_keys, hf_keys, group=model.ep_mesh.get_group()) + global_hf_keys_gathered = cast(list[list[str]], global_hf_keys) + self._global_hf_keys_mapping_cache[load_spec.name] = list( + chain.from_iterable(global_hf_keys_gathered) + ) + hf_keys = self._global_hf_keys_mapping_cache[load_spec.name] + + fused_full_tensor = fsdp_tensor.bfloat16() + if isinstance(fused_full_tensor, DTensor): + fused_full_tensor = fused_full_tensor.full_tensor() + # FUSED load specs pack multiple HF tensors along load_spec.dim; split them + # back into HF tensors before selecting the target rollout EP shard. + dim = cast(int, load_spec.dim) + + if should_gather_train_ep_shards and model_ep_size > 1: + assert model.ep_mesh is not None + ep_group = model.ep_mesh.get_group() + + output = torch.empty( + *fused_full_tensor.shape[:dim], + fused_full_tensor.shape[dim] * model_ep_size, + *fused_full_tensor.shape[dim + 1 :], + dtype=fused_full_tensor.dtype, + device=fused_full_tensor.device, + ) + dist.all_gather_into_tensor(output, fused_full_tensor.contiguous(), group=ep_group) + fused_full_tensor = output + + num_split = len(hf_keys) + hf_tensor_size = fused_full_tensor.shape[dim] / num_split + assert hf_tensor_size.is_integer(), "Internal Error, hf_tensor_size is not integer" + hf_tensor_size = int(hf_tensor_size) + + hf_tensor = fused_full_tensor.split([hf_tensor_size] * num_split, dim=dim) + assert num_split % target_ep_size == 0, ( + f"len(hf_keys) of '{hf_keys}' is {num_split}, it must be divisible by target_ep_size {target_ep_size}" + ) + start_idx = (num_split // target_ep_size) * target_ep_rank + end_idx = (num_split // target_ep_size) * (target_ep_rank + 1) + + hf_keys_list.extend(hf_keys[start_idx:end_idx]) + hf_tensor_list.extend(hf_tensor[start_idx:end_idx]) + + hf_tensor_list = [ + model.param_to_safetensor(safetensor, name) for safetensor, name in zip(hf_tensor_list, hf_keys_list) + ] + + return hf_tensor_list, hf_keys_list + + def _rl_get_fused_ep_hf_param( + self, + model: MoE, + target_ep_rank: int, + target_ep_size: int, + bucket_size: int, + should_gather_train_ep_shards: bool, + ): + fused_param_groups: list[tuple[torch.Tensor, LoadSpec]] = model._group_param_by_load_spec(LoadEnum.FUSED) + model_ep_size = 1 if model.fsdp_config is None else model.fsdp_config.ep_size + if not fused_param_groups: + return + + safetensor_size = 0 + dtype = torch.bfloat16 + tensor_list: list[tuple[torch.Tensor, LoadSpec]] = [] + + for param, load_spec in fused_param_groups: + tensor_size = dtype.itemsize * param.numel() // target_ep_size + if safetensor_size + tensor_size > bucket_size and tensor_list: + hf_params, name_list = self._get_hf_params( + model, + model_ep_size=model_ep_size, + target_ep_size=target_ep_size, + target_ep_rank=target_ep_rank, + fsdp_tensor_list=tensor_list, + should_gather_train_ep_shards=should_gather_train_ep_shards, + ) + yield name_list, hf_params + safetensor_size = tensor_size + # Kept to mirror the legacy generator layout; the next iteration rebuilds + # name_list from tensor_list before yielding. + name_list = load_spec.hf_keys.copy() + tensor_list = [(param, load_spec)] + continue + safetensor_size += tensor_size + tensor_list.append((param, load_spec)) + + if tensor_list: + hf_params, name_list = self._get_hf_params( + model=model, + model_ep_size=model_ep_size, + target_ep_size=target_ep_size, + target_ep_rank=target_ep_rank, + fsdp_tensor_list=tensor_list, + should_gather_train_ep_shards=should_gather_train_ep_shards, + ) + yield name_list, hf_params + + @torch.no_grad() + def iter_hf_batches(self, submodule=None, final_update=False): + """Update the model weights.""" + + model = self._engine.model + if submodule: + model = getattr(model, submodule) + + dtype = torch.bfloat16 + bucket_size = int(self.config.update_weight_bucket_size_in_gb * 1024**3) + same_gen = model._get_same_hf_param( + model._group_param_by_load_spec(LoadEnum.SAME), + dtype=dtype, + device=DEVICE, + bucket_size=bucket_size, + ) + + train_enable_ep = model.fsdp_config is not None and model.fsdp_config.ep_size > 1 + should_gather_train_ep_shards = self.rollout_info.train_rollout_mode == "disaggregated" and train_enable_ep + + if train_enable_ep: + if self.rollout_info.train_rollout_mode == "colocate" and self.rollout_info.ep > 1: + rollout_device_mesh = self.rollout_info.rollout_device_mesh + assert rollout_device_mesh is not None + # Colocated IPC can send only the expert slice needed by the local rollout + # EP rank + fused_gen = self._rl_get_fused_ep_hf_param( + model, + target_ep_rank=rollout_device_mesh["engine_parallel"].get_coordinate()[0], + target_ep_size=rollout_device_mesh["engine_parallel"].size(), + bucket_size=bucket_size, + should_gather_train_ep_shards=should_gather_train_ep_shards, + ) + else: + # Disaggregated NCCL uses one trainer-side broadcast for all rollout ranks. + # Gather train EP shards first, then send the full expert tensor instead of + # slicing by rollout EP rank. + fused_gen = self._rl_get_fused_ep_hf_param( + model, + target_ep_rank=0, + target_ep_size=1, + bucket_size=bucket_size, + should_gather_train_ep_shards=should_gather_train_ep_shards, + ) + else: + fused_gen = model._get_fused_hf_param( + model._group_param_by_load_spec(LoadEnum.FUSED), + dtype=dtype, + device=DEVICE, + bucket_size=bucket_size, + update_weights_for_rl=True, + ) + shard_gen = model._get_shard_hf_param( + model._group_param_by_load_spec(LoadEnum.SHARD), + dtype=dtype, + device=DEVICE, + bucket_size=bucket_size, + ) + + for name_list, fused_param_list in fused_gen: + state_dict = {name: param.detach() for name, param in zip(name_list, fused_param_list)} + yield WeightUpdateBatch(state_dict, train_enable_ep=train_enable_ep, finished=False) + del state_dict, name_list, fused_param_list + + for name_list, param_list in chain(same_gen, shard_gen): + state_dict = {name: param.detach() for name, param in zip(name_list, param_list)} + yield WeightUpdateBatch(state_dict, train_enable_ep=train_enable_ep, finished=False) + del state_dict, name_list, param_list + + # pytorch and vLLM use an empty final update as an end marker; SGLang and + # turbomind do not consume this marker. + if self.rollout_info.backend in ("pytorch", "vllm") and final_update: + yield WeightUpdateBatch({}, train_enable_ep=train_enable_ep, finished=True) + + DEVICE_MODULE.empty_cache() + + @torch.no_grad() + def iter_layer_batches(self): + """Update the model weights.""" + assert self.rollout_info.rollout_device_mesh is not None + + model = self._engine.model + DEVICE_MODULE.empty_cache() + + if isinstance(model.config, BaseComposeConfig): + # TODO: support float8 for vision compose model. + dtype = torch.bfloat16 + else: + if (model.config.float8_cfg is not None) and (model.config.float8_cfg.enable_float8): + dtype = torch.float8_e4m3fn + else: + dtype = torch.bfloat16 + + def get_params(tensor_list, name_list, save_dtype): + _tensor_list, _spec_list = list(zip(*tensor_list)) + fsdp_unshard_tensor_list = model._fsdp_foreach_allgather(_tensor_list, _spec_list) + if save_dtype == torch.float8_e4m3fn: + fsdp_unshard_tensor_list, name_list = model._to_float8( + fsdp_unshard_tensor_list, name_list, _tensor_list, save_dtype + ) + return fsdp_unshard_tensor_list, name_list + + saved_list = [] + is_qwen3vl = False + if isinstance(model.config, BaseComposeConfig): + language_model = model.language_model + if isinstance(model, Qwen3VLForConditionalGeneration): + is_qwen3vl = True + else: + language_model = model + + if is_qwen3vl: + vision_hf_prefix = "model.visual." + projector_hf_prefix = "model.visual." + else: + vision_hf_prefix = "model.vision_tower." + projector_hf_prefix = "model.multi_modal_projector." + + for i, layer in tqdm.tqdm(language_model.layers.items(), desc="[gather weight]"): + tensor_list = [] + name_list = [] + for sub_name, param in layer.state_dict().items(): + if isinstance(model.config, BaseComposeConfig): + saved_list.append(f"language_model.layers.{i}.{sub_name}") + else: + saved_list.append(f"layers.{i}.{sub_name}") + local_tensor = param._local_tensor if isinstance(param, DTensor) else param + local_tensor = local_tensor.bfloat16() + load_spec = language_model.load_spec_mapping.get(f"layers.{i}.{sub_name}") + + if isinstance(model.config, BaseComposeConfig): + name = f"model.language_model.layers.{i}.{sub_name}" + else: + name = f"model.layers.{i}.{sub_name}" + + if ".experts." in name and ".mlp.experts." not in name: + name = name.replace(".experts.", ".mlp.experts.") + if ".gate." in name and ".mlp.gate." not in name: + name = name.replace(".gate.", ".mlp.gate.") + name_list.append(name) + tensor_list.append((local_tensor, load_spec)) + fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) + state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) + yield WeightUpdateBatch(state_dict) + + for name, param in model.state_dict().items(): + if name in saved_list: + continue + local_tensor = param._local_tensor if isinstance(param, DTensor) else param + local_tensor = local_tensor.bfloat16() + load_spec = model.load_spec_mapping.get(name) + + if isinstance(model.config, BaseComposeConfig): + if "vision_tower." in name: + name = name.replace("vision_tower.", vision_hf_prefix) + elif "multi_modal_projector." in name: + name = name.replace("multi_modal_projector.", projector_hf_prefix) + elif name == "language_model.norm.weight": + name = "model.language_model.norm.weight" + elif name == "language_model.embed_tokens.weight": + name = "model.language_model.embed_tokens.weight" + elif name == "language_model.lm_head.weight": + name = "lm_head.weight" + else: + if name == "norm.weight": + name = "model.norm.weight" + elif name == "embed_tokens.weight": + name = "model.embed_tokens.weight" + tensor_list = [(local_tensor, load_spec)] + name_list = [name] + fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) + state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) + yield WeightUpdateBatch(state_dict) + + if self.rollout_info.backend in ("pytorch", "vllm"): + yield WeightUpdateBatch({}, finished=True) + + DEVICE_MODULE.empty_cache() diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index da1b9eaf87..a0192eaaf5 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -52,6 +52,7 @@ set_cpu_resource_manager, sort_rollout_state_for_deterministic, ) +from xtuner.v1.rl.weight_update.data import TrainRolloutMode from xtuner.v1.train.trainer import LoadCheckpointConfig, XTunerMeta from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger, is_hf_model_path, set_deterministic, timer from xtuner.v1.utils.device import get_device, get_torch_device_module @@ -108,13 +109,21 @@ def check_fa3(): def bind_train_rollout( train_controller: TrainingController, rollout_controller: RolloutControllerProxy, + train_rollout_mode: TrainRolloutMode | str, + weight_update_host: str | None = None, + weight_update_port: int | None = None, ) -> None: """Bind the training and rollout workers for update weights.""" info_dict = ray.get( rollout_controller.get_rollout_metadata.remote(), # type: ignore[attr-defined] timeout=RL_TRAINER_RAY_GET_TIMEOUT, ) - train_controller.update_rollout_info(info_dict) + train_controller.update_rollout_info( + info_dict, + train_rollout_mode=train_rollout_mode, + weight_update_host=weight_update_host, + weight_update_port=weight_update_port, + ) return @@ -1549,14 +1558,17 @@ def __init__(self, cfg: RLColocateTrainerConfig): self.train_controller.offload(target="all") self.rollout_controller = self._rollout_config.build(self._pg) - bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller) + bind_train_rollout( + train_controller=self.train_controller, + rollout_controller=self.rollout_controller, + train_rollout_mode="colocate", + ) replay_buffer = cfg.replay_buffer_config.build() self._build_agent_loop_components(cfg, replay_buffer) if checkpoint_path is not None: asyncio_run(self._resume_agent_loop_manager(checkpoint_path)) - self.train_controller.set_train_rollout_mode("colocate") self._cpu_resource_manager.log_registered_summary() if self._rollout_config.skip_load_weights: @@ -1702,6 +1714,7 @@ def _sync_weights_and_save(self, train_step: int, step_timer_dict: dict) -> bool bind_train_rollout( train_controller=self.train_controller, rollout_controller=self.rollout_controller, + train_rollout_mode="colocate", ) ray.get( self.rollout_controller.onload_weights.remote(), @@ -1745,8 +1758,13 @@ def __init__(self, cfg: RLDisaggregatedTrainerConfig): "In disaggregated mode, should_continue_fn must be default, " "because it does not allow early stopping in production." ) - bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller) - self.train_controller.set_train_rollout_mode("disaggregated") + bind_train_rollout( + train_controller=self.train_controller, + rollout_controller=self.rollout_controller, + train_rollout_mode="disaggregated", + weight_update_host=self._rollout_config.weight_update_host, + weight_update_port=self._rollout_config.weight_update_port, + ) if self._load_checkpoint_cfg.checkpoint_path is not None: self._resume_from_checkpoint(self._load_checkpoint_cfg.checkpoint_path) @@ -1931,7 +1949,11 @@ async def _sync_weights_and_save(self, model_step: int, step_timer_dict: dict): # TODO: 非共卡需要额外加健康检查恢复worker的逻辑,共卡是在训练之前恢复,但是非共卡不需要在训练之前恢复,挂掉就恢复或者更新权重前恢复,需要评估一下哪种方式更合理。 with timer("sync_weight", step_timer_dict): - bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller) + bind_train_rollout( + train_controller=self.train_controller, + rollout_controller=self.rollout_controller, + train_rollout_mode="disaggregated", + ) self.update_weights() def update_weights(self):