Skip to content

Commit b8a7d56

Browse files
committed
Support cutlass MLA
Signed-off-by: kaixih <kaixih@nvidia.com>
1 parent d4bfc23 commit b8a7d56

File tree

7 files changed

+385
-2
lines changed

7 files changed

+385
-2
lines changed

CMakeLists.txt

+22-2
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
289289
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
290290
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
291291
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
292-
"csrc/cutlass_extensions/common.cpp")
292+
"csrc/cutlass_extensions/common.cpp"
293+
"csrc/attention/mla/cutlass_mla_entry.cu")
293294

294295
set_gencode_flags_for_srcs(
295296
SRCS "${VLLM_EXT_SRC}"
@@ -462,7 +463,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
462463
set(FP4_ARCHS)
463464
endif()
464465

465-
#
466+
# CUTLASS MLA Archs and flags
467+
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
468+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS)
469+
set(SRCS
470+
"csrc/attention/mla/cutlass_mla_kernels.cu")
471+
set_gencode_flags_for_srcs(
472+
SRCS "${SRCS}"
473+
CUDA_ARCHS "${MLA_ARCHS}")
474+
list(APPEND VLLM_EXT_SRC "${SRCS}")
475+
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1")
476+
# Add MLA-specific include directories only to MLA source files
477+
set_source_files_properties(${SRCS}
478+
PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_INCLUDE_DIR}/../examples/77_blackwell_fmha;${CUTLASS_INCLUDE_DIR}/../examples/common")
479+
message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}")
480+
else()
481+
message(STATUS "Not building CUTLASS MLA as no compatible archs were found.")
482+
# clear MLA_ARCHS
483+
set(MLA_ARCHS)
484+
endif()
485+
466486
# CUTLASS MoE kernels
467487

