Skip to content

[Kernel] Adding basic Triton JitCache for triton_attn #16606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 40 commits into
base: main
Choose a base branch
from

Conversation

bringlein
Copy link
Contributor

@bringlein bringlein commented Apr 14, 2025

In this PR we improve the performance of the V1 triton_attn backend. We do not change the triton kernels, but we reduce the well known launch overhead of triton kernels by caching the JIT artifacts of the decode kernel.

Performance

All of the below results are for meta-llama/Llama-3.1-8B-Instruct on an NVIDIA H100 GPU. All experiments are done with --no-enable-prefix-caching and

current upstream:

$ python benchmarks/benchmark_serving.py \
    --model meta-llama/Llama-3.1-8B-Instruct \
    --dataset-name sharegpt \
    --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json 
============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  25.54     
Total input tokens:                      215196    
Total generated tokens:                  197122    
Request throughput (req/s):              39.15     
Output token throughput (tok/s):         7718.08   
Total Token throughput (tok/s):          16143.83  
---------------Time to First Token----------------
Mean TTFT (ms):                          3572.95   
Median TTFT (ms):                        3427.32   
P99 TTFT (ms):                           6817.08   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          92.99     
Median TPOT (ms):                        53.98     
P99 TPOT (ms):                           239.31    
---------------Inter-token Latency----------------
Mean ITL (ms):                           44.60     
Median ITL (ms):                         32.69     
P99 ITL (ms):                            243.61    
==================================================

with this PR

$ python benchmarks/benchmark_serving.py \
    --model meta-llama/Llama-3.1-8B-Instruct \
    --dataset-name sharegpt \
    --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json
============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  20.08     
Total input tokens:                      215196    
Total generated tokens:                  197466    
Request throughput (req/s):              49.80     
Output token throughput (tok/s):         9833.53   
Total Token throughput (tok/s):          20549.99  
---------------Time to First Token----------------
Mean TTFT (ms):                          3613.66   
Median TTFT (ms):                        3360.28   
P99 TTFT (ms):                           6966.82   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          88.58     
Median TPOT (ms):                        44.06     
P99 TPOT (ms):                           282.93    
---------------Inter-token Latency----------------
Mean ITL (ms):                           35.33     
Median ITL (ms):                         19.82     
P99 ITL (ms):                            234.20    
==================================================

to compare, using V1 FlashAttention backend:

$ python benchmarks/benchmark_serving.py \
    --model meta-llama/Llama-3.1-8B-Instruct \
    --dataset-name sharegpt \
    --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json
============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  21.17     
Total input tokens:                      215196    
Total generated tokens:                  198001    
Request throughput (req/s):              47.23     
Output token throughput (tok/s):         9351.00   
Total Token throughput (tok/s):          19514.07  
---------------Time to First Token----------------
Mean TTFT (ms):                          3332.44   
Median TTFT (ms):                        3174.85   
P99 TTFT (ms):                           6246.38   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          83.21     
Median TPOT (ms):                        48.40     
P99 TPOT (ms):                           216.60    
---------------Inter-token Latency----------------
Mean ITL (ms):                           38.41     
Median ITL (ms):                         27.33     
P99 ITL (ms):                            220.79    
==================================================

So, this PR improves the performance of the V1 triton_attn backend by 27% and outperforms the FlashAttention-3 baseline by 5% for the serving benchmark.

We are in the process of evaluating the performance on AMD GPUs.

Correctness

Using the jitcache still produces correct results.

Using FlashAttention on H100 we see:

$ VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.792|±  |0.0182|
|     |       |strict-match    |     5|exact_match|↑  |0.770|±  |0.0188|

with this PR, we see

$ VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.79|±  |0.0182|
|     |       |strict-match    |     5|exact_match|↑  | 0.77|±  |0.0188|

Details / How did we achieve this performance improvement

The launch overhead of triton kernels is a well known problem (see e.g. 1, 2, 3). Parts of the launch overhead comes from the fact that the triton JIT checks very carefully if an existing binary is safe to use.

In many scenarios, these checks can be relaxed and check only a subset of the parameters.
This PR adds such a cache with relaxed checks is implemented by jitcache. It is implemented as a decorator that could be used in front of the triton.jit decorator:

@jitcache(
    check_keys=["USE_ALIBI_SLOPES", "SLIDING_WINDOW", "filter_by_query_len"],
)
@triton.jit
def kernel_paged_attention_.... 

As short description, the jitcache follows the steps of the triton JIT compiler to produce a binary, but does this for each version (indicated by check_keys) only once. For all consecutive invocations, only non-constant arguments are updated (and copied to GPU) skipping most parts of the compiler.
A detailed usage description can be found in here.

This reduces the launch overhead of the paged_attention_2d kernel of the triton backend from ~186us down to ~24us.

