Skip to content

Commit d664d05

Browse files
njhilladobrzyn
authored andcommitted
[V1][DP] More robust DP/EP dummy request coordination (vllm-project#16277)
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
1 parent d4a8c54 commit d664d05

File tree

4 files changed

+94
-57
lines changed

4 files changed

+94
-57
lines changed

tests/v1/test_async_llm_dp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind):
101101
# the engines only synchronize stopping every N steps so
102102
# allow a small amount of time here.
103103
for _ in range(10):
104-
if core_client.num_engines_running == 0:
104+
if not core_client.engines_running:
105105
break
106106
await asyncio.sleep(0.5)
107107

108-
assert core_client.num_engines_running == 0
108+
assert not core_client.engines_running
109109
assert not core_client.reqs_in_flight

vllm/v1/engine/__init__.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ class EngineCoreRequest(
6161
arrival_time: float
6262
lora_request: Optional[LoRARequest]
6363

64+
# Used in DP case to indicate which wave of requests this is expected to
65+
# belong to, to cover a race condition where the request is sent before
66+
# a wave finished notification is received.
67+
current_wave: int = 0
68+
6469

6570
class EngineCoreEventType(enum.IntEnum):
6671
"""The type of engine core request event."""
@@ -139,8 +144,12 @@ class EngineCoreOutputs(
139144
utility_output: Optional[UtilityOutput] = None
140145
finished_requests: Optional[set[str]] = None
141146

142-
# In DP case, used to signal that the engine is paused.
143-
engine_paused: bool = False
147+
# In DP case, used to signal that the current wave of requests
148+
# has finished and the engines are paused.
149+
wave_complete: Optional[int] = None
150+
# In DP case, used to signal that a request was received for an
151+
# "old" wave, so the next wave needs to be started in other engines.
152+
start_wave: Optional[int] = None
144153

145154
def __post_init__(self):
146155
if self.timestamp == 0.0:
@@ -154,7 +163,7 @@ class EngineCoreRequestType(enum.Enum):
154163
"""
155164
ADD = b'\x00'
156165
ABORT = b'\x01'
157-
START_DP = b'\x02'
166+
START_DP_WAVE = b'\x02'
158167
UTILITY = b'\x03'
159168
# Sentinel used within EngineCoreProc.
160169
EXECUTOR_FAILED = b'\x04'

vllm/v1/engine/core.py

+42-21
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def __init__(
325325

326326
self.step_fn = (self.step if self.batch_queue is None else
327327
self.step_with_batch_queue)
328-
self.global_unfinished_reqs = False
328+
self.engines_running = False
329329

330330
# Background Threads and Queues for IO. These enable us to
331331
# overlap ZMQ socket IO with GPU since they release the GIL,
@@ -410,19 +410,15 @@ def _process_input_queue(self):
410410
"""Exits when an engine step needs to be performed."""
411411

412412
waited = False
413-
while not self.global_unfinished_reqs and not (
414-
self.scheduler.has_requests()):
413+
while not self.engines_running and not (self.scheduler.has_requests()):
415414
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
416415
logger.debug("EngineCore waiting for work.")
417416
waited = True
418417
req = self.input_queue.get()
419418
self._handle_client_request(*req)
420419

421420
if waited:
422-
logger.debug(
423-
"EngineCore loop active - local unfinished: %s, finished: %s.",
424-
self.scheduler.has_unfinished_requests(),
425-
self.scheduler.has_finished_requests())
421+
logger.debug("EngineCore loop active.")
426422

427423
# Handle any more client requests.
428424
while not self.input_queue.empty():
@@ -446,10 +442,6 @@ def _handle_client_request(self, request_type: EngineCoreRequestType,
446442
self.add_request(request)
447443
elif request_type == EngineCoreRequestType.ABORT:
448444
self.abort_requests(request)
449-
elif request_type == EngineCoreRequestType.START_DP:
450-
if not self.global_unfinished_reqs:
451-
logger.debug("EngineCore starting idle loop.")
452-
self.global_unfinished_reqs = True
453445
elif request_type == EngineCoreRequestType.UTILITY:
454446
call_id, method_name, args = request
455447
output = UtilityOutput(call_id)
@@ -548,9 +540,6 @@ def process_output_socket(self, output_path: str, engine_index: int):
548540
socket.send_multipart(buffers, copy=False)
549541

550542

551-
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
552-
553-
554543
class DPEngineCoreProc(EngineCoreProc):
555544
"""ZMQ-wrapper for running EngineCore in background process
556545
in a data parallel context."""
@@ -587,7 +576,9 @@ def __init__(
587576
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
588577
tp_size))
589578

579+
self.local_dp_rank = local_dp_rank
590580
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
581+
self.current_wave = 0
591582

592583
# Initialize the engine after setting up environment.
593584
super().__init__(input_path, output_path, vllm_config, executor_class,
@@ -602,6 +593,31 @@ def shutdown(self):
602593
if dp_group := getattr(self, "dp_group", None):
603594
stateless_destroy_torch_distributed_process_group(dp_group)
604595

596+
def add_request(self, request: EngineCoreRequest):
597+
if request.current_wave != self.current_wave:
598+
if request.current_wave > self.current_wave:
599+
self.current_wave = request.current_wave
600+
elif not self.engines_running:
601+
# Request received for an already-completed wave, notify
602+
# front-end that we need to start the next one.
603+
self.output_queue.put_nowait(
604+
EngineCoreOutputs(start_wave=self.current_wave))
605+
606+
super().add_request(request)
607+
608+
def _handle_client_request(self, request_type: EngineCoreRequestType,
609+
request: Any) -> None:
610+
if request_type == EngineCoreRequestType.START_DP_WAVE:
611+
new_wave: int = request
612+
if new_wave >= self.current_wave:
613+
self.current_wave = new_wave
614+
if not self.engines_running:
615+
logger.debug("EngineCore starting idle loop for wave %d.",
616+
new_wave)
617+
self.engines_running = True
618+
else:
619+
super()._handle_client_request(request_type, request)
620+
605621
def run_busy_loop(self):
606622
"""Core busy loop of the EngineCore for data parallel case."""
607623

@@ -628,7 +644,7 @@ def run_busy_loop(self):
628644
# up-to-date state is returned in the engine outputs.
629645
self._process_engine_step()
630646

631-
if not self.global_unfinished_reqs:
647+
if not self.engines_running:
632648
# All engines are idle.
633649
continue
634650

@@ -637,18 +653,23 @@ def run_busy_loop(self):
637653
self.execute_dummy_batch()
638654

639655
# 3) All-reduce operation to determine global unfinished reqs.
640-
self.global_unfinished_reqs = self._has_global_unfinished_reqs(
656+
self.engines_running = self._has_global_unfinished_reqs(
641657
local_unfinished_reqs)
642658

643-
if not self.global_unfinished_reqs:
644-
# Notify client that we are pausing the loop.
645-
self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS)
659+
if not self.engines_running:
660+
if self.local_dp_rank == 0:
661+
# Notify client that we are pausing the loop.
662+
logger.debug("Wave %d finished, pausing engine loop.",
663+
self.current_wave)
664+
self.output_queue.put_nowait(
665+
EngineCoreOutputs(wave_complete=self.current_wave))
666+
self.current_wave += 1
646667

647668
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
648669

649-
# Optimization - only perform finish-sync all-reduce every 16 steps.
670+
# Optimization - only perform finish-sync all-reduce every 24 steps.
650671
self.counter += 1
651-
if self.counter != 16:
672+
if self.counter != 24:
652673
return True
653674
self.counter = 0
654675

vllm/v1/engine/core_client.py

+38-31
Original file line numberDiff line numberDiff line change
@@ -792,15 +792,12 @@ class DPAsyncMPClient(AsyncMPClient):
792792
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
793793
log_stats: bool):
794794

795-
self.num_engines_running = 0
795+
self.current_wave = 0
796+
self.engines_running = False
796797
self.reqs_in_flight: dict[str, CoreEngine] = {}
797798

798799
super().__init__(vllm_config, executor_class, log_stats)
799800

800-
# Control message used for triggering dp idle mode loop.
801-
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
802-
*self.encoder.encode(None))
803-
804801
assert len(self.core_engines) > 1
805802

806803
def _init_core_engines(
@@ -829,23 +826,23 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
829826
# NOTE: text prompt is not needed in the core engine as it has been
830827
# tokenized.
831828
request.prompt = None
832-
833-
msg = (EngineCoreRequestType.ADD.value, *self.encoder.encode(request))
829+
request.current_wave = self.current_wave
834830

835831
chosen_engine = self.get_core_engine_for_request()
836832
self.reqs_in_flight[request.request_id] = chosen_engine
837833
chosen_engine.num_reqs_in_flight += 1
838-
if self.num_engines_running >= len(self.core_engines):
839-
await self._send_input_message(msg, chosen_engine)
840-
else:
834+
835+
to_await = self._send_input(EngineCoreRequestType.ADD, request,
836+
chosen_engine)
837+
if not self.engines_running:
841838
# Send request to chosen engine and dp start loop
842839
# control message to all other engines.
843-
self.num_engines_running += len(self.core_engines)
844-
await asyncio.gather(*[
845-
self._send_input_message(
846-
msg if engine is chosen_engine else self.start_dp_msg,
847-
engine) for engine in self.core_engines
848-
])
840+
self.engines_running = True
841+
to_await = asyncio.gather(
842+
to_await, # type: ignore[assignment]
843+
*self._start_wave_coros(exclude_index=chosen_engine.index))
844+
845+
await to_await
849846

850847
self._ensure_output_queue_task()
851848

@@ -860,21 +857,31 @@ async def process_engine_outputs(self: "DPAsyncMPClient",
860857
if engine := self.reqs_in_flight.pop(req_id, None):
861858
engine.num_reqs_in_flight -= 1
862859

863-
if outputs.engine_paused:
864-
assert self.num_engines_running >= 1
865-
self.num_engines_running -= 1
866-
if not self.num_engines_running and self.reqs_in_flight:
867-
# If there are requests in flight here, they must have
868-
# been sent after the engines paused. We must make
869-
# sure to start the other engines:
870-
self.num_engines_running = len(self.core_engines)
871-
coros = [
872-
self._send_input_message(self.start_dp_msg, engine)
873-
for engine in self.core_engines
874-
if not engine.num_reqs_in_flight
875-
]
876-
if coros:
877-
await asyncio.gather(*coros)
860+
if outputs.wave_complete is not None:
861+
# Current wave is complete, move to next wave number
862+
# and mark engines as paused.
863+
if self.current_wave <= outputs.wave_complete:
864+
self.current_wave = outputs.wave_complete + 1
865+
self.engines_running = False
866+
867+
elif outputs.start_wave is not None and (
868+
outputs.start_wave > self.current_wave or
869+
(outputs.start_wave == self.current_wave
870+
and not self.engines_running)):
871+
# Engine received request for a non-current wave so we must ensure
872+
# that other engines progress to the next wave.
873+
self.current_wave = outputs.start_wave
874+
self.engines_running = True
875+
await asyncio.gather(*self._start_wave_coros(
876+
exclude_index=outputs.engine_index))
877+
878+
def _start_wave_coros(self, exclude_index: int) -> list[Awaitable[None]]:
879+
logger.debug("Sending start DP wave %d.", self.current_wave)
880+
return [
881+
self._send_input(EngineCoreRequestType.START_DP_WAVE,
882+
self.current_wave, engine)
883+
for engine in self.core_engines if engine.index != exclude_index
884+
]
878885

879886
async def abort_requests_async(self, request_ids: list[str]) -> None:
880887
if not request_ids:

0 commit comments

Comments
 (0)