Skip to content

[ROCm][FP8][Kernel] FP8 quantization fused into Custom Paged Attention #17139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 61 additions & 31 deletions csrc/rocm/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions) {
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const auto num_heads = gridDim.x;
const auto head_idx = blockIdx.x;
const auto seq_idx = blockIdx.y;
Expand Down Expand Up @@ -1465,8 +1465,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(

const float inv_global_exp_sum =
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
const float out_scale =
(fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f;
acc *= inv_global_exp_sum;

acc *= out_scale;
const int64_t query_start_off = static_cast<int64_t>(
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
Expand Down Expand Up @@ -1548,7 +1550,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions) {
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
UNREACHABLE_CODE
}
// clang-format on
Expand Down Expand Up @@ -1582,7 +1584,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
context_lens_ptr, query_start_loc_ptr, max_num_partitions);
context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
fp8_out_scale_ptr);

template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
Expand All @@ -1594,7 +1597,7 @@ void paged_attention_custom_launcher(
torch::Tensor& block_tables, torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
torch::Tensor& v_scale, const c10::optional<torch::Tensor>& fp8_out_scale) {
int num_seqs = block_tables.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
Expand Down Expand Up @@ -1626,6 +1629,11 @@ void paged_attention_custom_launcher(
int* context_lens_ptr = context_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
// NOTE: fp8_out_scale is optional.
const auto fp8_out_scale_ptr =
fp8_out_scale
? static_cast<const float*>(fp8_out_scale.value().data_ptr())
: nullptr;
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the OUTT type be fp8 if scale is given? Is that captured automatically? Maybe we could assert this somewhere

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, should tmp_output be the same type as output? So if output is fp8, is tmp_output also fp8?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the OUTT type be fp8 if scale is given? Is that captured automatically? Maybe we could assert this somewhere

This is ensured at https://github.com/vllm-project/vllm/pull/17139/files#diff-79b8261aa73f07cc7450e48c8e14150576656f19ccfb42ba972860092c1f5949R1779-R1786

Also, should tmp_output be the same type as output? So if output is fp8, is tmp_output also fp8?

No, it should be the same type as query, it is used in the internal calculations


const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
Expand Down Expand Up @@ -1736,33 +1744,54 @@ void paged_attention_custom_launcher(
}
}

#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, \
ALIBI_ENABLED) \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale);

#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
PSIZE) \
if (alibi_slopes) { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, true); \
} else { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, false); \
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
PSIZE, ALIBI_ENABLED) \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale);

#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
OUTT, PSIZE) \
if (alibi_slopes) { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
true); \
} else { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
false); \
}

#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
switch (block_size) { \
case 16: \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 16, HEAD_SIZE, 256); \
break; \
case 32: \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 32, HEAD_SIZE, 256); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
#if defined(__HIPCC__) && defined(__gfx90a__)
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
if (fp8_out_scale) { \
TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \
} else { \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
256); \
}
#else
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
if (fp8_out_scale) { \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
uint8_t, 256); \
} else { \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
256); \
}
#endif

#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
switch (block_size) { \
case 16: \
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \
break; \
case 32: \
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}

#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \
Expand Down Expand Up @@ -1795,7 +1824,8 @@ void paged_attention(
int64_t block_size, int64_t max_context_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
torch::Tensor& v_scale,
const c10::optional<torch::Tensor>& fp8_out_scale) {
// clang-format on
const int head_size = query.size(2);
if (kv_cache_dtype == "auto") {
Expand Down
20 changes: 9 additions & 11 deletions csrc/rocm/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,12 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);

void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads,
double scale, torch::Tensor& block_tables,
torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc,
int64_t block_size, int64_t max_context_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale);
void paged_attention(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc, int64_t block_size,
int64_t max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const c10::optional<torch::Tensor>& fp8_out_scale);
3 changes: 2 additions & 1 deletion csrc/rocm/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
" Tensor k_scale, Tensor v_scale,"
" Tensor? fp8_out_scale) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
}

Expand Down
3 changes: 2 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,14 @@ def paged_attention_rocm(
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens,
query_start_loc, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale,
v_scale)
v_scale, fp8_out_scale)


def mla_decode_kvcache_cpu(
Expand Down