@@ -1287,7 +1287,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
1287
1287
// max_num_partitions, head_size]
1288
1288
const int * __restrict__ context_lens, // [num_seqs]
1289
1289
const int * __restrict__ query_start_loc_ptr, // [num_seqs]
1290
- const int max_num_partitions) {
1290
+ const int max_num_partitions, const float * __restrict__ fp8_out_scale_ptr ) {
1291
1291
const auto num_heads = gridDim .x ;
1292
1292
const auto head_idx = blockIdx .x ;
1293
1293
const auto seq_idx = blockIdx .y ;
@@ -1465,8 +1465,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
1465
1465
1466
1466
const float inv_global_exp_sum =
1467
1467
__fdividef (1 .0f , shared_global_exp_sum + 1e-6f );
1468
+ const float out_scale =
1469
+ (fp8_out_scale_ptr != nullptr ) ? 1 .0f / (*fp8_out_scale_ptr) : 1 .0f ;
1468
1470
acc *= inv_global_exp_sum;
1469
-
1471
+ acc *= out_scale;
1470
1472
const int64_t query_start_off = static_cast <int64_t >(
1471
1473
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
1472
1474
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
@@ -1548,7 +1550,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
1548
1550
const scalar_t * __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
1549
1551
const int * __restrict__ context_lens, // [num_seqs]
1550
1552
const int * __restrict__ query_start_loc_ptr, // [num_seqs]
1551
- const int max_num_partitions) {
1553
+ const int max_num_partitions, const float * __restrict__ fp8_out_scale_ptr ) {
1552
1554
UNREACHABLE_CODE
1553
1555
}
1554
1556
// clang-format on
@@ -1582,7 +1584,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
1582
1584
PARTITION_SIZE, NPAR_LOOPS> \
1583
1585
<<<reduce_grid, reduce_block, 0 , stream>>> ( \
1584
1586
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
1585
- context_lens_ptr, query_start_loc_ptr, max_num_partitions);
1587
+ context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
1588
+ fp8_out_scale_ptr);
1586
1589
1587
1590
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
1588
1591
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
@@ -1594,7 +1597,7 @@ void paged_attention_custom_launcher(
1594
1597
torch::Tensor& block_tables, torch::Tensor& context_lens,
1595
1598
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
1596
1599
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
1597
- torch::Tensor& v_scale) {
1600
+ torch::Tensor& v_scale, const c10::optional<torch::Tensor>& fp8_out_scale ) {
1598
1601
int num_seqs = block_tables.size (0 );
1599
1602
int num_heads = query.size (1 );
1600
1603
int head_size = query.size (2 );
@@ -1626,6 +1629,11 @@ void paged_attention_custom_launcher(
1626
1629
int * context_lens_ptr = context_lens.data_ptr <int >();
1627
1630
const float * k_scale_ptr = reinterpret_cast <const float *>(k_scale.data_ptr ());
1628
1631
const float * v_scale_ptr = reinterpret_cast <const float *>(v_scale.data_ptr ());
1632
+ // NOTE: fp8_out_scale is optional.
1633
+ const auto fp8_out_scale_ptr =
1634
+ fp8_out_scale
1635
+ ? static_cast <const float *>(fp8_out_scale.value ().data_ptr ())
1636
+ : nullptr ;
1629
1637
OUTT* out_ptr = reinterpret_cast <OUTT*>(out.data_ptr ());
1630
1638
1631
1639
const int max_ctx_blocks = DIVIDE_ROUND_UP (max_context_len, BLOCK_SIZE);
@@ -1736,33 +1744,54 @@ void paged_attention_custom_launcher(
1736
1744
}
1737
1745
}
1738
1746
1739
- #define CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, \
1740
- ALIBI_ENABLED) \
1741
- paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
1742
- PSIZE, ALIBI_ENABLED>( \
1743
- out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
1744
- num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
1745
- max_context_len, alibi_slopes, k_scale, v_scale);
1746
-
1747
- #define CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
1748
- PSIZE) \
1749
- if (alibi_slopes) { \
1750
- CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, true ); \
1751
- } else { \
1752
- CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, false ); \
1747
+ #define CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
1748
+ PSIZE, ALIBI_ENABLED) \
1749
+ paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
1750
+ PSIZE, ALIBI_ENABLED>( \
1751
+ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
1752
+ num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
1753
+ max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale);
1754
+
1755
+ #define CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
1756
+ OUTT, PSIZE) \
1757
+ if (alibi_slopes) { \
1758
+ CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
1759
+ true ); \
1760
+ } else { \
1761
+ CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
1762
+ false ); \
1753
1763
}
1754
1764
1755
- #define CALL_CUSTOM_LAUNCHER_BLK (T, KVT, KV_DTYPE, HEAD_SIZE ) \
1756
- switch (block_size) { \
1757
- case 16 : \
1758
- CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, 16 , HEAD_SIZE, 256 ); \
1759
- break ; \
1760
- case 32 : \
1761
- CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, 32 , HEAD_SIZE, 256 ); \
1762
- break ; \
1763
- default : \
1764
- TORCH_CHECK (false , " Unsupported block size: " , block_size); \
1765
- break ; \
1765
+ #if defined(__HIPCC__) && defined(__gfx90a__)
1766
+ #define CALL_CUSTOM_LAUNCHER_OUT (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE ) \
1767
+ if (fp8_out_scale) { \
1768
+ TORCH_CHECK (false , " fp8 out scale unsupported for gfx90a" ); \
1769
+ } else { \
1770
+ CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
1771
+ 256 ); \
1772
+ }
1773
+ #else
1774
+ #define CALL_CUSTOM_LAUNCHER_OUT (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE ) \
1775
+ if (fp8_out_scale) { \
1776
+ CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
1777
+ uint8_t , 256 ); \
1778
+ } else { \
1779
+ CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
1780
+ 256 ); \
1781
+ }
1782
+ #endif
1783
+
1784
+ #define CALL_CUSTOM_LAUNCHER_BLK (T, KVT, KV_DTYPE, HEAD_SIZE ) \
1785
+ switch (block_size) { \
1786
+ case 16 : \
1787
+ CALL_CUSTOM_LAUNCHER_OUT (T, KVT, KV_DTYPE, 16 , HEAD_SIZE); \
1788
+ break ; \
1789
+ case 32 : \
1790
+ CALL_CUSTOM_LAUNCHER_OUT (T, KVT, KV_DTYPE, 32 , HEAD_SIZE); \
1791
+ break ; \
1792
+ default : \
1793
+ TORCH_CHECK (false , " Unsupported block size: " , block_size); \
1794
+ break ; \
1766
1795
}
1767
1796
1768
1797
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD (T, KVT, KV_DTYPE ) \
@@ -1795,7 +1824,8 @@ void paged_attention(
1795
1824
int64_t block_size, int64_t max_context_len,
1796
1825
const std::optional<torch::Tensor>& alibi_slopes,
1797
1826
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
1798
- torch::Tensor& v_scale) {
1827
+ torch::Tensor& v_scale,
1828
+ const c10::optional<torch::Tensor>& fp8_out_scale) {
1799
1829
// clang-format on
1800
1830
const int head_size = query.size (2 );
1801
1831
if (kv_cache_dtype == " auto" ) {
0 commit comments