Skip to content

Commit ed7a29d

Browse files
authored
[NVIDIA] Support Cutlass MLA for Blackwell GPUs (#16032)
Signed-off-by: kaixih <kaixih@nvidia.com>
1 parent 756848e commit ed7a29d

File tree

8 files changed

+403
-5
lines changed

8 files changed

+403
-5
lines changed

CMakeLists.txt

+24-4
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
251251

252252
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
253253
# Please keep this in sync with FetchContent_Declare line below.
254-
set(CUTLASS_REVISION "v3.8.0" CACHE STRING "CUTLASS revision to use")
254+
set(CUTLASS_REVISION "v3.9.0" CACHE STRING "CUTLASS revision to use")
255255

256256
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
257257
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@@ -269,7 +269,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
269269
cutlass
270270
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
271271
# Please keep this in sync with CUTLASS_REVISION line above.
272-
GIT_TAG v3.8.0
272+
GIT_TAG v3.9.0
273273
GIT_PROGRESS TRUE
274274

275275
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
@@ -290,7 +290,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
290290
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
291291
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
292292
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
293-
"csrc/cutlass_extensions/common.cpp")
293+
"csrc/cutlass_extensions/common.cpp"
294+
"csrc/attention/mla/cutlass_mla_entry.cu")
294295

295296
set_gencode_flags_for_srcs(
296297
SRCS "${VLLM_EXT_SRC}"
@@ -463,7 +464,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
463464
set(FP4_ARCHS)
464465
endif()
465466

466-
#
467+
# CUTLASS MLA Archs and flags
468+
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
469+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS)
470+
set(SRCS
471+
"csrc/attention/mla/cutlass_mla_kernels.cu")
472+
set_gencode_flags_for_srcs(
473+
SRCS "${SRCS}"
474+
CUDA_ARCHS "${MLA_ARCHS}")
475+
list(APPEND VLLM_EXT_SRC "${SRCS}")
476+
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1")
477+
# Add MLA-specific include directories only to MLA source files
478+
set_source_files_properties(${SRCS}
479+
PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common")
480+
message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}")
481+
else()
482+
message(STATUS "Not building CUTLASS MLA as no compatible archs were found.")
483+
# clear MLA_ARCHS
484+
set(MLA_ARCHS)
485+
endif()
486+
467487
# CUTLASS MoE kernels
468488

