@@ -325,7 +325,7 @@ def __init__(
325
325
326
326
self .step_fn = (self .step if self .batch_queue is None else
327
327
self .step_with_batch_queue )
328
- self .global_unfinished_reqs = False
328
+ self .engines_running = False
329
329
330
330
# Background Threads and Queues for IO. These enable us to
331
331
# overlap ZMQ socket IO with GPU since they release the GIL,
@@ -410,19 +410,15 @@ def _process_input_queue(self):
410
410
"""Exits when an engine step needs to be performed."""
411
411
412
412
waited = False
413
- while not self .global_unfinished_reqs and not (
414
- self .scheduler .has_requests ()):
413
+ while not self .engines_running and not (self .scheduler .has_requests ()):
415
414
if logger .isEnabledFor (DEBUG ) and self .input_queue .empty ():
416
415
logger .debug ("EngineCore waiting for work." )
417
416
waited = True
418
417
req = self .input_queue .get ()
419
418
self ._handle_client_request (* req )
420
419
421
420
if waited :
422
- logger .debug (
423
- "EngineCore loop active - local unfinished: %s, finished: %s." ,
424
- self .scheduler .has_unfinished_requests (),
425
- self .scheduler .has_finished_requests ())
421
+ logger .debug ("EngineCore loop active." )
426
422
427
423
# Handle any more client requests.
428
424
while not self .input_queue .empty ():
@@ -446,10 +442,6 @@ def _handle_client_request(self, request_type: EngineCoreRequestType,
446
442
self .add_request (request )
447
443
elif request_type == EngineCoreRequestType .ABORT :
448
444
self .abort_requests (request )
449
- elif request_type == EngineCoreRequestType .START_DP :
450
- if not self .global_unfinished_reqs :
451
- logger .debug ("EngineCore starting idle loop." )
452
- self .global_unfinished_reqs = True
453
445
elif request_type == EngineCoreRequestType .UTILITY :
454
446
call_id , method_name , args = request
455
447
output = UtilityOutput (call_id )
@@ -548,9 +540,6 @@ def process_output_socket(self, output_path: str, engine_index: int):
548
540
socket .send_multipart (buffers , copy = False )
549
541
550
542
551
- ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs (engine_paused = True )
552
-
553
-
554
543
class DPEngineCoreProc (EngineCoreProc ):
555
544
"""ZMQ-wrapper for running EngineCore in background process
556
545
in a data parallel context."""
@@ -587,7 +576,9 @@ def __init__(
587
576
for i in range (local_dp_rank * tp_size , (local_dp_rank + 1 ) *
588
577
tp_size ))
589
578
579
+ self .local_dp_rank = local_dp_rank
590
580
self .dp_group = vllm_config .parallel_config .stateless_init_dp_group ()
581
+ self .current_wave = 0
591
582
592
583
# Initialize the engine after setting up environment.
593
584
super ().__init__ (input_path , output_path , vllm_config , executor_class ,
@@ -602,6 +593,31 @@ def shutdown(self):
602
593
if dp_group := getattr (self , "dp_group" , None ):
603
594
stateless_destroy_torch_distributed_process_group (dp_group )
604
595
596
+ def add_request (self , request : EngineCoreRequest ):
597
+ if request .current_wave != self .current_wave :
598
+ if request .current_wave > self .current_wave :
599
+ self .current_wave = request .current_wave
600
+ elif not self .engines_running :
601
+ # Request received for an already-completed wave, notify
602
+ # front-end that we need to start the next one.
603
+ self .output_queue .put_nowait (
604
+ EngineCoreOutputs (start_wave = self .current_wave ))
605
+
606
+ super ().add_request (request )
607
+
608
+ def _handle_client_request (self , request_type : EngineCoreRequestType ,
609
+ request : Any ) -> None :
610
+ if request_type == EngineCoreRequestType .START_DP_WAVE :
611
+ new_wave : int = request
612
+ if new_wave >= self .current_wave :
613
+ self .current_wave = new_wave
614
+ if not self .engines_running :
615
+ logger .debug ("EngineCore starting idle loop for wave %d." ,
616
+ new_wave )
617
+ self .engines_running = True
618
+ else :
619
+ super ()._handle_client_request (request_type , request )
620
+
605
621
def run_busy_loop (self ):
606
622
"""Core busy loop of the EngineCore for data parallel case."""
607
623
@@ -628,7 +644,7 @@ def run_busy_loop(self):
628
644
# up-to-date state is returned in the engine outputs.
629
645
self ._process_engine_step ()
630
646
631
- if not self .global_unfinished_reqs :
647
+ if not self .engines_running :
632
648
# All engines are idle.
633
649
continue
634
650
@@ -637,18 +653,23 @@ def run_busy_loop(self):
637
653
self .execute_dummy_batch ()
638
654
639
655
# 3) All-reduce operation to determine global unfinished reqs.
640
- self .global_unfinished_reqs = self ._has_global_unfinished_reqs (
656
+ self .engines_running = self ._has_global_unfinished_reqs (
641
657
local_unfinished_reqs )
642
658
643
- if not self .global_unfinished_reqs :
644
- # Notify client that we are pausing the loop.
645
- self .output_queue .put_nowait (ENGINE_PAUSED_OUTPUTS )
659
+ if not self .engines_running :
660
+ if self .local_dp_rank == 0 :
661
+ # Notify client that we are pausing the loop.
662
+ logger .debug ("Wave %d finished, pausing engine loop." ,
663
+ self .current_wave )
664
+ self .output_queue .put_nowait (
665
+ EngineCoreOutputs (wave_complete = self .current_wave ))
666
+ self .current_wave += 1
646
667
647
668
def _has_global_unfinished_reqs (self , local_unfinished : bool ) -> bool :
648
669
649
- # Optimization - only perform finish-sync all-reduce every 16 steps.
670
+ # Optimization - only perform finish-sync all-reduce every 24 steps.
650
671
self .counter += 1
651
- if self .counter != 16 :
672
+ if self .counter != 24 :
652
673
return True
653
674
self .counter = 0
654
675
0 commit comments