Skip to content

[ROCm][FP8][Kernel] FP8 quantization fused into Custom Paged Attention #17139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 7, 2025

Conversation

gshtras
Copy link
Collaborator

@gshtras gshtras commented Apr 24, 2025

An option to apply fp8 output scale in ROCm custom paged attention and output FP8 tensor
In case a non-None scale tensor is passed to the kernel, the output tensor is expected to be in the current_platform.fp8_dtype() type (float8_fnuz or float8_fn), and the scale is applied to it before storing into an 8-bit type

…d output FP8 tensor

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

gshtras added a commit to ROCm/vllm that referenced this pull request Apr 24, 2025
…uantizing in the flash attention kernel for V1

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Copy link
Contributor

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 nits, and could we add this case to tests?

// NOTE: fp8_out_scale is optional.
const float* fp8_out_scale_ptr =
fp8_out_scale
? reinterpret_cast<const float*>(fp8_out_scale.value().data_ptr())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: static cast?

const float* fp8_out_scale_ptr =
fp8_out_scale
? reinterpret_cast<const float*>(fp8_out_scale.value().data_ptr())
: nullptr;
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the OUTT type be fp8 if scale is given? Is that captured automatically? Maybe we could assert this somewhere

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, should tmp_output be the same type as output? So if output is fp8, is tmp_output also fp8?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the OUTT type be fp8 if scale is given? Is that captured automatically? Maybe we could assert this somewhere

This is ensured at https://github.com/vllm-project/vllm/pull/17139/files#diff-79b8261aa73f07cc7450e48c8e14150576656f19ccfb42ba972860092c1f5949R1779-R1786

Also, should tmp_output be the same type as output? So if output is fp8, is tmp_output also fp8?

No, it should be the same type as query, it is used in the internal calculations

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
@houseroad houseroad added the rocm Related to AMD ROCm label Apr 26, 2025
@gshtras gshtras requested a review from ProExpertProg April 28, 2025 15:08
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label May 1, 2025
ProExpertProg added a commit to neuralmagic/vllm that referenced this pull request May 1, 2025
commit 9f733ff
Author: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Date:   Fri Apr 25 22:10:58 2025 +0000

    Using static cast

    Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>

commit 2d7dba5
Author: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Date:   Thu Apr 24 21:37:16 2025 +0000

    An option to apply fp8 output scale in ROCm custom paged attention and output FP8 tensor

    Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
@@ -1238,6 +1240,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(

// final write to tmp_out after vout accumulation
if (warpid == 0) {
const float out_scale =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering where out_scale is used here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is actually used in the reduction kernel launched after either of the attention kernels.
The dereferencing here is indeed not needed, but it'll get optimized out. I'll make a note to clean it up

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you just remove it in this PR?

gshtras added 2 commits May 1, 2025 20:17
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) May 5, 2025 17:07
@vllm-bot vllm-bot merged commit 32aa74c into vllm-project:main May 7, 2025
77 of 80 checks passed
princepride pushed a commit to princepride/vllm that referenced this pull request May 10, 2025
vllm-project#17139)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
vllm-project#17139)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
vllm-project#17139)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants