-
-
Notifications
You must be signed in to change notification settings - Fork 7.5k
[Kernel] Adding basic Triton JitCache for triton_attn #16606
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
(I also removed the jitcache from |
Testing this PR on MI-300X VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 VLLM_ROCM_CUSTOM_PAGED_ATTN=0 lm_eval --model vllm --model_args pretrained=/models/llama-3.1-8b/instruct/ --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500
Performance: VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 VLLM_ROCM_CUSTOM_PAGED_ATTN=0 python benchmarks/benchmark_serving.py --model /models/llama-3.1-8b/instruct/ --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --ignore-eos
The triton launch overhead does not appear to be a bottleneck on MI-300X. We are currently investigating this and may follow-up in a separate PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The infra looks good. I just want to make sure I'm understanding the key generation logic correctly.
vllm/triton_utils/jit_cache.py
Outdated
:param assume_const: A list of parameters that are NOT marked as | ||
tl.constexpr but should be treated as constants in | ||
this kernel launch. | ||
:param assume_const: list[str] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use type hints here instead of the "type" comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, addressed.
@@ -324,8 +355,9 @@ def chunked_prefill_paged_decode( | |||
v_scale=v_scale, | |||
) | |||
else: | |||
assert num_seqs <= 4096 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment explaining what's going on here and why this max value was set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, addressed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: minor mispelling "assume"
vllm/triton_utils/jit_cache.py
Outdated
cache_launch_grid=False, | ||
assume_const=None, | ||
): | ||
# we depend on the triton version, right now, 3.0 -- 3.2 are supported |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make sure this doesn't crash if someone uses Triton 3.3?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, so you prefer that the cache just "disables itself" and does nothing if the triton version is not supported (yet)? I could do that. Also I mean that for Triton 3.3 it would just do nothing and call regular JIT...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(In general I think we can support triton 3.3 as well, but that's not a focus right now)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think that's the best solution for now. Crashing the whole process just because this caching system isn't supported yet seems like too much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I implemented it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How hard is 3.3 support? Since we've updated to torch 2.7 it would be good to get Triton 3.3 support here. Hopefully it just works?
self.base_fn = fn | ||
while not inspect.isfunction(self.base_fn): | ||
self.base_fn = self.base_fn.fn | ||
self.cache_lock = cache_lock |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the cache_lock locked?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right now, it is not used (only _dynamic
is used in this PR), see also below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since only _dynamic
is used, can we remove this static vs dynamic distinction and simplify the code?
|
||
return prepared_kernel | ||
|
||
def _run_static(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some high level comments on the difference between static and dynamic would be good here. Looks like static basically means that we only add to the cache if we have acquired this cache lock? Dynamic will just add the generated kernel to the cache if it doesn't already exist?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, that is correct. I can add some comments to the functions...if that is the best place to document it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some more explanations in the docstr of the decorator.
self.assume_const = assume_const | ||
self.kernel_cache: dict[str, PreparedKernel] = {} | ||
|
||
def calc_cache_index(kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you talk a bit about the key generation logic? I'm somewhat confused. It looks like, in this case, we would be constructing a key with the values of the "USE_ALIBI_SLOPES", "SLIDING_WINDOW", and "filter_by_query_len". Since those are the strings in the "check_keys" argument to the decorator. How were these arguments selected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right :). The developer needs to select these arguments based on her/his knowledge of the application. Put very basic, the jitcache
trades safety in all scenarios and high launch overhead against a low launch overhead but reduced/relaxed safety checks applicable only to applications-specific use. It is then the job of the developers to ensure that the relaxed safety checks still hold for the particular application.
In the case of the paged_attention_2d
kernel, we assume that only the arguments "USE_ALIBI_SLOPES", "SLIDING_WINDOW", "filter_by_query_len"
change during the lifetime of one vLLM instance (one model within vLLM, to be precise) (or at most these arguments, I actually think they barley change and we could still reduce the list). Other arguments like num_heads
don't change during a models live time (at least, to my knowledge), hence we don't need to check them every time and realize "oh, didn't change...I just used some micro-seconds to ensure this again".
Said differently, if we would use this kernel in an application that would use the same python process / same kernel instance to serve multiple different LLMs (something like attention-as-a-service...just making things up ;) ), then we would need to extend the check_keys
list to ensure it also holds in this scenario (or not use the jitcache).
I can try to mention this more explicitly in the doc strings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some more explanations in the docstr of the decorator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we be more explicit with the JIT cache here and properly scope it to the model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think that we need to scope these caches to the model. Users are allowed to run multiple models without terminating the process.
vllm/triton_utils/jit_cache.py
Outdated
self.update_args_index[arg_n] = i | ||
self.non_const_vals_lst.append("dummy_value") | ||
else: | ||
self.non_const_vals_lst.append(assume_const_vals_dict[arg_n]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assert that assume_const_vals_dict contains arg_n?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is correct by construction in
vllm/vllm/triton_utils/jit_cache.py
Line 234 in f8c6610
assume_const_vals_dict = { |
self.non_const_vals_lst = [] | ||
self.update_args_index = {} | ||
for i, arg_n in enumerate(self.non_const_arg_names): | ||
if arg_n in update_only_arg_names: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think some comments would be good here. Something like
# If the argument can change each time the kernel is called, store a dummy value
# that will be set each time __call__ is called
if arg_n in update_only_arg_names:
self.update_args_index[arg_n] = i
self.non_const_vals_lst.append("dummy_value")
# else the argument is assumed to be constant and we can just store its initial value
else:
self.non_const_vals_lst.append(assume_const_vals_dict[arg_n])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, I added something similar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the main value of the non_const_vals_lst
that it lets us do fewer copies in the call function? It doesn't effect the caching at all right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that is correct, yes
|
||
self.cache_index_func = calc_cache_index | ||
if len(check_keys) == 0: | ||
self.cache_index_func = lambda ignore: "_default_" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a scary default. Wouldn't this just result in lots of false hits to the cache? I.E all functions mapping to the first function added?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that is correct. But if check_keys
is empty, it means that the developer assumes there is only ever one kernel variant. Meaning, the cache will ever have only one entry and the name of this entry doesn't matter.
I can add some explanations. I can also rename the string to smth like "_only_one_variant_"
...but does it really matter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the current name is fine but we should be very clear about our expectations of users of the jit cache. Like naively I would somewhat expect key generation to be "transparent" from the users, meaning we can just make a key given the values of all constexpr and assumed_const args. Though, maybe there is a performance concern with making a key with so many values? In this scenario the developer could use the check_keys
field to override the default behavior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
meaning we can just make a key given the values of all constexpr and assumed_const args
I think this would be possible, but then the difference to naitive triton launcher is minimal (although) I didn't measure it. So my current thinking is, if someone doesn't want to deal with check_keys
, (s)he maybe shouldn't use the jitcache.
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
This reverts commit 450770c.
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
some updates on this PR:
From this PR, two todos remain:
Correctness:
Performance:tracingIf analyzing vllm profiles, the jitcache functions as expected and reduces the triton launch overhead from 148us to 26us: H100With pytorch 2.7 and triton 3.3, the performance of the triton attention with this PR on an H100 drops to:
with
(so roughly only 2% increase in total performance due to the jitcache). MI300With pytorch 2.7 and triton 3.3, the performance of the triton attention with this PR on an H100 drops to:
with
|
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
In this PR we improve the performance of the V1
triton_attn
backend. We do not change the triton kernels, but we reduce the well known launch overhead of triton kernels by caching the JIT artifacts of the decode kernel.Performance
All of the below results are for
meta-llama/Llama-3.1-8B-Instruct
on an NVIDIA H100 GPU. All experiments are done with--no-enable-prefix-caching
andcurrent upstream:
with this PR
to compare, using V1
FlashAttention
backend:So, this PR improves the performance of the V1
triton_attn
backend by 27% and outperforms the FlashAttention-3 baseline by 5% for the serving benchmark.We are in the process of evaluating the performance on AMD GPUs.
Correctness
Using the jitcache still produces correct results.
Using FlashAttention on H100 we see:
with this PR, we see
Details / How did we achieve this performance improvement
The launch overhead of triton kernels is a well known problem (see e.g. 1, 2, 3). Parts of the launch overhead comes from the fact that the triton JIT checks very carefully if an existing binary is safe to use.
In many scenarios, these checks can be relaxed and check only a subset of the parameters.
This PR adds such a cache with relaxed checks is implemented by
jitcache
. It is implemented as a decorator that could be used in front of thetriton.jit
decorator:As short description, the
jitcache
follows the steps of the triton JIT compiler to produce a binary, but does this for each version (indicated bycheck_keys
) only once. For all consecutive invocations, only non-constant arguments are updated (and copied to GPU) skipping most parts of the compiler.A detailed usage description can be found in here.
This reduces the launch overhead of the
paged_attention_2d
kernel of the triton backend from ~186us down to ~24us.Discussion
We have added this new jit cache to
vllm/triton_utils/
because we expect that the vllm community prefers to have this rather as internal tool as an external dependency. However, we also published it as part of our triton-dejavu framework. So vllm could also import it from there, if this is preferred.Ideally, something like the
jitcache
could be added as feature to triton itself, but we expect this to maybe be a lengthier process. However, in all cases we would need to update the jitcache for every new triton release.