From 985034c459883218db2cc4470f504662ba943935 Mon Sep 17 00:00:00 2001 From: kaixih Date: Tue, 1 Apr 2025 18:29:02 +0000 Subject: [PATCH 1/5] Support cutlass MLA and Upgrade cutlass to 3.9 Signed-off-by: kaixih --- CMakeLists.txt | 28 ++- csrc/attention/mla/cutlass_mla_entry.cu | 37 +++ csrc/attention/mla/cutlass_mla_kernels.cu | 220 ++++++++++++++++++ csrc/ops.h | 6 + .../fp4/nvfp4_scaled_mm_kernels.cu | 2 +- csrc/torch_bindings.cpp | 7 + tests/kernels/test_cutlass_mla_decode.py | 89 +++++++ vllm/_custom_ops.py | 15 ++ 8 files changed, 399 insertions(+), 5 deletions(-) create mode 100644 csrc/attention/mla/cutlass_mla_entry.cu create mode 100644 csrc/attention/mla/cutlass_mla_kernels.cu create mode 100644 tests/kernels/test_cutlass_mla_decode.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 21464a0560d..3314f05fd2a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -251,7 +251,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. # Please keep this in sync with FetchContent_Declare line below. - set(CUTLASS_REVISION "v3.8.0" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v3.9.0" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -269,7 +269,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git # Please keep this in sync with CUTLASS_REVISION line above. - GIT_TAG v3.8.0 + GIT_TAG v3.9.0 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. @@ -290,7 +290,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" - "csrc/cutlass_extensions/common.cpp") + "csrc/cutlass_extensions/common.cpp" + "csrc/attention/mla/cutlass_mla_entry.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -463,7 +464,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(FP4_ARCHS) endif() - # + # CUTLASS MLA Archs and flags + cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS) + set(SRCS + "csrc/attention/mla/cutlass_mla_kernels.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${MLA_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1") + # Add MLA-specific include directories only to MLA source files + set_source_files_properties(${SRCS} + PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common") + message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}") + else() + message(STATUS "Not building CUTLASS MLA as no compatible archs were found.") + # clear MLA_ARCHS + set(MLA_ARCHS) + endif() + # CUTLASS MoE kernels # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works diff --git a/csrc/attention/mla/cutlass_mla_entry.cu b/csrc/attention/mla/cutlass_mla_entry.cu new file mode 100644 index 00000000000..6a8b3df250a --- /dev/null +++ b/csrc/attention/mla/cutlass_mla_entry.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA +void cutlass_mla_decode_sm100a(torch::Tensor const& out, + torch::Tensor const& q_nope_and_q_pe, + torch::Tensor const& kv_c_and_k_pe_cache, + torch::Tensor const& seq_lens, + torch::Tensor const& page_table); +#endif + +void cutlass_mla_decode(torch::Tensor const& out, + torch::Tensor const& q_nope_and_q_pe, + torch::Tensor const& kv_c_and_k_pe_cache, + torch::Tensor const& seq_lens, + torch::Tensor const& page_table) { +#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA + return cutlass_mla_decode_sm100a(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, + seq_lens, page_table); +#endif + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA"); +} diff --git a/csrc/attention/mla/cutlass_mla_kernels.cu b/csrc/attention/mla/cutlass_mla_kernels.cu new file mode 100644 index 00000000000..f33f8ec9098 --- /dev/null +++ b/csrc/attention/mla/cutlass_mla_kernels.cu @@ -0,0 +1,220 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.h" + +#include "cutlass_extensions/common.hpp" + +#include "device/sm100_mla.hpp" +#include "kernel/sm100_mla_tile_scheduler.hpp" + +using namespace cute; +using namespace cutlass::fmha::kernel; + +template +struct IsPersistent { + static const bool value = v; +}; + +template > +struct MlaSm100 { + using Element = T; + using ElementAcc = float; + using ElementOut = T; + + using TileShape = Shape<_128, _128, Shape<_512, _64>>; + using TileShapeH = cute::tuple_element_t<0, TileShape>; + using TileShapeD = cute::tuple_element_t<2, TileShape>; + + // H K (D_latent D_rope) B + using ProblemShape = cute::tuple; + + using StrideQ = cute::tuple; // H D B + using StrideK = cute::tuple; // K D B + using StrideO = StrideK; // H D B + using StrideLSE = cute::tuple<_1, int>; // H B + + using TileScheduler = std::conditional_t; + + using FmhaKernel = + cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized< + TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler, + /*kIsCpAsync=*/true>; + using Fmha = cutlass::fmha::device::MLA; +}; + +template +typename T::Fmha::Arguments args_from_options( + at::Tensor const& out, at::Tensor const& q_nope_and_q_pe, + at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens, + at::Tensor const& page_table) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = q_nope_and_q_pe.device().index(); + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + + int batches = q_nope_and_q_pe.sizes()[0]; + int page_count_per_seq = page_table.sizes()[1]; + int page_count_total = kv_c_and_k_pe_cache.sizes()[0]; + int page_size = kv_c_and_k_pe_cache.sizes()[1]; + int max_seq_len = page_size * page_count_per_seq; + using TileShapeH = typename T::TileShapeH; + using TileShapeD = typename T::TileShapeD; + auto problem_shape = + cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + // the scale is based on the non-absorbed sizes, change as appropriate + // we can't determine this parameter from the info we have, it's an input + int D_non_latent = 128; + float scale = 1.0 / sqrt(1.0 * (D_non_latent + D_rope)); + + using StrideQ = typename T::StrideQ; + using StrideK = typename T::StrideK; + using StrideO = typename T::StrideO; + using StrideLSE = typename T::StrideLSE; + + StrideQ stride_Q = + cute::make_tuple(static_cast(D_latent + D_rope), _1{}, + static_cast(H * (D_latent + D_rope))); + StrideK stride_C = + cute::make_tuple(static_cast(D_latent + D_rope), _1{}, + static_cast(page_size * (D_latent + D_rope))); + StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq); + StrideLSE stride_LSE = cute::make_tuple(_1{}, static_cast(H)); + StrideO stride_O = cute::make_tuple(static_cast(D_latent), _1{}, + static_cast(H * D_latent)); + + using Element = typename T::Element; + using ElementOut = typename T::ElementOut; + using ElementAcc = typename T::ElementAcc; + auto Q_ptr = static_cast(q_nope_and_q_pe.data_ptr()); + auto C_ptr = static_cast(kv_c_and_k_pe_cache.data_ptr()); + typename T::Fmha::Arguments arguments{ + problem_shape, + {scale, Q_ptr, stride_Q, Q_ptr + D_latent, stride_Q, C_ptr, stride_C, + C_ptr + D_latent, stride_C, static_cast(seq_lens.data_ptr()), + static_cast(page_table.data_ptr()), stride_PT, page_count_total, + page_size}, + {static_cast(out.data_ptr()), stride_O, + static_cast(nullptr), stride_LSE}, + hw_info, + -1, // split_kv + nullptr, // is_var_split_kv + }; + // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute + // split_kv automatically based on batch size and sequence length to balance + // workload across available SMs. Consider using var_split_kv for manual + // control if needed. + T::Fmha::set_split_kv(arguments); + return arguments; +} + +template +void runMla(at::Tensor const& out, at::Tensor const& q_nope_and_q_pe, + at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens, + at::Tensor const& page_table, cudaStream_t stream) { + using MlaSm100Type = MlaSm100; + typename MlaSm100Type::Fmha fmha; + auto arguments = args_from_options( + out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table); + size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments); + auto const workspace_options = torch::TensorOptions() + .dtype(torch::kUInt8) + .device(q_nope_and_q_pe.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + CUTLASS_CHECK(fmha.can_implement(arguments)); + + CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream)); + + CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream)); +} + +void cutlass_mla_decode_sm100a(torch::Tensor const& out, + torch::Tensor const& q_nope_and_q_pe, + torch::Tensor const& kv_c_and_k_pe_cache, + torch::Tensor const& seq_lens, + torch::Tensor const& page_table) { + TORCH_CHECK(q_nope_and_q_pe.device().is_cuda(), + "q_nope_and_q_pe must be on CUDA"); + TORCH_CHECK(q_nope_and_q_pe.dim() == 3, + "q_nope_and_q_pe must be a 3D tensor"); + TORCH_CHECK(kv_c_and_k_pe_cache.dim() == 3, + "kv_c_and_k_pe_cache must be a 3D tensor"); + TORCH_CHECK(seq_lens.dim() == 1, "seq_lens must be a 1D tensor"); + TORCH_CHECK(page_table.dim() == 2, "page_table must be a 2D tensor"); + + auto B_q = q_nope_and_q_pe.size(0); + auto H = q_nope_and_q_pe.size(1); + auto D_q = q_nope_and_q_pe.size(2); + auto B_pt = page_table.size(0); + auto PAGE_NUM = page_table.size(1); + auto PAGE_SIZE = kv_c_and_k_pe_cache.size(1); + auto D_ckv = kv_c_and_k_pe_cache.size(2); + + TORCH_CHECK(D_q == D_ckv && D_q == 576, + "D_q must be equal to D_ckv and D_q must be equal to 576"); + TORCH_CHECK(H == 128, "H must be 128"); + TORCH_CHECK(PAGE_SIZE > 0 && (PAGE_SIZE & (PAGE_SIZE - 1)) == 0, + "PAGE_SIZE must be a power of 2"); + TORCH_CHECK(B_q == B_pt, + "Batch dims must be same for page_table and q_nope_and_q_pe"); + TORCH_CHECK(PAGE_NUM % (128 / PAGE_SIZE) == 0, + "PAGE_NUM must be divisible by 128 / PAGE_SIZE"); + + TORCH_CHECK( + q_nope_and_q_pe.dtype() == at::ScalarType::Half || + q_nope_and_q_pe.dtype() == at::ScalarType::BFloat16 || + q_nope_and_q_pe.dtype() == at::ScalarType::Float8_e4m3fn, + "q_nope_and_q_pe must be a half, bfloat16, or float8_e4m3fn tensor"); + TORCH_CHECK(kv_c_and_k_pe_cache.dtype() == q_nope_and_q_pe.dtype(), + "kv_c_and_k_pe_cache must be the same type as q_nope_and_q_pe"); + TORCH_CHECK(seq_lens.dtype() == torch::kInt32, + "seq_lens must be a 32-bit integer tensor"); + TORCH_CHECK(page_table.dtype() == torch::kInt32, + "page_table must be a 32-bit integer tensor"); + + auto in_dtype = q_nope_and_q_pe.dtype(); + at::cuda::CUDAGuard device_guard{(char)q_nope_and_q_pe.get_device()}; + const cudaStream_t stream = + at::cuda::getCurrentCUDAStream(q_nope_and_q_pe.get_device()); + if (in_dtype == at::ScalarType::Half) { + runMla(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, + page_table, stream); + } else if (in_dtype == at::ScalarType::BFloat16) { + runMla(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, + seq_lens, page_table, stream); + } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { + runMla(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, + seq_lens, page_table, stream); + } else { + TORCH_CHECK(false, "Unsupported input data type of MLA"); + } +} diff --git a/csrc/ops.h b/csrc/ops.h index 86039a26041..fda2fbb967e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -128,6 +128,12 @@ void advance_step_flashinfer( torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); +void cutlass_mla_decode(torch::Tensor const& out, + torch::Tensor const& q_nope_and_q_pe, + torch::Tensor const& kv_c_and_k_pe_cache, + torch::Tensor const& seq_lens, + torch::Tensor const& page_table); + torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor); #ifndef USE_ROCM diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu index 6e14de0c780..97c0e0da7b1 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu @@ -134,7 +134,7 @@ typename T::Gemm::Arguments args_from_options( using StrideB = typename T::StrideB; using StrideD = typename T::StrideD; using Sm100BlkScaledConfig = - typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; int m = static_cast(M); int n = static_cast(N); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b6ff6a006c0..ccfb18cd576 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -130,6 +130,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ") -> ()"); ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer); + // Compute MLA decode using cutlass. + ops.def( + "cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe," + " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," + " Tensor page_table) -> ()"); + ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); + // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( diff --git a/tests/kernels/test_cutlass_mla_decode.py b/tests/kernels/test_cutlass_mla_decode.py new file mode 100644 index 00000000000..ce0d5c05cf0 --- /dev/null +++ b/tests/kernels/test_cutlass_mla_decode.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch +import torch.nn.functional as F +from torch import Tensor + +import vllm._custom_ops as ops +from vllm.platforms import current_platform + +if not current_platform.has_device_capability(100): + pytest.skip( + reason="Cutlass MLA Requires compute capability of 10 or above.", + allow_module_level=True) + + +def ref_mla( + out: Tensor, # (bs, num_heads, v_head_dim) + query: Tensor, # (bs, num_heads, head_dim) + kv_cache: Tensor, # (num_blocks, block_size, head_dim) + scale: float, + block_tables: Tensor, # (bs, max_num_blocks) + seq_lens: Tensor, # (bs,) +): + bs, num_heads, v_head_dim = out.shape + head_dim = query.shape[2] + + for i in range(bs): + # gather and flatten KV-cache + kv = kv_cache[ + block_tables[i]] # (max_num_blocks, block_size, head_dim) + kv = kv.view(1, -1, + head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) + v = kv[:, :, :v_head_dim] + + q = query[i].view(num_heads, 1, head_dim) + o = F.scaled_dot_product_attention(q, + kv, + v, + scale=scale, + enable_gqa=True) + out[i] = o.view(num_heads, v_head_dim) + + return out + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096]) +@pytest.mark.parametrize("bs", [1, 2, 4]) +@pytest.mark.parametrize("varlen", [False, True]) +@pytest.mark.parametrize("block_size", [16, 64, 128]) +def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, + varlen: bool, block_size: int): + torch.set_default_dtype(dtype) + torch.set_default_device('cuda') + torch.manual_seed(42) + + d = 576 + h_q = 128 + dv = 512 + + q_nope_dim = 128 + q_pe_dim = 64 + scale = (q_nope_dim + q_pe_dim)**(-0.5) + if varlen: + seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2) + seq_lens = seq_lens.clip(2).to(torch.int32) + else: + seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32) + max_seq_len = seq_lens.max().item() + block_num = (max_seq_len + block_size - 1) // block_size + + # Pad block_num so that small blocks can be packed into full 128-sized + # CUTLASS tiles. One 128-wide tile can hold (128 // block_size) small + # blocks. + pack_factor = 128 // block_size + block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor + + q = torch.randn(bs, h_q, d) + block_table = torch.randint(0, + bs * block_num, (bs, block_num), + dtype=torch.int32) + + kv_cache = torch.randn(block_table.numel(), block_size, d) + + out_ref = q.new_zeros(bs, h_q, dv) + ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) + out = ops.cutlass_mla_decode(q, kv_cache, seq_lens, block_table) + + torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 11297d3b9f5..041ed2049f0 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1525,3 +1525,18 @@ def flash_mla_with_kvcache( num_splits, ) return out, softmax_lse + + +def cutlass_mla_decode(q_nope_and_q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + seq_lens: torch.Tensor, + page_table: torch.Tensor) -> torch.Tensor: + B_q, H, _ = q_nope_and_q_pe.shape + + out = torch.empty((B_q, H, 512), + device=q_nope_and_q_pe.device, + dtype=q_nope_and_q_pe.dtype) + + torch.ops._C.cutlass_mla_decode(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, + seq_lens, page_table) + return out From 223e9345f91d86549577fecf349883f24252fecc Mon Sep 17 00:00:00 2001 From: kaixih Date: Sat, 26 Apr 2025 09:10:13 +0000 Subject: [PATCH 2/5] Address comments Signed-off-by: kaixih --- csrc/attention/mla/cutlass_mla_entry.cu | 8 ++-- csrc/attention/mla/cutlass_mla_kernels.cu | 46 +++++++++++------------ csrc/ops.h | 3 +- csrc/torch_bindings.cpp | 2 +- tests/kernels/test_cutlass_mla_decode.py | 5 ++- vllm/_custom_ops.py | 12 +++--- 6 files changed, 38 insertions(+), 38 deletions(-) diff --git a/csrc/attention/mla/cutlass_mla_entry.cu b/csrc/attention/mla/cutlass_mla_entry.cu index 6a8b3df250a..eae253ea40a 100644 --- a/csrc/attention/mla/cutlass_mla_entry.cu +++ b/csrc/attention/mla/cutlass_mla_entry.cu @@ -21,17 +21,19 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out, torch::Tensor const& q_nope_and_q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, - torch::Tensor const& page_table); + torch::Tensor const& page_table, + double scale); #endif void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope_and_q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, - torch::Tensor const& page_table) { + torch::Tensor const& page_table, + double scale) { #if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA return cutlass_mla_decode_sm100a(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table); + seq_lens, page_table, scale); #endif TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA"); } diff --git a/csrc/attention/mla/cutlass_mla_kernels.cu b/csrc/attention/mla/cutlass_mla_kernels.cu index f33f8ec9098..a3009451a8a 100644 --- a/csrc/attention/mla/cutlass_mla_kernels.cu +++ b/csrc/attention/mla/cutlass_mla_kernels.cu @@ -32,12 +32,7 @@ using namespace cute; using namespace cutlass::fmha::kernel; -template -struct IsPersistent { - static const bool value = v; -}; - -template > +template struct MlaSm100 { using Element = T; using ElementAcc = float; @@ -55,7 +50,7 @@ struct MlaSm100 { using StrideO = StrideK; // H D B using StrideLSE = cute::tuple<_1, int>; // H B - using TileScheduler = std::conditional_t; @@ -70,7 +65,7 @@ template typename T::Fmha::Arguments args_from_options( at::Tensor const& out, at::Tensor const& q_nope_and_q_pe, at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens, - at::Tensor const& page_table) { + at::Tensor const& page_table, double scale) { cutlass::KernelHardwareInfo hw_info; hw_info.device_id = q_nope_and_q_pe.device().index(); hw_info.sm_count = @@ -90,11 +85,6 @@ typename T::Fmha::Arguments args_from_options( auto [H, K, D, B] = problem_shape; auto [D_latent, D_rope] = D; - // the scale is based on the non-absorbed sizes, change as appropriate - // we can't determine this parameter from the info we have, it's an input - int D_non_latent = 128; - float scale = 1.0 / sqrt(1.0 * (D_non_latent + D_rope)); - using StrideQ = typename T::StrideQ; using StrideK = typename T::StrideK; using StrideO = typename T::StrideO; @@ -116,9 +106,10 @@ typename T::Fmha::Arguments args_from_options( using ElementAcc = typename T::ElementAcc; auto Q_ptr = static_cast(q_nope_and_q_pe.data_ptr()); auto C_ptr = static_cast(kv_c_and_k_pe_cache.data_ptr()); + auto scale_f = static_cast(scale); typename T::Fmha::Arguments arguments{ problem_shape, - {scale, Q_ptr, stride_Q, Q_ptr + D_latent, stride_Q, C_ptr, stride_C, + {scale_f, Q_ptr, stride_Q, Q_ptr + D_latent, stride_Q, C_ptr, stride_C, C_ptr + D_latent, stride_C, static_cast(seq_lens.data_ptr()), static_cast(page_table.data_ptr()), stride_PT, page_count_total, page_size}, @@ -139,11 +130,11 @@ typename T::Fmha::Arguments args_from_options( template void runMla(at::Tensor const& out, at::Tensor const& q_nope_and_q_pe, at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens, - at::Tensor const& page_table, cudaStream_t stream) { + at::Tensor const& page_table, float scale, cudaStream_t stream) { using MlaSm100Type = MlaSm100; typename MlaSm100Type::Fmha fmha; auto arguments = args_from_options( - out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table); + out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale); size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments); auto const workspace_options = torch::TensorOptions() .dtype(torch::kUInt8) @@ -161,7 +152,8 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out, torch::Tensor const& q_nope_and_q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, - torch::Tensor const& page_table) { + torch::Tensor const& page_table, + double scale) { TORCH_CHECK(q_nope_and_q_pe.device().is_cuda(), "q_nope_and_q_pe must be on CUDA"); TORCH_CHECK(q_nope_and_q_pe.dim() == 3, @@ -170,24 +162,30 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out, "kv_c_and_k_pe_cache must be a 3D tensor"); TORCH_CHECK(seq_lens.dim() == 1, "seq_lens must be a 1D tensor"); TORCH_CHECK(page_table.dim() == 2, "page_table must be a 2D tensor"); + TORCH_CHECK(out.dim() == 3, "out must be a 3D tensor"); auto B_q = q_nope_and_q_pe.size(0); - auto H = q_nope_and_q_pe.size(1); + auto H_q = q_nope_and_q_pe.size(1); auto D_q = q_nope_and_q_pe.size(2); auto B_pt = page_table.size(0); auto PAGE_NUM = page_table.size(1); auto PAGE_SIZE = kv_c_and_k_pe_cache.size(1); auto D_ckv = kv_c_and_k_pe_cache.size(2); + auto B_o = out.size(0); + auto H_o = out.size(1); + auto D_o = out.size(2); TORCH_CHECK(D_q == D_ckv && D_q == 576, "D_q must be equal to D_ckv and D_q must be equal to 576"); - TORCH_CHECK(H == 128, "H must be 128"); + TORCH_CHECK(H_q == H_o && H_q == 128, + "H_q must be equal to H_o and H_q must be 128"); TORCH_CHECK(PAGE_SIZE > 0 && (PAGE_SIZE & (PAGE_SIZE - 1)) == 0, "PAGE_SIZE must be a power of 2"); - TORCH_CHECK(B_q == B_pt, - "Batch dims must be same for page_table and q_nope_and_q_pe"); + TORCH_CHECK(B_q == B_pt && B_q == B_o, + "Batch dims must be same for page_table, q_nope_and_q_pe, and out"); TORCH_CHECK(PAGE_NUM % (128 / PAGE_SIZE) == 0, "PAGE_NUM must be divisible by 128 / PAGE_SIZE"); + TORCH_CHECK(D_o == 512, "D_o must be equal to 512"); TORCH_CHECK( q_nope_and_q_pe.dtype() == at::ScalarType::Half || @@ -207,13 +205,13 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out, at::cuda::getCurrentCUDAStream(q_nope_and_q_pe.get_device()); if (in_dtype == at::ScalarType::Half) { runMla(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, - page_table, stream); + page_table, scale, stream); } else if (in_dtype == at::ScalarType::BFloat16) { runMla(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table, stream); + seq_lens, page_table, scale, stream); } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { runMla(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table, stream); + seq_lens, page_table, scale, stream); } else { TORCH_CHECK(false, "Unsupported input data type of MLA"); } diff --git a/csrc/ops.h b/csrc/ops.h index fda2fbb967e..10514f29587 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -132,7 +132,8 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope_and_q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, - torch::Tensor const& page_table); + torch::Tensor const& page_table, + double scale); torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index ccfb18cd576..2e77fbad78c 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -134,7 +134,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe," " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," - " Tensor page_table) -> ()"); + " Tensor page_table, float scale) -> ()"); ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); // Layernorm diff --git a/tests/kernels/test_cutlass_mla_decode.py b/tests/kernels/test_cutlass_mla_decode.py index ce0d5c05cf0..65b44506cd7 100644 --- a/tests/kernels/test_cutlass_mla_decode.py +++ b/tests/kernels/test_cutlass_mla_decode.py @@ -84,6 +84,7 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, out_ref = q.new_zeros(bs, h_q, dv) ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) - out = ops.cutlass_mla_decode(q, kv_cache, seq_lens, block_table) + out_ans = torch.zeros_like(out_ref) + ops.cutlass_mla_decode(out_ans, q, kv_cache, seq_lens, block_table, scale) - torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 041ed2049f0..0120abec5d2 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1527,16 +1527,14 @@ def flash_mla_with_kvcache( return out, softmax_lse -def cutlass_mla_decode(q_nope_and_q_pe: torch.Tensor, +def cutlass_mla_decode(out: torch.Tensor, + q_nope_and_q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, seq_lens: torch.Tensor, - page_table: torch.Tensor) -> torch.Tensor: + page_table: torch.Tensor, + scale: float) -> torch.Tensor: B_q, H, _ = q_nope_and_q_pe.shape - out = torch.empty((B_q, H, 512), - device=q_nope_and_q_pe.device, - dtype=q_nope_and_q_pe.dtype) - torch.ops._C.cutlass_mla_decode(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table) + seq_lens, page_table, scale) return out From d094f200cf6fe7646d44158c39c18e3bf1d4a50f Mon Sep 17 00:00:00 2001 From: kaixih Date: Sat, 26 Apr 2025 09:21:03 +0000 Subject: [PATCH 3/5] Address format Signed-off-by: kaixih --- csrc/attention/mla/cutlass_mla_entry.cu | 6 ++---- csrc/attention/mla/cutlass_mla_kernels.cu | 14 +++++++------- csrc/ops.h | 3 +-- vllm/_custom_ops.py | 6 ++---- 4 files changed, 12 insertions(+), 17 deletions(-) diff --git a/csrc/attention/mla/cutlass_mla_entry.cu b/csrc/attention/mla/cutlass_mla_entry.cu index eae253ea40a..14835f2947a 100644 --- a/csrc/attention/mla/cutlass_mla_entry.cu +++ b/csrc/attention/mla/cutlass_mla_entry.cu @@ -21,16 +21,14 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out, torch::Tensor const& q_nope_and_q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, - torch::Tensor const& page_table, - double scale); + torch::Tensor const& page_table, double scale); #endif void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope_and_q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, - torch::Tensor const& page_table, - double scale) { + torch::Tensor const& page_table, double scale) { #if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA return cutlass_mla_decode_sm100a(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale); diff --git a/csrc/attention/mla/cutlass_mla_kernels.cu b/csrc/attention/mla/cutlass_mla_kernels.cu index a3009451a8a..d3fbce26624 100644 --- a/csrc/attention/mla/cutlass_mla_kernels.cu +++ b/csrc/attention/mla/cutlass_mla_kernels.cu @@ -50,9 +50,9 @@ struct MlaSm100 { using StrideO = StrideK; // H D B using StrideLSE = cute::tuple<_1, int>; // H B - using TileScheduler = std::conditional_t; + using TileScheduler = + std::conditional_t; using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized< @@ -152,8 +152,7 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out, torch::Tensor const& q_nope_and_q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, - torch::Tensor const& page_table, - double scale) { + torch::Tensor const& page_table, double scale) { TORCH_CHECK(q_nope_and_q_pe.device().is_cuda(), "q_nope_and_q_pe must be on CUDA"); TORCH_CHECK(q_nope_and_q_pe.dim() == 3, @@ -181,8 +180,9 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out, "H_q must be equal to H_o and H_q must be 128"); TORCH_CHECK(PAGE_SIZE > 0 && (PAGE_SIZE & (PAGE_SIZE - 1)) == 0, "PAGE_SIZE must be a power of 2"); - TORCH_CHECK(B_q == B_pt && B_q == B_o, - "Batch dims must be same for page_table, q_nope_and_q_pe, and out"); + TORCH_CHECK( + B_q == B_pt && B_q == B_o, + "Batch dims must be same for page_table, q_nope_and_q_pe, and out"); TORCH_CHECK(PAGE_NUM % (128 / PAGE_SIZE) == 0, "PAGE_NUM must be divisible by 128 / PAGE_SIZE"); TORCH_CHECK(D_o == 512, "D_o must be equal to 512"); diff --git a/csrc/ops.h b/csrc/ops.h index 10514f29587..1578150a1e9 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -132,8 +132,7 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope_and_q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, - torch::Tensor const& page_table, - double scale); + torch::Tensor const& page_table, double scale); torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor); diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0120abec5d2..501006d90f1 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1527,11 +1527,9 @@ def flash_mla_with_kvcache( return out, softmax_lse -def cutlass_mla_decode(out: torch.Tensor, - q_nope_and_q_pe: torch.Tensor, +def cutlass_mla_decode(out: torch.Tensor, q_nope_and_q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, - seq_lens: torch.Tensor, - page_table: torch.Tensor, + seq_lens: torch.Tensor, page_table: torch.Tensor, scale: float) -> torch.Tensor: B_q, H, _ = q_nope_and_q_pe.shape From 00a0c6a72b06a5859a955fe6cfce3552ac31b4cb Mon Sep 17 00:00:00 2001 From: kaixih Date: Sat, 26 Apr 2025 23:02:21 +0000 Subject: [PATCH 4/5] Support separate q_nope and q_pe Signed-off-by: kaixih --- csrc/attention/mla/cutlass_mla_entry.cu | 8 +- csrc/attention/mla/cutlass_mla_kernels.cu | 94 +++++++++++++---------- csrc/ops.h | 3 +- csrc/torch_bindings.cpp | 2 +- tests/kernels/test_cutlass_mla_decode.py | 5 +- vllm/_custom_ops.py | 7 +- 6 files changed, 67 insertions(+), 52 deletions(-) diff --git a/csrc/attention/mla/cutlass_mla_entry.cu b/csrc/attention/mla/cutlass_mla_entry.cu index 14835f2947a..758fbe5653f 100644 --- a/csrc/attention/mla/cutlass_mla_entry.cu +++ b/csrc/attention/mla/cutlass_mla_entry.cu @@ -18,19 +18,21 @@ #if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA void cutlass_mla_decode_sm100a(torch::Tensor const& out, - torch::Tensor const& q_nope_and_q_pe, + torch::Tensor const& q_nope, + torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, torch::Tensor const& page_table, double scale); #endif void cutlass_mla_decode(torch::Tensor const& out, - torch::Tensor const& q_nope_and_q_pe, + torch::Tensor const& q_nope, + torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, torch::Tensor const& page_table, double scale) { #if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA - return cutlass_mla_decode_sm100a(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, + return cutlass_mla_decode_sm100a(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale); #endif TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA"); diff --git a/csrc/attention/mla/cutlass_mla_kernels.cu b/csrc/attention/mla/cutlass_mla_kernels.cu index d3fbce26624..6e6f4ec0a83 100644 --- a/csrc/attention/mla/cutlass_mla_kernels.cu +++ b/csrc/attention/mla/cutlass_mla_kernels.cu @@ -63,16 +63,16 @@ struct MlaSm100 { template typename T::Fmha::Arguments args_from_options( - at::Tensor const& out, at::Tensor const& q_nope_and_q_pe, + at::Tensor const& out, at::Tensor const& q_nope, at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens, at::Tensor const& page_table, double scale) { cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = q_nope_and_q_pe.device().index(); + hw_info.device_id = q_nope.device().index(); hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count( hw_info.device_id); - int batches = q_nope_and_q_pe.sizes()[0]; + int batches = q_nope.sizes()[0]; int page_count_per_seq = page_table.sizes()[1]; int page_count_total = kv_c_and_k_pe_cache.sizes()[0]; int page_size = kv_c_and_k_pe_cache.sizes()[1]; @@ -90,9 +90,11 @@ typename T::Fmha::Arguments args_from_options( using StrideO = typename T::StrideO; using StrideLSE = typename T::StrideLSE; - StrideQ stride_Q = - cute::make_tuple(static_cast(D_latent + D_rope), _1{}, - static_cast(H * (D_latent + D_rope))); + StrideQ stride_Q_latent = + cute::make_tuple(static_cast(D_latent), _1{}, + static_cast(H * D_latent)); + StrideQ stride_Q_rope = cute::make_tuple(static_cast(D_rope), _1{}, + static_cast(H * D_rope)); StrideK stride_C = cute::make_tuple(static_cast(D_latent + D_rope), _1{}, static_cast(page_size * (D_latent + D_rope))); @@ -104,13 +106,15 @@ typename T::Fmha::Arguments args_from_options( using Element = typename T::Element; using ElementOut = typename T::ElementOut; using ElementAcc = typename T::ElementAcc; - auto Q_ptr = static_cast(q_nope_and_q_pe.data_ptr()); + auto Q_latent_ptr = static_cast(q_nope.data_ptr()); + auto Q_rope_ptr = static_cast(q_pe.data_ptr()); auto C_ptr = static_cast(kv_c_and_k_pe_cache.data_ptr()); auto scale_f = static_cast(scale); typename T::Fmha::Arguments arguments{ problem_shape, - {scale_f, Q_ptr, stride_Q, Q_ptr + D_latent, stride_Q, C_ptr, stride_C, - C_ptr + D_latent, stride_C, static_cast(seq_lens.data_ptr()), + {scale_f, Q_latent_ptr, stride_Q_latent, Q_rope_ptr, stride_Q_rope, + C_ptr, stride_C, C_ptr + D_latent, stride_C, + static_cast(seq_lens.data_ptr()), static_cast(page_table.data_ptr()), stride_PT, page_count_total, page_size}, {static_cast(out.data_ptr()), stride_O, @@ -128,17 +132,18 @@ typename T::Fmha::Arguments args_from_options( } template -void runMla(at::Tensor const& out, at::Tensor const& q_nope_and_q_pe, - at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens, - at::Tensor const& page_table, float scale, cudaStream_t stream) { +void runMla(at::Tensor const& out, at::Tensor const& q_nope, + at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, + at::Tensor const& seq_lens, at::Tensor const& page_table, + float scale, cudaStream_t stream) { using MlaSm100Type = MlaSm100; typename MlaSm100Type::Fmha fmha; auto arguments = args_from_options( - out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale); + out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale); size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments); auto const workspace_options = torch::TensorOptions() .dtype(torch::kUInt8) - .device(q_nope_and_q_pe.device()); + .device(q_nope.device()); auto workspace = torch::empty(workspace_size, workspace_options); CUTLASS_CHECK(fmha.can_implement(arguments)); @@ -149,23 +154,26 @@ void runMla(at::Tensor const& out, at::Tensor const& q_nope_and_q_pe, } void cutlass_mla_decode_sm100a(torch::Tensor const& out, - torch::Tensor const& q_nope_and_q_pe, + torch::Tensor const& q_nope, + torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, torch::Tensor const& page_table, double scale) { - TORCH_CHECK(q_nope_and_q_pe.device().is_cuda(), - "q_nope_and_q_pe must be on CUDA"); - TORCH_CHECK(q_nope_and_q_pe.dim() == 3, - "q_nope_and_q_pe must be a 3D tensor"); + TORCH_CHECK(q_nope.device().is_cuda(), "q_nope must be on CUDA"); + TORCH_CHECK(q_nope.dim() == 3, "q_nope must be a 3D tensor"); + TORCH_CHECK(q_pe.dim() == 3, "q_pe must be a 3D tensor"); TORCH_CHECK(kv_c_and_k_pe_cache.dim() == 3, "kv_c_and_k_pe_cache must be a 3D tensor"); TORCH_CHECK(seq_lens.dim() == 1, "seq_lens must be a 1D tensor"); TORCH_CHECK(page_table.dim() == 2, "page_table must be a 2D tensor"); TORCH_CHECK(out.dim() == 3, "out must be a 3D tensor"); - auto B_q = q_nope_and_q_pe.size(0); - auto H_q = q_nope_and_q_pe.size(1); - auto D_q = q_nope_and_q_pe.size(2); + auto B_q_nope = q_nope.size(0); + auto H_q_nope = q_nope.size(1); + auto D_q_nope = q_nope.size(2); + auto B_q_pe = q_pe.size(0); + auto H_q_pe = q_pe.size(1); + auto D_q_pe = q_pe.size(2); auto B_pt = page_table.size(0); auto PAGE_NUM = page_table.size(1); auto PAGE_SIZE = kv_c_and_k_pe_cache.size(1); @@ -174,44 +182,46 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out, auto H_o = out.size(1); auto D_o = out.size(2); - TORCH_CHECK(D_q == D_ckv && D_q == 576, - "D_q must be equal to D_ckv and D_q must be equal to 576"); - TORCH_CHECK(H_q == H_o && H_q == 128, - "H_q must be equal to H_o and H_q must be 128"); + TORCH_CHECK(D_q_nope == 512, "D_q_nope must be equal to 512"); + TORCH_CHECK(D_q_pe == 64, "D_q_pe must be equal to 64"); + TORCH_CHECK(D_ckv == 576, "D_ckv must be equal to 576"); + TORCH_CHECK(H_q_nope == H_q_pe && H_q_nope == H_o && H_o == 128, + "H_q_nope, H_q_pe, and H_o must be equal to 128"); TORCH_CHECK(PAGE_SIZE > 0 && (PAGE_SIZE & (PAGE_SIZE - 1)) == 0, "PAGE_SIZE must be a power of 2"); TORCH_CHECK( - B_q == B_pt && B_q == B_o, - "Batch dims must be same for page_table, q_nope_and_q_pe, and out"); + B_q_nope == B_q_pe && B_q_nope == B_pt && B_q_nope == B_o, + "Batch dims must be same for page_table, q_nope and q_pe, and out"); TORCH_CHECK(PAGE_NUM % (128 / PAGE_SIZE) == 0, "PAGE_NUM must be divisible by 128 / PAGE_SIZE"); TORCH_CHECK(D_o == 512, "D_o must be equal to 512"); TORCH_CHECK( - q_nope_and_q_pe.dtype() == at::ScalarType::Half || - q_nope_and_q_pe.dtype() == at::ScalarType::BFloat16 || - q_nope_and_q_pe.dtype() == at::ScalarType::Float8_e4m3fn, - "q_nope_and_q_pe must be a half, bfloat16, or float8_e4m3fn tensor"); - TORCH_CHECK(kv_c_and_k_pe_cache.dtype() == q_nope_and_q_pe.dtype(), - "kv_c_and_k_pe_cache must be the same type as q_nope_and_q_pe"); + q_nope.dtype() == at::ScalarType::Half || + q_nope.dtype() == at::ScalarType::BFloat16 || + q_nope.dtype() == at::ScalarType::Float8_e4m3fn, + "q_nope must be a half, bfloat16, or float8_e4m3fn tensor"); + TORCH_CHECK(kv_c_and_k_pe_cache.dtype() == q_nope.dtype() && + q_nope.dtype() == q_pe.dtype(), + "kv_c_and_k_pe_cache, q_nope, and q_pe must be the same type"); TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be a 32-bit integer tensor"); TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must be a 32-bit integer tensor"); - auto in_dtype = q_nope_and_q_pe.dtype(); - at::cuda::CUDAGuard device_guard{(char)q_nope_and_q_pe.get_device()}; + auto in_dtype = q_nope.dtype(); + at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; const cudaStream_t stream = - at::cuda::getCurrentCUDAStream(q_nope_and_q_pe.get_device()); + at::cuda::getCurrentCUDAStream(q_nope.get_device()); if (in_dtype == at::ScalarType::Half) { - runMla(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, + runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale, stream); } else if (in_dtype == at::ScalarType::BFloat16) { - runMla(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table, scale, stream); + runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, + page_table, scale, stream); } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { - runMla(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table, scale, stream); + runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, + page_table, scale, stream); } else { TORCH_CHECK(false, "Unsupported input data type of MLA"); } diff --git a/csrc/ops.h b/csrc/ops.h index 1578150a1e9..fa8944e434e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -129,7 +129,8 @@ void advance_step_flashinfer( torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); void cutlass_mla_decode(torch::Tensor const& out, - torch::Tensor const& q_nope_and_q_pe, + torch::Tensor const& q_nope, + torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, torch::Tensor const& page_table, double scale); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 2e77fbad78c..c9a120976b1 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -132,7 +132,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute MLA decode using cutlass. ops.def( - "cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe," + "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," " Tensor page_table, float scale) -> ()"); ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); diff --git a/tests/kernels/test_cutlass_mla_decode.py b/tests/kernels/test_cutlass_mla_decode.py index 65b44506cd7..87e4bd4b096 100644 --- a/tests/kernels/test_cutlass_mla_decode.py +++ b/tests/kernels/test_cutlass_mla_decode.py @@ -85,6 +85,9 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, out_ref = q.new_zeros(bs, h_q, dv) ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) out_ans = torch.zeros_like(out_ref) - ops.cutlass_mla_decode(out_ans, q, kv_cache, seq_lens, block_table, scale) + q_nope = q[:, :, :dv].clone() + q_pe = q[:, :, dv:].clone() + ops.cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache, seq_lens, + block_table, scale) torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 501006d90f1..ec9e81f0295 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1527,12 +1527,11 @@ def flash_mla_with_kvcache( return out, softmax_lse -def cutlass_mla_decode(out: torch.Tensor, q_nope_and_q_pe: torch.Tensor, +def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, + q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, seq_lens: torch.Tensor, page_table: torch.Tensor, scale: float) -> torch.Tensor: - B_q, H, _ = q_nope_and_q_pe.shape - - torch.ops._C.cutlass_mla_decode(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, + torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale) return out From 1ed6dc31c72cb500a415fcb65ad860a01770a978 Mon Sep 17 00:00:00 2001 From: kaixih Date: Sat, 26 Apr 2025 23:14:08 +0000 Subject: [PATCH 5/5] Format Signed-off-by: kaixih --- csrc/attention/mla/cutlass_mla_entry.cu | 3 +-- csrc/attention/mla/cutlass_mla_kernels.cu | 31 ++++++++++------------- csrc/ops.h | 3 +-- vllm/_custom_ops.py | 3 +-- 4 files changed, 17 insertions(+), 23 deletions(-) diff --git a/csrc/attention/mla/cutlass_mla_entry.cu b/csrc/attention/mla/cutlass_mla_entry.cu index 758fbe5653f..0319d1daf30 100644 --- a/csrc/attention/mla/cutlass_mla_entry.cu +++ b/csrc/attention/mla/cutlass_mla_entry.cu @@ -25,8 +25,7 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out, torch::Tensor const& page_table, double scale); #endif -void cutlass_mla_decode(torch::Tensor const& out, - torch::Tensor const& q_nope, +void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, diff --git a/csrc/attention/mla/cutlass_mla_kernels.cu b/csrc/attention/mla/cutlass_mla_kernels.cu index 6e6f4ec0a83..6743af0cf2d 100644 --- a/csrc/attention/mla/cutlass_mla_kernels.cu +++ b/csrc/attention/mla/cutlass_mla_kernels.cu @@ -90,9 +90,8 @@ typename T::Fmha::Arguments args_from_options( using StrideO = typename T::StrideO; using StrideLSE = typename T::StrideLSE; - StrideQ stride_Q_latent = - cute::make_tuple(static_cast(D_latent), _1{}, - static_cast(H * D_latent)); + StrideQ stride_Q_latent = cute::make_tuple( + static_cast(D_latent), _1{}, static_cast(H * D_latent)); StrideQ stride_Q_rope = cute::make_tuple(static_cast(D_rope), _1{}, static_cast(H * D_rope)); StrideK stride_C = @@ -112,8 +111,8 @@ typename T::Fmha::Arguments args_from_options( auto scale_f = static_cast(scale); typename T::Fmha::Arguments arguments{ problem_shape, - {scale_f, Q_latent_ptr, stride_Q_latent, Q_rope_ptr, stride_Q_rope, - C_ptr, stride_C, C_ptr + D_latent, stride_C, + {scale_f, Q_latent_ptr, stride_Q_latent, Q_rope_ptr, stride_Q_rope, C_ptr, + stride_C, C_ptr + D_latent, stride_C, static_cast(seq_lens.data_ptr()), static_cast(page_table.data_ptr()), stride_PT, page_count_total, page_size}, @@ -141,9 +140,8 @@ void runMla(at::Tensor const& out, at::Tensor const& q_nope, auto arguments = args_from_options( out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale); size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments); - auto const workspace_options = torch::TensorOptions() - .dtype(torch::kUInt8) - .device(q_nope.device()); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(q_nope.device()); auto workspace = torch::empty(workspace_size, workspace_options); CUTLASS_CHECK(fmha.can_implement(arguments)); @@ -196,11 +194,10 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out, "PAGE_NUM must be divisible by 128 / PAGE_SIZE"); TORCH_CHECK(D_o == 512, "D_o must be equal to 512"); - TORCH_CHECK( - q_nope.dtype() == at::ScalarType::Half || - q_nope.dtype() == at::ScalarType::BFloat16 || - q_nope.dtype() == at::ScalarType::Float8_e4m3fn, - "q_nope must be a half, bfloat16, or float8_e4m3fn tensor"); + TORCH_CHECK(q_nope.dtype() == at::ScalarType::Half || + q_nope.dtype() == at::ScalarType::BFloat16 || + q_nope.dtype() == at::ScalarType::Float8_e4m3fn, + "q_nope must be a half, bfloat16, or float8_e4m3fn tensor"); TORCH_CHECK(kv_c_and_k_pe_cache.dtype() == q_nope.dtype() && q_nope.dtype() == q_pe.dtype(), "kv_c_and_k_pe_cache, q_nope, and q_pe must be the same type"); @@ -217,11 +214,11 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out, runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale, stream); } else if (in_dtype == at::ScalarType::BFloat16) { - runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, - page_table, scale, stream); + runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, + seq_lens, page_table, scale, stream); } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { - runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, - page_table, scale, stream); + runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, + seq_lens, page_table, scale, stream); } else { TORCH_CHECK(false, "Unsupported input data type of MLA"); } diff --git a/csrc/ops.h b/csrc/ops.h index fa8944e434e..fe120af5d56 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -128,8 +128,7 @@ void advance_step_flashinfer( torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); -void cutlass_mla_decode(torch::Tensor const& out, - torch::Tensor const& q_nope, +void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ec9e81f0295..4c577c1c47e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1528,8 +1528,7 @@ def flash_mla_with_kvcache( def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, + q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, seq_lens: torch.Tensor, page_table: torch.Tensor, scale: float) -> torch.Tensor: torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,