Skip to content

Commit 2610029

Browse files
njhillwuisawesome
authored andcommitted
[Core] Remove prompt string from engine core data structures (vllm-project#17214)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 38236e9 commit 2610029

21 files changed

+40
-76
lines changed

tests/tokenization/test_detokenize.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def _run_incremental_decode(tokenizer,
6060
skip_special_tokens=skip_special_tokens,
6161
spaces_between_special_tokens=spaces_between_special_tokens,
6262
)
63-
request = EngineCoreRequest("", "", prompt_token_ids, None, None, None,
64-
params, None, 0.0, None)
63+
request = EngineCoreRequest("", prompt_token_ids, None, None, None, params,
64+
None, 0.0, None)
6565

6666
if fast is None:
6767
detokenizer = IncrementalDetokenizer.from_new_request(

tests/v1/core/test_kv_cache_utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def make_request(request_id,
3737

3838
return Request(
3939
request_id=request_id,
40-
prompt=None,
4140
prompt_token_ids=prompt_token_ids,
4241
multi_modal_inputs=multi_modal_inputs,
4342
multi_modal_hashes=mm_hashes,

tests/v1/core/test_prefix_caching.py

-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def make_request(request_id,
2929

3030
return Request(
3131
request_id=request_id,
32-
prompt=None,
3332
prompt_token_ids=prompt_token_ids,
3433
multi_modal_inputs=multi_modal_inputs,
3534
multi_modal_hashes=mm_hashes,

tests/v1/core/test_scheduler.py

-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ def create_requests(num_requests: int,
132132
mm_inputs = None
133133
request = Request(
134134
request_id=f"{i}",
135-
prompt=None,
136135
prompt_token_ids=[i] * num_tokens,
137136
sampling_params=sampling_params,
138137
multi_modal_inputs=mm_inputs,

tests/v1/engine/test_engine_core.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@
3131

3232
def make_request() -> EngineCoreRequest:
3333
return EngineCoreRequest(
34-
request_id=uuid.uuid4(),
35-
prompt=PROMPT,
34+
request_id=str(uuid.uuid4()),
3635
prompt_token_ids=PROMPT_TOKENS,
3736
mm_inputs=None,
3837
mm_hashes=None,

tests/v1/engine/test_engine_core_client.py

-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
def make_request(params: SamplingParams) -> EngineCoreRequest:
3636
return EngineCoreRequest(
3737
request_id=str(uuid.uuid4()),
38-
prompt=PROMPT,
3938
prompt_token_ids=PROMPT_TOKENS,
4039
mm_inputs=None,
4140
mm_hashes=None,

tests/v1/engine/test_output_processor.py

+16-26
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
5050
# Make N requests.
5151
requests = [
5252
EngineCoreRequest(request_id=f"request-{idx}",
53-
prompt=prompt,
5453
prompt_token_ids=prompt_tokens,
5554
arrival_time=0,
5655
mm_inputs=None,
@@ -64,14 +63,13 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
6463
output_kind=request_output_kind,
6564
stop=[],
6665
include_stop_str_in_output=False,
67-
)) for idx, (prompt, prompt_tokens) in enumerate(
68-
zip(dummy_test_vectors.prompt_strings,
69-
dummy_test_vectors.prompt_tokens))
66+
))
67+
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
7068
]
7169

7270
# Add requests to the detokenizer.
73-
for request in requests:
74-
output_processor.add_request(request)
71+
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
72+
output_processor.add_request(request, prompt)
7573

7674
gen_strings = {}
7775
gen_tokens = {}
@@ -398,7 +396,6 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
398396
]
399397
requests = [
400398
EngineCoreRequest(request_id=request_id_list[idx],
401-
prompt=prompt,
402399
prompt_token_ids=prompt_tokens,
403400
arrival_time=0,
404401
mm_inputs=None,
@@ -414,14 +411,13 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
414411
include_stop_str_in_output=False,
415412
logprobs=num_sample_logprobs,
416413
prompt_logprobs=num_prompt_logprobs,
417-
)) for idx, (prompt, prompt_tokens) in enumerate(
418-
zip(dummy_test_vectors.prompt_strings,
419-
dummy_test_vectors.prompt_tokens))
414+
))
415+
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
420416
]
421417

422418
# Add requests to the detokenizer.
423-
for request in requests:
424-
output_processor.add_request(request)
419+
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
420+
output_processor.add_request(request, prompt)
425421

426422
gen_tokens = {}
427423
gen_logprobs = {}
@@ -562,7 +558,6 @@ def test_stop_token(include_stop_str_in_output: bool,
562558
request_id = "request-0"
563559
request = EngineCoreRequest(
564560
request_id=request_id,
565-
prompt=prompt_string,
566561
prompt_token_ids=prompt_tokens,
567562
arrival_time=0,
568563
mm_inputs=None,
@@ -583,7 +578,7 @@ def test_stop_token(include_stop_str_in_output: bool,
583578
))
584579

585580
# Add request to the detokenizer.
586-
output_processor.add_request(request)
581+
output_processor.add_request(request, prompt_string)
587582

588583
# Loop over engine core steps; run output processor
589584
gen_string = ""
@@ -659,7 +654,6 @@ def test_stop_string(include_stop_str_in_output: bool,
659654
requests = [
660655
EngineCoreRequest(
661656
request_id=request_id_list[idx],
662-
prompt=prompt,
663657
prompt_token_ids=prompt_tokens,
664658
arrival_time=0,
665659
mm_inputs=None,
@@ -675,14 +669,13 @@ def test_stop_string(include_stop_str_in_output: bool,
675669
include_stop_str_in_output=include_stop_str_in_output,
676670
logprobs=num_sample_logprobs,
677671
prompt_logprobs=None,
678-
)) for idx, (prompt, prompt_tokens) in enumerate(
679-
zip(dummy_test_vectors.prompt_strings,
680-
dummy_test_vectors.prompt_tokens))
672+
))
673+
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
681674
]
682675

683676
# Add requests to the detokenizer.
684-
for request in requests:
685-
output_processor.add_request(request)
677+
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
678+
output_processor.add_request(request, prompt)
686679

687680
gen_strings = {}
688681
gen_tokens = {}
@@ -774,7 +767,6 @@ def test_iteration_stats(dummy_test_vectors):
774767
requests = [
775768
EngineCoreRequest(
776769
request_id=f"request-{idx}",
777-
prompt=prompt,
778770
prompt_token_ids=prompt_tokens,
779771
arrival_time=0,
780772
mm_inputs=None,
@@ -783,15 +775,13 @@ def test_iteration_stats(dummy_test_vectors):
783775
eos_token_id=None,
784776
lora_request=None,
785777
sampling_params=SamplingParams(),
786-
) for idx, (prompt, prompt_tokens) in enumerate(
787-
zip(dummy_test_vectors.prompt_strings,
788-
dummy_test_vectors.prompt_tokens))
778+
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
789779
]
790780

791781
# Add all requests except one to the OutputProcessor.
792782
num_active = len(dummy_test_vectors.generation_tokens) - 1
793783
for request in requests[:num_active]:
794-
output_processor.add_request(request)
784+
output_processor.add_request(request, None)
795785
inactive_request = requests[num_active]
796786

797787
# First iteration has 2 prefills.
@@ -817,7 +807,7 @@ def test_iteration_stats(dummy_test_vectors):
817807
assert iteration_stats.num_generation_tokens == num_active
818808

819809
# Add a new request - prefill and 2 decodes in this step.
820-
output_processor.add_request(inactive_request)
810+
output_processor.add_request(inactive_request, None)
821811
num_active += 1
822812
outputs = engine_core.get_outputs()[:num_active]
823813
iteration_stats = IterationStats()

tests/v1/tpu/worker/test_tpu_model_runner.py

-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
7777
NewRequestData(
7878
req_id=req_id,
7979
prompt_token_ids=[1, 2, 3],
80-
prompt="test",
8180
mm_inputs=[],
8281
mm_hashes=[],
8382
mm_positions=[],

tests/v1/worker/test_gpu_input_batch.py

-1
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ def _construct_cached_request_state(req_id_suffix: int):
195195
return CachedRequestState(
196196
req_id=f"req_id_{req_id_suffix}",
197197
prompt_token_ids=prompt_token_ids,
198-
prompt=None,
199198
sampling_params=_create_sampling_params(),
200199
mm_inputs=[],
201200
mm_positions=[],

tests/v1/worker/test_gpu_model_runner.py

-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
5050
NewRequestData(
5151
req_id=req_id,
5252
prompt_token_ids=[1, 2, 3],
53-
prompt="test",
5453
mm_inputs=[],
5554
mm_hashes=[],
5655
mm_positions=[],

vllm/v1/core/sched/output.py

-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ class NewRequestData:
2222

2323
req_id: str
2424
prompt_token_ids: list[int]
25-
prompt: Optional[str]
2625
mm_inputs: list[MultiModalKwargs]
2726
mm_hashes: list[str]
2827
mm_positions: list[PlaceholderRange]
@@ -40,7 +39,6 @@ def from_request(
4039
return cls(
4140
req_id=request.request_id,
4241
prompt_token_ids=request.prompt_token_ids,
43-
prompt=request.prompt,
4442
mm_inputs=request.mm_inputs,
4543
mm_hashes=request.mm_hashes,
4644
mm_positions=request.mm_positions,

vllm/v1/engine/__init__.py

-3
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,6 @@ class EngineCoreRequest(
4949
# due to circular imports and typing we have in data.py
5050

5151
request_id: str
52-
# NOTE(ywang96): original text prompt is needed when a request is added to
53-
# Detokenizer, but set to None when it is added to EngineCoreClient.
54-
prompt: Optional[str]
5552
prompt_token_ids: list[int]
5653
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
5754
mm_hashes: Optional[list[str]]

vllm/v1/engine/async_llm.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -217,14 +217,12 @@ async def add_request(
217217
queue = RequestOutputCollector(output_kind=params.output_kind)
218218

219219
# Convert Input --> Request.
220-
request = self.processor.process_inputs(request_id, prompt, params,
221-
arrival_time, lora_request,
222-
trace_headers,
223-
prompt_adapter_request,
224-
priority)
220+
prompt_str, request = self.processor.process_inputs(
221+
request_id, prompt, params, arrival_time, lora_request,
222+
trace_headers, prompt_adapter_request, priority)
225223

226224
if params.n == 1:
227-
await self._add_request(request, None, 0, queue)
225+
await self._add_request(request, prompt_str, None, 0, queue)
228226
return queue
229227

230228
# Fan out child requests (for n>1).
@@ -234,15 +232,18 @@ async def add_request(
234232
child_request = request if idx == params.n - 1 else copy(request)
235233
child_request.request_id = request_id
236234
child_request.sampling_params = params
237-
await self._add_request(child_request, parent_request, idx, queue)
235+
await self._add_request(child_request, prompt_str, parent_request,
236+
idx, queue)
238237
return queue
239238

240239
async def _add_request(self, request: EngineCoreRequest,
240+
prompt: Optional[str],
241241
parent_req: Optional[ParentRequest], index: int,
242242
queue: RequestOutputCollector):
243243

244244
# Add the request to OutputProcessor (this process).
245-
self.output_processor.add_request(request, parent_req, index, queue)
245+
self.output_processor.add_request(request, prompt, parent_req, index,
246+
queue)
246247

247248
# Add the EngineCoreRequest to EngineCore (separate process).
248249
await self.engine_core.add_request_async(request)

vllm/v1/engine/core_client.py

-9
Original file line numberDiff line numberDiff line change
@@ -583,9 +583,6 @@ def call_utility(self, method: str, *args) -> Any:
583583
return future.result()
584584

585585
def add_request(self, request: EngineCoreRequest) -> None:
586-
# NOTE: text prompt is not needed in the core engine as it has been
587-
# tokenized.
588-
request.prompt = None
589586
self._send_input(EngineCoreRequestType.ADD, request)
590587

591588
def abort_requests(self, request_ids: list[str]) -> None:
@@ -772,9 +769,6 @@ async def _call_utility_async(self, method: str, *args,
772769
return await future
773770

774771
async def add_request_async(self, request: EngineCoreRequest) -> None:
775-
# NOTE: text prompt is not needed in the core engine as it has been
776-
# tokenized.
777-
request.prompt = None
778772
await self._send_input(EngineCoreRequestType.ADD, request)
779773
self._ensure_output_queue_task()
780774

@@ -867,9 +861,6 @@ async def call_utility_async(self, method: str, *args) -> Any:
867861
]))[0]
868862

869863
async def add_request_async(self, request: EngineCoreRequest) -> None:
870-
# NOTE: text prompt is not needed in the core engine as it has been
871-
# tokenized.
872-
request.prompt = None
873864
request.current_wave = self.current_wave
874865

875866
chosen_engine = self.get_core_engine_for_request()

vllm/v1/engine/llm_engine.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -180,17 +180,15 @@ def add_request(
180180
priority: int = 0,
181181
) -> None:
182182
# Process raw inputs into the request.
183-
request = self.processor.process_inputs(request_id, prompt, params,
184-
arrival_time, lora_request,
185-
trace_headers,
186-
prompt_adapter_request,
187-
priority)
183+
prompt_str, request = self.processor.process_inputs(
184+
request_id, prompt, params, arrival_time, lora_request,
185+
trace_headers, prompt_adapter_request, priority)
188186

189187
n = params.n if isinstance(params, SamplingParams) else 1
190188

191189
if n == 1:
192190
# Make a new RequestState and queue.
193-
self.output_processor.add_request(request, None, 0)
191+
self.output_processor.add_request(request, prompt_str, None, 0)
194192
# Add the request to EngineCore.
195193
self.engine_core.add_request(request)
196194
return
@@ -204,7 +202,8 @@ def add_request(
204202
child_request.sampling_params = params
205203

206204
# Make a new RequestState and queue.
207-
self.output_processor.add_request(child_request, parent_req, idx)
205+
self.output_processor.add_request(child_request, prompt_str,
206+
parent_req, idx)
208207
# Add the request to EngineCore.
209208
self.engine_core.add_request(child_request)
210209

vllm/v1/engine/output_processor.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def from_new_request(
109109
cls,
110110
tokenizer: AnyTokenizer,
111111
request: EngineCoreRequest,
112+
prompt: Optional[str],
112113
parent_req: Optional[ParentRequest],
113114
request_index: int,
114115
queue: Optional[RequestOutputCollector],
@@ -123,7 +124,7 @@ def from_new_request(
123124
lora_name=(request.lora_request.name
124125
if request.lora_request is not None else None),
125126
output_kind=request.sampling_params.output_kind,
126-
prompt=request.prompt,
127+
prompt=prompt,
127128
prompt_token_ids=request.prompt_token_ids,
128129
logprobs_processor=LogprobsProcessor.from_new_request(
129130
tokenizer=tokenizer,
@@ -267,6 +268,7 @@ def abort_requests(
267268
def add_request(
268269
self,
269270
request: EngineCoreRequest,
271+
prompt: Optional[str],
270272
parent_req: Optional[ParentRequest] = None,
271273
request_index: int = 0,
272274
queue: Optional[RequestOutputCollector] = None,
@@ -278,6 +280,7 @@ def add_request(
278280
req_state = RequestState.from_new_request(
279281
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
280282
request=request,
283+
prompt=prompt,
281284
parent_req=parent_req,
282285
request_index=request_index,
283286
queue=queue,

vllm/v1/engine/processor.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def process_inputs(
202202
trace_headers: Optional[Mapping[str, str]] = None,
203203
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
204204
priority: int = 0,
205-
) -> EngineCoreRequest:
205+
) -> tuple[Optional[str], EngineCoreRequest]:
206206

207207
# TODO(woosuk): Support pooling models.
208208
# TODO(woosuk): Support encoder-decoder models.
@@ -306,9 +306,8 @@ def process_inputs(
306306
else:
307307
sorted_mm_inputs = orig_sorted_mm_inputs
308308

309-
return EngineCoreRequest(
309+
return decoder_inputs.get("prompt"), EngineCoreRequest(
310310
request_id=request_id,
311-
prompt=decoder_inputs.get("prompt"),
312311
prompt_token_ids=decoder_inputs["prompt_token_ids"],
313312
mm_inputs=sorted_mm_inputs,
314313
mm_hashes=sorted_mm_hashes,

0 commit comments

Comments
 (0)