468488
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <torch/all.h>
18+
19+
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
20+
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
21+
torch::Tensor const& q_nope_and_q_pe,
22+
torch::Tensor const& kv_c_and_k_pe_cache,
23+
torch::Tensor const& seq_lens,
24+
torch::Tensor const& page_table);
25+
#endif
26+
27+
void cutlass_mla_decode(torch::Tensor const& out,
28+
torch::Tensor const& q_nope_and_q_pe,
29+
torch::Tensor const& kv_c_and_k_pe_cache,
30+
torch::Tensor const& seq_lens,
31+
torch::Tensor const& page_table) {
32+
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
33+
return cutlass_mla_decode_sm100a(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
34+
seq_lens, page_table);
35+
#endif
36+
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA");
37+
}
+201
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <torch/all.h>
18+
19+
#include <ATen/cuda/CUDAContext.h>
20+
#include <c10/cuda/CUDAGuard.h>
21+
22+
#include "cute/tensor.hpp"
23+
24+
#include "cutlass/cutlass.h"
25+
#include "cutlass/kernel_hardware_info.h"
26+
27+
#include "device/sm100_mla.hpp"
28+
#include "kernel/sm100_mla_tile_scheduler.hpp"
29+
30+
#define CUTLASS_CHECK(status) \
31+
{ \
32+
cutlass::Status error = status; \
33+
TORCH_CHECK(error == cutlass::Status::kSuccess, \
34+
cutlassGetStatusString(error)); \
35+
}
36+
37+
using namespace cute;
38+
using namespace cutlass::fmha::kernel;
39+
40+
template<bool v>
41+
struct IsPersistent {
42+
static const bool value = v;
43+
};
44+
45+
template <typename T, typename PersistenceOption = IsPersistent<true>>
46+
struct MlaSm100 {
47+
using Element = T;
48+
using ElementAcc = float;
49+
using ElementOut = T;
50+
51+
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
52+
using TileShapeH = cute::tuple_element_t<0, TileShape>;
53+
using TileShapeD = cute::tuple_element_t<2, TileShape>;
54+
55+
// H K (D_latent D_rope) B
56+
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
57+
58+
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
59+
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
60+
using StrideO = StrideK; // H D B
61+
using StrideLSE = cute::tuple<_1, int>; // H B
62+
63+
using TileScheduler = std::conditional_t<
64+
PersistenceOption::value,
65+
Sm100MlaPersistentTileScheduler,
66+
Sm100MlaIndividualTileScheduler>;
67+
68+
using FmhaKernel =
69+
cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
70+
TileShape, Element, ElementAcc, ElementOut, ElementAcc,
71+
TileScheduler, /*kIsCpAsync=*/true>;
72+
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
73+
};
74+
75+
76+
template <typename T>
77+
typename T::Fmha::Arguments args_from_options(at::Tensor const& out,
78+
at::Tensor const& q_nope_and_q_pe,
79+
at::Tensor const& kv_c_and_k_pe_cache,
80+
at::Tensor const& seq_lens,
81+
at::Tensor const& page_table) {
82+
cutlass::KernelHardwareInfo hw_info;
83+
hw_info.device_id = q_nope_and_q_pe.device().index();
84+
hw_info.sm_count =
85+
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
86+
hw_info.device_id);
87+
88+
int batches = q_nope_and_q_pe.sizes()[0];
89+
int page_count_per_seq = page_table.sizes()[1];
90+
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
91+
int page_size = kv_c_and_k_pe_cache.sizes()[1];
92+
int max_seq_len = page_size * page_count_per_seq;
93+
using TileShapeH = typename T::TileShapeH;
94+
using TileShapeD = typename T::TileShapeD;
95+
auto problem_shape = cute::make_tuple(
96+
TileShapeH{}, max_seq_len, TileShapeD{}, batches);
97+
98+
auto [H, K, D, B] = problem_shape;
99+
auto [D_latent, D_rope] = D;
100+
101+
// the scale is based on the non-absorbed sizes, change as appropriate
102+
// we can't determine this parameter from the info we have, it's an input
103+
int D_non_latent = 128;
104+
float scale = 1.0 / sqrt(1.0 * (D_non_latent + D_rope));
105+
106+
using StrideQ = typename T::StrideQ;
107+
using StrideK = typename T::StrideK;
108+
using StrideO = typename T::StrideO;
109+
using StrideLSE = typename T::StrideLSE;
110+
111+
StrideQ stride_Q = cute::make_tuple(
112+
static_cast<int64_t>(0 + D_latent + D_rope),
113+
_1{},
114+
static_cast<int64_t>(H * (0 + D_latent + D_rope)));
115+
StrideK stride_C = cute::make_tuple(
116+
static_cast<int64_t>(0 + D_latent + D_rope),
117+
_1{},
118+
static_cast<int64_t>(page_size * (D_latent + D_rope)));
119+
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
120+
StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H);
121+
StrideO stride_O = cute::make_tuple(
122+
static_cast<int64_t>(0 + D_latent),
123+
_1{},
124+
static_cast<int64_t>(0 + H * D_latent));
125+
126+
using Element = typename T::Element;
127+
using ElementOut = typename T::ElementOut;
128+
using ElementAcc = typename T::ElementAcc;
129+
auto Q_ptr = static_cast<Element*>(q_nope_and_q_pe.data_ptr());
130+
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
131+
typename T::Fmha::Arguments arguments{
132+
problem_shape,
133+
{ scale,
134+
Q_ptr, stride_Q,
135+
Q_ptr + D_latent, stride_Q,
136+
C_ptr, stride_C,
137+
C_ptr + D_latent, stride_C,
138+
static_cast<int*>(seq_lens.data_ptr()),
139+
static_cast<int*>(page_table.data_ptr()), stride_PT,
140+
page_count_total, page_size},
141+
{ static_cast<ElementOut*>(out.data_ptr()), stride_O,
142+
// static_cast<ElementAcc*>(lse.data_ptr()), stride_LSE},
143+
static_cast<ElementAcc*>(nullptr), stride_LSE},
144+
hw_info,
145+
-1, // split_kv
146+
nullptr, // is_var_split_kv
147+
};
148+
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
149+
// split_kv automatically based on batch size and sequence length to balance
150+
// workload across available SMs. Consider using var_split_kv for manual
151+
// control if needed.
152+
T::Fmha::set_split_kv(arguments);
153+
return arguments;
154+
}
155+
156+
template <typename Element>
157+
void runMla(at::Tensor const& out,
158+
at::Tensor const& q_nope_and_q_pe,
159+
at::Tensor const& kv_c_and_k_pe_cache,
160+
at::Tensor const& seq_lens,
161+
at::Tensor const& page_table,
162+
cudaStream_t stream) {
163+
using MlaSm100Type = MlaSm100<Element>;
164+
typename MlaSm100Type::Fmha fmha;
165+
auto arguments =
166+
args_from_options<MlaSm100Type>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
167+
seq_lens, page_table);
168+
size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments);
169+
auto const workspace_options =
170+
torch::TensorOptions().dtype(torch::kUInt8).device(q_nope_and_q_pe.device());
171+
auto workspace = torch::empty(workspace_size, workspace_options);
172+
173+
CUTLASS_CHECK(fmha.can_implement(arguments));
174+
175+
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
176+
177+
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
178+
}
179+
180+
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
181+
torch::Tensor const& q_nope_and_q_pe,
182+
torch::Tensor const& kv_c_and_k_pe_cache,
183+
torch::Tensor const& seq_lens,
184+
torch::Tensor const& page_table) {
185+
auto in_dtype = q_nope_and_q_pe.dtype();
186+
at::cuda::CUDAGuard device_guard{(char)q_nope_and_q_pe.get_device()};
187+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(
188+
q_nope_and_q_pe.get_device());
189+
if (in_dtype == at::ScalarType::Half) {
190+
runMla<cutlass::half_t>(
191+
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, stream);
192+
} else if (in_dtype == at::ScalarType::BFloat16) {
193+
runMla<cutlass::bfloat16_t>(
194+
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, stream);
195+
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
196+
runMla<cutlass::float_e4m3_t>(
197+
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, stream);
198+
} else {
199+
TORCH_CHECK(false, "Unsupported input data type of MLA");
200+
}
201+
}

