Skip to content

Commit b525b34

Browse files
gshtrasmawong-amd
authored andcommitted
[ROCm][FP8][Kernel] FP8 quantization fused into Custom Paged Attention (vllm-project#17139)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
1 parent 08859a7 commit b525b34

File tree

4 files changed

+74
-44
lines changed

4 files changed

+74
-44
lines changed

csrc/rocm/attention.cu

Lines changed: 61 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,7 +1287,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
12871287
// max_num_partitions, head_size]
12881288
const int* __restrict__ context_lens, // [num_seqs]
12891289
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
1290-
const int max_num_partitions) {
1290+
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
12911291
const auto num_heads = gridDim.x;
12921292
const auto head_idx = blockIdx.x;
12931293
const auto seq_idx = blockIdx.y;
@@ -1465,8 +1465,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
14651465

14661466
const float inv_global_exp_sum =
14671467
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
1468+
const float out_scale =
1469+
(fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f;
14681470
acc *= inv_global_exp_sum;
1469-
1471+
acc *= out_scale;
14701472
const int64_t query_start_off = static_cast<int64_t>(
14711473
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
14721474
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
@@ -1548,7 +1550,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
15481550
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
15491551
const int* __restrict__ context_lens, // [num_seqs]
15501552
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
1551-
const int max_num_partitions) {
1553+
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
15521554
UNREACHABLE_CODE
15531555
}
15541556
// clang-format on
@@ -1582,7 +1584,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
15821584
PARTITION_SIZE, NPAR_LOOPS> \
15831585
<<<reduce_grid, reduce_block, 0, stream>>>( \
15841586
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
1585-
context_lens_ptr, query_start_loc_ptr, max_num_partitions);
1587+
context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
1588+
fp8_out_scale_ptr);
15861589

15871590
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
15881591
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
@@ -1594,7 +1597,7 @@ void paged_attention_custom_launcher(
15941597
torch::Tensor& block_tables, torch::Tensor& context_lens,
15951598
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
15961599
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
1597-
torch::Tensor& v_scale) {
1600+
torch::Tensor& v_scale, const c10::optional<torch::Tensor>& fp8_out_scale) {
15981601
int num_seqs = block_tables.size(0);
15991602
int num_heads = query.size(1);
16001603
int head_size = query.size(2);
@@ -1626,6 +1629,11 @@ void paged_attention_custom_launcher(
16261629
int* context_lens_ptr = context_lens.data_ptr<int>();
16271630
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
16281631
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
1632+
// NOTE: fp8_out_scale is optional.
1633+
const auto fp8_out_scale_ptr =
1634+
fp8_out_scale
1635+
? static_cast<const float*>(fp8_out_scale.value().data_ptr())
1636+
: nullptr;
16291637
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
16301638

16311639
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
@@ -1736,33 +1744,54 @@ void paged_attention_custom_launcher(
17361744
}
17371745
}
17381746

1739-
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, \
1740-
ALIBI_ENABLED) \
1741-
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
1742-
PSIZE, ALIBI_ENABLED>( \
1743-
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
1744-
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
1745-
max_context_len, alibi_slopes, k_scale, v_scale);
1746-
1747-
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
1748-
PSIZE) \
1749-
if (alibi_slopes) { \
1750-
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, true); \
1751-
} else { \
1752-
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, false); \
1747+
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
1748+
PSIZE, ALIBI_ENABLED) \
1749+
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
1750+
PSIZE, ALIBI_ENABLED>( \
1751+
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
1752+
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
1753+
max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale);
1754+
1755+
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
1756+
OUTT, PSIZE) \
1757+
if (alibi_slopes) { \
1758+
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
1759+
true); \
1760+
} else { \
1761+
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
1762+
false); \
17531763
}
17541764

1755-
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
1756-
switch (block_size) { \
1757-
case 16: \
1758-
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 16, HEAD_SIZE, 256); \
1759-
break; \
1760-
case 32: \
1761-
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 32, HEAD_SIZE, 256); \
1762-
break; \
1763-
default: \
1764-
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
1765-
break; \
1765+
#if defined(__HIPCC__) && defined(__gfx90a__)
1766+
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
1767+
if (fp8_out_scale) { \
1768+
TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \
1769+
} else { \
1770+
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
1771+
256); \
1772+
}
1773+
#else
1774+
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
1775+
if (fp8_out_scale) { \
1776+
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
1777+
uint8_t, 256); \
1778+
} else { \
1779+
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
1780+
256); \
1781+
}
1782+
#endif
1783+
1784+
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
1785+
switch (block_size) { \
1786+
case 16: \
1787+
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \
1788+
break; \
1789+
case 32: \
1790+
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \
1791+
break; \
1792+
default: \
1793+
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
1794+
break; \
17661795
}
17671796

17681797
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \
@@ -1795,7 +1824,8 @@ void paged_attention(
17951824
int64_t block_size, int64_t max_context_len,
17961825
const std::optional<torch::Tensor>& alibi_slopes,
17971826
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
1798-
torch::Tensor& v_scale) {
1827+
torch::Tensor& v_scale,
1828+
const c10::optional<torch::Tensor>& fp8_out_scale) {
17991829
// clang-format on
18001830
const int head_size = query.size(2);
18011831
if (kv_cache_dtype == "auto") {

csrc/rocm/ops.h

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
1111
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
1212
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
1313

14-
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
15-
torch::Tensor& max_logits, torch::Tensor& tmp_out,
16-
torch::Tensor& query, torch::Tensor& key_cache,
17-
torch::Tensor& value_cache, int64_t num_kv_heads,
18-
double scale, torch::Tensor& block_tables,
19-
torch::Tensor& context_lens,
20-
const std::optional<torch::Tensor>& query_start_loc,
21-
int64_t block_size, int64_t max_context_len,
22-
const std::optional<torch::Tensor>& alibi_slopes,
23-
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
24-
torch::Tensor& v_scale);
14+
void paged_attention(
15+
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
16+
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
17+
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
18+
torch::Tensor& block_tables, torch::Tensor& context_lens,
19+
const std::optional<torch::Tensor>& query_start_loc, int64_t block_size,
20+
int64_t max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
21+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
22+
torch::Tensor& v_scale, const c10::optional<torch::Tensor>& fp8_out_scale);

csrc/rocm/torch_bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
4747
" int max_context_len,"
4848
" Tensor? alibi_slopes,"
4949
" str kv_cache_dtype,"
50-
" Tensor k_scale, Tensor v_scale) -> ()");
50+
" Tensor k_scale, Tensor v_scale,"
51+
" Tensor? fp8_out_scale) -> ()");
5152
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
5253
}
5354

vllm/_custom_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,14 @@ def paged_attention_rocm(
117117
kv_cache_dtype: str,
118118
k_scale: torch.Tensor,
119119
v_scale: torch.Tensor,
120+
fp8_out_scale: Optional[torch.Tensor] = None,
120121
) -> None:
121122
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
122123
key_cache, value_cache, num_kv_heads,
123124
scale, block_tables, seq_lens,
124125
query_start_loc, block_size, max_seq_len,
125126
alibi_slopes, kv_cache_dtype, k_scale,
126-
v_scale)
127+
v_scale, fp8_out_scale)
127128

128129

129130
def mla_decode_kvcache_cpu(

0 commit comments

Comments
 (0)