-
-
Notifications
You must be signed in to change notification settings - Fork 7.4k
[ROCm][FP8][Kernel] FP8 quantization fused into Custom Paged Attention #17139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…d output FP8 tensor Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
…uantizing in the flash attention kernel for V1 Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 nits, and could we add this case to tests?
csrc/rocm/attention.cu
Outdated
// NOTE: fp8_out_scale is optional. | ||
const float* fp8_out_scale_ptr = | ||
fp8_out_scale | ||
? reinterpret_cast<const float*>(fp8_out_scale.value().data_ptr()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: static cast?
const float* fp8_out_scale_ptr = | ||
fp8_out_scale | ||
? reinterpret_cast<const float*>(fp8_out_scale.value().data_ptr()) | ||
: nullptr; | ||
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the OUTT
type be fp8 if scale is given? Is that captured automatically? Maybe we could assert this somewhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, should tmp_output
be the same type as output
? So if output is fp8
, is tmp_output also fp8
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the
OUTT
type be fp8 if scale is given? Is that captured automatically? Maybe we could assert this somewhere
This is ensured at https://github.com/vllm-project/vllm/pull/17139/files#diff-79b8261aa73f07cc7450e48c8e14150576656f19ccfb42ba972860092c1f5949R1779-R1786
Also, should tmp_output be the same type as output? So if output is fp8, is tmp_output also fp8?
No, it should be the same type as query, it is used in the internal calculations
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
commit 9f733ff Author: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Date: Fri Apr 25 22:10:58 2025 +0000 Using static cast Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> commit 2d7dba5 Author: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Date: Thu Apr 24 21:37:16 2025 +0000 An option to apply fp8 output scale in ROCm custom paged attention and output FP8 tensor Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com>
csrc/rocm/attention.cu
Outdated
@@ -1238,6 +1240,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( | |||
|
|||
// final write to tmp_out after vout accumulation | |||
if (warpid == 0) { | |||
const float out_scale = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering where out_scale is used here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is actually used in the reduction kernel launched after either of the attention kernels.
The dereferencing here is indeed not needed, but it'll get optimized out. I'll make a note to clean it up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you just remove it in this PR?
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
vllm-project#17139) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
vllm-project#17139) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
vllm-project#17139) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
An option to apply fp8 output scale in ROCm custom paged attention and output FP8 tensor
In case a non-None scale tensor is passed to the kernel, the output tensor is expected to be in the current_platform.fp8_dtype() type (float8_fnuz or float8_fn), and the scale is applied to it before storing into an 8-bit type