Discussion

We have added this new jit cache to vllm/triton_utils/ because we expect that the vllm community prefers to have this rather as internal tool as an external dependency. However, we also published it as part of our triton-dejavu framework. So vllm could also import it from there, if this is preferred.

Ideally, something like the jitcache could be added as feature to triton itself, but we expect this to maybe be a lengthier process. However, in all cases we would need to update the jitcache for every new triton release.

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
@mergify mergify bot added the ci/build label Apr 14, 2025
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
@bringlein bringlein changed the title Adding basic Triton JitCache for triton_attn [Kernel] Adding basic Triton JitCache for triton_attn Apr 15, 2025
@bringlein bringlein marked this pull request as ready for review April 16, 2025 07:34
@bringlein
Copy link
Contributor Author

CC @SageMoore @tdoublep

(I also removed the jitcache from prefix_prefill in this PR, to not conflict with #13305)

@Sara-KS
Copy link

Sara-KS commented Apr 17, 2025

Testing this PR on MI-300X
Correctness:

VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 VLLM_ROCM_CUSTOM_PAGED_ATTN=0 lm_eval --model vllm --model_args pretrained=/models/llama-3.1-8b/instruct/ --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.790|±  |0.0182|
|     |       |strict-match    |     5|exact_match|↑  |0.772|±  |0.0188|

Performance:
VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 VLLM_ROCM_CUSTOM_PAGED_ATTN=0 vllm serve /models/llama-3.1-8b/instruct/ --disable-log-requests --no-enable-prefix-caching

VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 VLLM_ROCM_CUSTOM_PAGED_ATTN=0 python benchmarks/benchmark_serving.py --model /models/llama-3.1-8b/instruct/ --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --ignore-eos

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  27.98     
Total input tokens:                      215196    
Total generated tokens:                  198343    
Request throughput (req/s):              35.73     
Output token throughput (tok/s):         7087.69   
Total Token throughput (tok/s):          14777.61  
---------------Time to First Token----------------
Mean TTFT (ms):                          8536.68   
Median TTFT (ms):                        7890.95   
P99 TTFT (ms):                           19628.92  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          31.62     
Median TPOT (ms):                        29.14     
P99 TPOT (ms):                           76.65     
---------------Inter-token Latency----------------
Mean ITL (ms):                           26.71     
Median ITL (ms):                         19.01     
P99 ITL (ms):                            85.35     
==================================================

The triton launch overhead does not appear to be a bottleneck on MI-300X. We are currently investigating this and may follow-up in a separate PR.

Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The infra looks good. I just want to make sure I'm understanding the key generation logic correctly.

:param assume_const: A list of parameters that are NOT marked as
tl.constexpr but should be treated as constants in
this kernel launch.
:param assume_const: list[str]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use type hints here instead of the "type" comments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, addressed.

@@ -324,8 +355,9 @@ def chunked_prefill_paged_decode(
v_scale=v_scale,
)
else:
assert num_seqs <= 4096
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining what's going on here and why this max value was set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, addressed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: minor mispelling "assume"

cache_launch_grid=False,
assume_const=None,
):
# we depend on the triton version, right now, 3.0 -- 3.2 are supported
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make sure this doesn't crash if someone uses Triton 3.3?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so you prefer that the cache just "disables itself" and does nothing if the triton version is not supported (yet)? I could do that. Also I mean that for Triton 3.3 it would just do nothing and call regular JIT...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(In general I think we can support triton 3.3 as well, but that's not a focus right now)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think that's the best solution for now. Crashing the whole process just because this caching system isn't supported yet seems like too much.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I implemented it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How hard is 3.3 support? Since we've updated to torch 2.7 it would be good to get Triton 3.3 support here. Hopefully it just works?

self.base_fn = fn
while not inspect.isfunction(self.base_fn):
self.base_fn = self.base_fn.fn
self.cache_lock = cache_lock
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the cache_lock locked?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right now, it is not used (only _dynamic is used in this PR), see also below.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since only _dynamic is used, can we remove this static vs dynamic distinction and simplify the code?


return prepared_kernel