csrc/ops.h

+6
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,12 @@ void advance_step_flashinfer(
119119
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
120120
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
121121

122+
void cutlass_mla_decode(torch::Tensor const& out,
123+
torch::Tensor const& q_nope_and_q_pe,
124+
torch::Tensor const& kv_c_and_k_pe_cache,
125+
torch::Tensor const& seq_lens,
126+
torch::Tensor const& page_table);
127+
122128
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
123129

124130
#ifndef USE_ROCM

csrc/torch_bindings.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
115115
") -> ()");
116116
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
117117

118+
// Compute MLA decode using cutlass.
119+
ops.def(
120+
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe,"
121+
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
122+
" Tensor page_table) -> ()");
123+
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
124+
118125
// Layernorm
119126
// Apply Root Mean Square (RMS) Normalization to the input tensor.
120127
ops.def(
+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import pytest
3+
import torch
4+
import torch.nn.functional as F
5+
from torch import Tensor
6+
7+
import vllm._custom_ops as ops
8+
from vllm.platforms import current_platform
9+
10+
if not current_platform.has_device_capability(100):
11+
pytest.skip(reason="Cutlass MLA Requires compute capability of 10 or above.",
12+
allow_module_level=True)
13+
14+
def ref_mla(
15+
out: Tensor, # (bs, num_heads, v_head_dim)
16+
query: Tensor, # (bs, num_heads, head_dim)
17+
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
18+
scale: float,
19+
block_tables: Tensor, # (bs, max_num_blocks)
20+
seq_lens: Tensor, # (bs,)
21+
):
22+
bs, num_heads, v_head_dim = out.shape
23+
head_dim = query.shape[2]
24+
25+
for i in range(bs):
26+
# gather and flatten KV-cache
27+
kv = kv_cache[
28+
block_tables[i]] # (max_num_blocks, block_size, head_dim)
29+
kv = kv.view(1, -1,
30+
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
31+
v = kv[:, :, :v_head_dim]
32+
33+
q = query[i].view(num_heads, 1, head_dim)
34+
o = F.scaled_dot_product_attention(q,
35+
kv,
36+
v,
37+
scale=scale,
38+
enable_gqa=True)
39+
out[i] = o.view(num_heads, v_head_dim)
40+
41+
return out
42+
43+
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
44+
@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096])
45+
@pytest.mark.parametrize("bs", [1, 2, 4])
46+
@pytest.mark.parametrize("varlen", [False, True])
47+
@pytest.mark.parametrize("block_size", [16, 128])
48+
def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, varlen: bool, block_size: int):
49+
torch.set_default_dtype(dtype)
50+
torch.set_default_device('cuda')
51+
torch.manual_seed(42)
52+
53+
d = 576
54+
h_q = 128
55+
dv = 512
56+
57+
q_nope_dim = 128
58+
q_pe_dim = 64
59+
scale = (q_nope_dim + q_pe_dim)**(-0.5)
60+
if varlen:
61+
seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
62+
seq_lens = seq_lens.clip(2).to(torch.int32)
63+
else:
64+
seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32)
65+
max_seq_len = seq_lens.max().item()
66+
block_num = (max_seq_len + block_size - 1) // block_size
67+
68+
q = torch.randn(bs, h_q, d)
69+
block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32)
70+
71+
kv_cache = torch.randn(block_table.numel(), block_size, d)
72+
73+
out_ref = q.new_zeros(bs, h_q, dv)
74+
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
75+
out = ops.cutlass_mla_decode(q, kv_cache, seq_lens, block_table)
76+
77+
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)

0 commit comments

Comments
 (0)