469489
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <torch/all.h>
18+
19+
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
20+
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
21+
torch::Tensor const& q_nope,
22+
torch::Tensor const& q_pe,
23+
torch::Tensor const& kv_c_and_k_pe_cache,
24+
torch::Tensor const& seq_lens,
25+
torch::Tensor const& page_table, double scale);
26+
#endif
27+
28+
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
29+
torch::Tensor const& q_pe,
30+
torch::Tensor const& kv_c_and_k_pe_cache,
31+
torch::Tensor const& seq_lens,
32+
torch::Tensor const& page_table, double scale) {
33+
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
34+
return cutlass_mla_decode_sm100a(out, q_nope, q_pe, kv_c_and_k_pe_cache,
35+
seq_lens, page_table, scale);
36+
#endif
37+
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA");
38+
}
+225
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <torch/all.h>
18+
19+
#include <ATen/cuda/CUDAContext.h>
20+
#include <c10/cuda/CUDAGuard.h>
21+
22+
#include "cute/tensor.hpp"
23+
24+
#include "cutlass/cutlass.h"
25+
#include "cutlass/kernel_hardware_info.h"
26+
27+
#include "cutlass_extensions/common.hpp"
28+
29+
#include "device/sm100_mla.hpp"
30+
#include "kernel/sm100_mla_tile_scheduler.hpp"
31+
32+
using namespace cute;
33+
using namespace cutlass::fmha::kernel;
34+
35+
template <typename T, bool PersistenceOption = true>
36+
struct MlaSm100 {
37+
using Element = T;
38+
using ElementAcc = float;
39+
using ElementOut = T;
40+
41+
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
42+
using TileShapeH = cute::tuple_element_t<0, TileShape>;
43+
using TileShapeD = cute::tuple_element_t<2, TileShape>;
44+
45+
// H K (D_latent D_rope) B
46+
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
47+
48+
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
49+
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
50+
using StrideO = StrideK; // H D B
51+
using StrideLSE = cute::tuple<_1, int>; // H B
52+
53+
using TileScheduler =
54+
std::conditional_t<PersistenceOption, Sm100MlaPersistentTileScheduler,
55+
Sm100MlaIndividualTileScheduler>;
56+
57+
using FmhaKernel =
58+
cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
59+
TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler,
60+
/*kIsCpAsync=*/true>;
61+
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
62+
};
63+
64+
template <typename T>
65+
typename T::Fmha::Arguments args_from_options(
66+
at::Tensor const& out, at::Tensor const& q_nope, at::Tensor const& q_pe,
67+
at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens,
68+
at::Tensor const& page_table, double scale) {
69+
cutlass::KernelHardwareInfo hw_info;
70+
hw_info.device_id = q_nope.device().index();
71+
hw_info.sm_count =
72+
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
73+
hw_info.device_id);
74+
75+
int batches = q_nope.sizes()[0];
76+
int page_count_per_seq = page_table.sizes()[1];
77+
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
78+
int page_size = kv_c_and_k_pe_cache.sizes()[1];
79+
int max_seq_len = page_size * page_count_per_seq;
80+
using TileShapeH = typename T::TileShapeH;
81+
using TileShapeD = typename T::TileShapeD;
82+
auto problem_shape =
83+
cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
84+
85+
auto [H, K, D, B] = problem_shape;
86+
auto [D_latent, D_rope] = D;
87+
88+
using StrideQ = typename T::StrideQ;
89+
using StrideK = typename T::StrideK;
90+
using StrideO = typename T::StrideO;
91+
using StrideLSE = typename T::StrideLSE;
92+
93+
StrideQ stride_Q_latent = cute::make_tuple(
94+
static_cast<int64_t>(D_latent), _1{}, static_cast<int64_t>(H * D_latent));
95+
StrideQ stride_Q_rope = cute::make_tuple(static_cast<int64_t>(D_rope), _1{},
96+
static_cast<int64_t>(H * D_rope));
97+
StrideK stride_C =
98+
cute::make_tuple(static_cast<int64_t>(D_latent + D_rope), _1{},
99+
static_cast<int64_t>(page_size * (D_latent + D_rope)));
100+
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
101+
StrideLSE stride_LSE = cute::make_tuple(_1{}, static_cast<int>(H));
102+
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(D_latent), _1{},
103+
static_cast<int64_t>(H * D_latent));
104+
105+
using Element = typename T::Element;
106+
using ElementOut = typename T::ElementOut;
107+
using ElementAcc = typename T::ElementAcc;
108+
auto Q_latent_ptr = static_cast<Element*>(q_nope.data_ptr());
109+
auto Q_rope_ptr = static_cast<Element*>(q_pe.data_ptr());
110+
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
111+
auto scale_f = static_cast<float>(scale);
112+
typename T::Fmha::Arguments arguments{
113+
problem_shape,
114+
{scale_f, Q_latent_ptr, stride_Q_latent, Q_rope_ptr, stride_Q_rope, C_ptr,
115+
stride_C, C_ptr + D_latent, stride_C,
116+
static_cast<int*>(seq_lens.data_ptr()),
117+
static_cast<int*>(page_table.data_ptr()), stride_PT, page_count_total,
118+
page_size},
119+
{static_cast<ElementOut*>(out.data_ptr()), stride_O,
120+
static_cast<ElementAcc*>(nullptr), stride_LSE},
121+
hw_info,
122+
-1, // split_kv
123+
nullptr, // is_var_split_kv
124+
};
125+
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
126+
// split_kv automatically based on batch size and sequence length to balance
127+
// workload across available SMs. Consider using var_split_kv for manual
128+
// control if needed.
129+
T::Fmha::set_split_kv(arguments);
130+
return arguments;
131+
}
132+
133+
template <typename Element>
134+
void runMla(at::Tensor const& out, at::Tensor const& q_nope,
135+
at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache,
136+
at::Tensor const& seq_lens, at::Tensor const& page_table,
137+
float scale, cudaStream_t stream) {
138+
using MlaSm100Type = MlaSm100<Element>;
139+
typename MlaSm100Type::Fmha fmha;
140+
auto arguments = args_from_options<MlaSm100Type>(
141+
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale);
142+
size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments);
143+
auto const workspace_options =
144+
torch::TensorOptions().dtype(torch::kUInt8).device(q_nope.device());
145+
auto workspace = torch::empty(workspace_size, workspace_options);
146+
147+
CUTLASS_CHECK(fmha.can_implement(arguments));
148+
149+
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
150+
151+
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
152+
}
153+
154+
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
155+
torch::Tensor const& q_nope,
156+
torch::Tensor const& q_pe,
157+
torch::Tensor const& kv_c_and_k_pe_cache,
158+
torch::Tensor const& seq_lens,
159+
torch::Tensor const& page_table, double scale) {
160+
TORCH_CHECK(q_nope.device().is_cuda(), "q_nope must be on CUDA");
161+
TORCH_CHECK(q_nope.dim() == 3, "q_nope must be a 3D tensor");
162+
TORCH_CHECK(q_pe.dim() == 3, "q_pe must be a 3D tensor");
163+
TORCH_CHECK(kv_c_and_k_pe_cache.dim() == 3,
164+
"kv_c_and_k_pe_cache must be a 3D tensor");
165+
TORCH_CHECK(seq_lens.dim() == 1, "seq_lens must be a 1D tensor");
166+
TORCH_CHECK(page_table.dim() == 2, "page_table must be a 2D tensor");
167+
TORCH_CHECK(out.dim() == 3, "out must be a 3D tensor");
168+
169+
auto B_q_nope = q_nope.size(0);
170+
auto H_q_nope = q_nope.size(1);
171+
auto D_q_nope = q_nope.size(2);
172+
auto B_q_pe = q_pe.size(0);
173+
auto H_q_pe = q_pe.size(1);
174+
auto D_q_pe = q_pe.size(2);
175+
auto B_pt = page_table.size(0);
176+
auto PAGE_NUM = page_table.size(1);
177+
auto PAGE_SIZE = kv_c_and_k_pe_cache.size(1);
178+
auto D_ckv = kv_c_and_k_pe_cache.size(2);
179+
auto B_o = out.size(0);
180+
auto H_o = out.size(1);
181+
auto D_o = out.size(2);
182+
183+
TORCH_CHECK(D_q_nope == 512, "D_q_nope must be equal to 512");
184+
TORCH_CHECK(D_q_pe == 64, "D_q_pe must be equal to 64");
185+
TORCH_CHECK(D_ckv == 576, "D_ckv must be equal to 576");
186+
TORCH_CHECK(H_q_nope == H_q_pe && H_q_nope == H_o && H_o == 128,
187+
"H_q_nope, H_q_pe, and H_o must be equal to 128");
188+
TORCH_CHECK(PAGE_SIZE > 0 && (PAGE_SIZE & (PAGE_SIZE - 1)) == 0,
189+
"PAGE_SIZE must be a power of 2");
190+
TORCH_CHECK(
191+
B_q_nope == B_q_pe && B_q_nope == B_pt && B_q_nope == B_o,
192+
"Batch dims must be same for page_table, q_nope and q_pe, and out");
193+
TORCH_CHECK(PAGE_NUM % (128 / PAGE_SIZE) == 0,
194+
"PAGE_NUM must be divisible by 128 / PAGE_SIZE");
195+
TORCH_CHECK(D_o == 512, "D_o must be equal to 512");
196+
197+
TORCH_CHECK(q_nope.dtype() == at::ScalarType::Half ||
198+
q_nope.dtype() == at::ScalarType::BFloat16 ||
199+
q_nope.dtype() == at::ScalarType::Float8_e4m3fn,
200+
"q_nope must be a half, bfloat16, or float8_e4m3fn tensor");
201+
TORCH_CHECK(kv_c_and_k_pe_cache.dtype() == q_nope.dtype() &&
202+
q_nope.dtype() == q_pe.dtype(),
203+
"kv_c_and_k_pe_cache, q_nope, and q_pe must be the same type");
204+
TORCH_CHECK(seq_lens.dtype() == torch::kInt32,
205+
"seq_lens must be a 32-bit integer tensor");
206+
TORCH_CHECK(page_table.dtype() == torch::kInt32,
207+
"page_table must be a 32-bit integer tensor");
208+
209+
auto in_dtype = q_nope.dtype();
210+
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
211+
const cudaStream_t stream =
212+
at::cuda::getCurrentCUDAStream(q_nope.get_device());
213+
if (in_dtype == at::ScalarType::Half) {
214+
runMla<cutlass::half_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens,
215+
page_table, scale, stream);
216+
} else if (in_dtype == at::ScalarType::BFloat16) {
217+
runMla<cutlass::bfloat16_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache,
218+
seq_lens, page_table, scale, stream);
219+
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
220+
runMla<cutlass::float_e4m3_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache,
221+
seq_lens, page_table, scale, stream);
222+
} else {
223+
TORCH_CHECK(false, "Unsupported input data type of MLA");
224+
}
225+
}

csrc/ops.h

+6
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,12 @@ void advance_step_flashinfer(
128128
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
129129
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
130130

131+
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
132+
torch::Tensor const& q_pe,
133+
torch::Tensor const& kv_c_and_k_pe_cache,
134+
torch::Tensor const& seq_lens,
135+
torch::Tensor const& page_table, double scale);
136+
131137
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
132138

133139
#ifndef USE_ROCM

csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ typename T::Gemm::Arguments args_from_options(
134134
using StrideB = typename T::StrideB;
135135
using StrideD = typename T::StrideD;
136136
using Sm100BlkScaledConfig =
137-
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
137+
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
138138

139139
int m = static_cast<int>(M);
140140
int n = static_cast<int>(N);

csrc/torch_bindings.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
130130
") -> ()");
131131
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
132132

133+
// Compute MLA decode using cutlass.
134+
ops.def(
135+
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
136+
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
137+
" Tensor page_table, float scale) -> ()");
138+
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
139+
133140
// Layernorm
134141
// Apply Root Mean Square (RMS) Normalization to the input tensor.
135142
ops.def(

0 commit comments

Comments
 (0)