def _run_static(self, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some high level comments on the difference between static and dynamic would be good here. Looks like static basically means that we only add to the cache if we have acquired this cache lock? Dynamic will just add the generated kernel to the cache if it doesn't already exist?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that is correct. I can add some comments to the functions...if that is the best place to document it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some more explanations in the docstr of the decorator.

self.assume_const = assume_const
self.kernel_cache: dict[str, PreparedKernel] = {}

def calc_cache_index(kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you talk a bit about the key generation logic? I'm somewhat confused. It looks like, in this case, we would be constructing a key with the values of the "USE_ALIBI_SLOPES", "SLIDING_WINDOW", and "filter_by_query_len". Since those are the strings in the "check_keys" argument to the decorator. How were these arguments selected?

Copy link
Contributor Author

@bringlein bringlein Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right :). The developer needs to select these arguments based on her/his knowledge of the application. Put very basic, the jitcache trades safety in all scenarios and high launch overhead against a low launch overhead but reduced/relaxed safety checks applicable only to applications-specific use. It is then the job of the developers to ensure that the relaxed safety checks still hold for the particular application.

In the case of the paged_attention_2d kernel, we assume that only the arguments "USE_ALIBI_SLOPES", "SLIDING_WINDOW", "filter_by_query_len" change during the lifetime of one vLLM instance (one model within vLLM, to be precise) (or at most these arguments, I actually think they barley change and we could still reduce the list). Other arguments like num_heads don't change during a models live time (at least, to my knowledge), hence we don't need to check them every time and realize "oh, didn't change...I just used some micro-seconds to ensure this again".

Said differently, if we would use this kernel in an application that would use the same python process / same kernel instance to serve multiple different LLMs (something like attention-as-a-service...just making things up ;) ), then we would need to extend the check_keys list to ensure it also holds in this scenario (or not use the jitcache).

I can try to mention this more explicitly in the doc strings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some more explanations in the docstr of the decorator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be more explicit with the JIT cache here and properly scope it to the model?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think that we need to scope these caches to the model. Users are allowed to run multiple models without terminating the process.

self.update_args_index[arg_n] = i
self.non_const_vals_lst.append("dummy_value")
else:
self.non_const_vals_lst.append(assume_const_vals_dict[arg_n])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assert that assume_const_vals_dict contains arg_n?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is correct by construction in

assume_const_vals_dict = {

self.non_const_vals_lst = []
self.update_args_index = {}
for i, arg_n in enumerate(self.non_const_arg_names):
if arg_n in update_only_arg_names:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some comments would be good here. Something like

            # If the argument can change each time the kernel is called, store a dummy value 
            # that will be set each time __call__ is called
            if arg_n in update_only_arg_names:
                self.update_args_index[arg_n] = i
                self.non_const_vals_lst.append("dummy_value")
            # else the argument is assumed to be constant and we can just store its initial value 
            else:
                self.non_const_vals_lst.append(assume_const_vals_dict[arg_n])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I added something similar.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the main value of the non_const_vals_lst that it lets us do fewer copies in the call function? It doesn't effect the caching at all right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is correct, yes


self.cache_index_func = calc_cache_index
if len(check_keys) == 0:
self.cache_index_func = lambda ignore: "_default_"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a scary default. Wouldn't this just result in lots of false hits to the cache? I.E all functions mapping to the first function added?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is correct. But if check_keys is empty, it means that the developer assumes there is only ever one kernel variant. Meaning, the cache will ever have only one entry and the name of this entry doesn't matter.
I can add some explanations. I can also rename the string to smth like "_only_one_variant_"...but does it really matter?

Copy link
Contributor

@SageMoore SageMoore Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current name is fine but we should be very clear about our expectations of users of the jit cache. Like naively I would somewhat expect key generation to be "transparent" from the users, meaning we can just make a key given the values of all constexpr and assumed_const args. Though, maybe there is a performance concern with making a key with so many values? In this scenario the developer could use the check_keys field to override the default behavior

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meaning we can just make a key given the values of all constexpr and assumed_const args

I think this would be possible, but then the difference to naitive triton launcher is minimal (although) I didn't measure it. So my current thinking is, if someone doesn't want to deal with check_keys, (s)he maybe shouldn't use the jitcache.

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
bringlein added 6 commits May 7, 2025 23:36
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Copy link

mergify bot commented May 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bringlein.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 9, 2025
@mergify mergify bot removed the needs-rebase label May 13, 2025
bringlein added 2 commits May 13, 2025 05:38
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
bringlein added 5 commits May 13, 2025 06:47
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
@bringlein
Copy link
Contributor Author

some updates on this PR:

  • The jitcache in this PR is now compatible with triton 3.3 (and only 3.3, for backward compatibility, please use triton-dejavu.
  • The jitcache is disabled by default, to enable it the environment variable VLLM_TRITON_ENABLE_JITCACHE=1 needs to be set.
  • I updated the unified triton attention kernel from PR [Kernel] Unified Triton kernel that doesn't distinguish between prefill + decode #16828 to use the updated jitcache.
  • I fixed the broken attention backend on cuda platform (which was broken due to the dependency on the FlashAttentionMetadataBuilder).

From this PR, two todos remain:

  1. With triton 3.3, we see a performance regression for the triton backend. Details see below. However, this regression is independent of the jitcache and is therefore addressed in another PR.
  2. The triton backen is still depending on the attention metadata from the V1 flash attention backend, which should be fixed and made fully independent (maybe including to move some of the flash attention metadata functionalities to a utility file). However, also this is independent of the jitcache and should be addressed in another PR.

Correctness:

$ VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 VLLM_TRITON_ENABLE_JITCACHE=1 lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500
vllm (pretrained=meta-llama/Llama-3.1-8B-Instruct), gen_kwargs: (None), limit: 500.0, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.798|±  |0.0180|
|     |       |strict-match    |     5|exact_match|↑  |0.782|±  |0.0185|

Performance:

tracing

If analyzing vllm profiles, the jitcache functions as expected and reduces the triton launch overhead from 148us to 26us:
before:
image
after:
image

H100

With pytorch 2.7 and triton 3.3, the performance of the triton attention with this PR on an H100 drops to:

VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 VLLM_TRITON_ENABLE_JITCACHE=1 vllm serve meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --no-enable-prefix-caching

VLLM_USE_V1=1 python benchmarks/benchmark_serving.py --model meta-llama/Llama-3.1-8B-Instruct --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  30.49     
Total input tokens:                      215196    
Total generated tokens:                  197999    
Request throughput (req/s):              32.80     
Output token throughput (tok/s):         6493.77   
Total Token throughput (tok/s):          13551.56  
---------------Time to First Token----------------
Mean TTFT (ms):                          4738.41   
Median TTFT (ms):                        4668.27   
P99 TTFT (ms):                           9096.74   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          119.62    
Median TPOT (ms):                        65.38     
P99 TPOT (ms):                           322.74    
---------------Inter-token Latency----------------
Mean ITL (ms):                           53.69     
Median ITL (ms):                         36.63     
P99 ITL (ms):                            330.55    
==================================================

with VLLM_TRITON_ENABLE_JITCACHE=0 the performance is:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  31.05     
Total input tokens:                      215196    
Total generated tokens:                  198064    
Request throughput (req/s):              32.21     
Output token throughput (tok/s):         6378.75   
Total Token throughput (tok/s):          13309.25  
---------------Time to First Token----------------
Mean TTFT (ms):                          4927.07   
Median TTFT (ms):                        4718.52   
P99 TTFT (ms):                           9635.06   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          124.71    
Median TPOT (ms):                        67.28     
P99 TPOT (ms):                           365.14    
---------------Inter-token Latency----------------
Mean ITL (ms):                           55.10     
Median ITL (ms):                         37.23     
P99 ITL (ms):                            338.55    
==================================================

(so roughly only 2% increase in total performance due to the jitcache).

MI300

With pytorch 2.7 and triton 3.3, the performance of the triton attention with this PR on an H100 drops to:

VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 VLLM_TRITON_ENABLE_JITCACHE=1 vllm serve mistralai/Mistral-Small-24B-Instruct-2501 --disable-log-requests --no-enable-prefix-caching

VLLM_USE_V1=1 python benchmarks/benchmark_serving.py --model mistralai/Mistral-Small-24B-Instruct-2501 --dataset-name random --request-rate 9 --random-input-len 1000 --random-output-len 100 --num-prompts 1080 --ignore-eos --seed 9

============ Serving Benchmark Result ============
Successful requests:                     1080      
Benchmark duration (s):                  154.60    
Total input tokens:                      1078920   
Total generated tokens:                  108000    
Request throughput (req/s):              6.99      
Output token throughput (tok/s):         698.57    
Total Token throughput (tok/s):          7677.32   
---------------Time to First Token----------------
Mean TTFT (ms):                          10223.16  
Median TTFT (ms):                        9194.06   
P99 TTFT (ms):                           22111.30  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          710.73    
Median TPOT (ms):                        801.62    
P99 TPOT (ms):                           1004.43   
---------------Inter-token Latency----------------
Mean ITL (ms):                           710.73    
Median ITL (ms):                         1016.49   
P99 ITL (ms):                            1070.81   
==================================================

with VLLM_TRITON_ENABLE_JITCACHE=0 the performance is:

============ Serving Benchmark Result ============
Successful requests:                     1080      
Benchmark duration (s):                  154.65    
Total input tokens:                      1078920   
Total generated tokens:                  108000    
Request throughput (req/s):              6.98      
Output token throughput (tok/s):         698.34    
Total Token throughput (tok/s):          7674.80   
---------------Time to First Token----------------
Mean TTFT (ms):                          10475.19  
Median TTFT (ms):                        9509.31   
P99 TTFT (ms):                           22325.98  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          708.41    
Median TPOT (ms):                        798.45    
P99 TPOT (ms):                           1001.22   
---------------Inter-token Latency----------------
Mean ITL (ms):                           708.41    
Median ITL (ms):                         1013.82   
P99 ITL (ms):                            1067.09   
==================================================

cyang49 and others added 2 commits May 14, 2025 12:19
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants