Skip to content

Commit 596c81a

Browse files
committed
Support cutlass MLA and Upgrade cutlass to 3.9
Signed-off-by: kaixih <kaixih@nvidia.com>
1 parent c53e073 commit 596c81a

File tree

8 files changed

+399
-5
lines changed

8 files changed

+399
-5
lines changed

CMakeLists.txt

+24-4
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
251251

252252
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
253253
# Please keep this in sync with FetchContent_Declare line below.
254-
set(CUTLASS_REVISION "v3.8.0" CACHE STRING "CUTLASS revision to use")
254+
set(CUTLASS_REVISION "v3.9.0" CACHE STRING "CUTLASS revision to use")
255255

256256
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
257257
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@@ -269,7 +269,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
269269
cutlass
270270
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
271271
# Please keep this in sync with CUTLASS_REVISION line above.
272-
GIT_TAG v3.8.0
272+
GIT_TAG v3.9.0
273273
GIT_PROGRESS TRUE
274274

275275
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
@@ -290,7 +290,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
290290
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
291291
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
292292
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
293-
"csrc/cutlass_extensions/common.cpp")
293+
"csrc/cutlass_extensions/common.cpp"
294+
"csrc/attention/mla/cutlass_mla_entry.cu")
294295

295296
set_gencode_flags_for_srcs(
296297
SRCS "${VLLM_EXT_SRC}"
@@ -463,7 +464,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
463464
set(FP4_ARCHS)
464465
endif()
465466

466-
#
467+
# CUTLASS MLA Archs and flags
468+
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
469+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS)
470+
set(SRCS
471+
"csrc/attention/mla/cutlass_mla_kernels.cu")
472+
set_gencode_flags_for_srcs(
473+
SRCS "${SRCS}"
474+
CUDA_ARCHS "${MLA_ARCHS}")
475+
list(APPEND VLLM_EXT_SRC "${SRCS}")
476+
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1")
477+
# Add MLA-specific include directories only to MLA source files
478+
set_source_files_properties(${SRCS}
479+
PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_INCLUDE_DIR}/../examples/77_blackwell_fmha;${CUTLASS_INCLUDE_DIR}/../examples/common")
480+
message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}")
481+
else()
482+
message(STATUS "Not building CUTLASS MLA as no compatible archs were found.")
483+
# clear MLA_ARCHS
484+
set(MLA_ARCHS)
485+
endif()
486+
467487
# CUTLASS MoE kernels
468488

469489
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <torch/all.h>
18+
19+
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
20+
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
21+
torch::Tensor const& q_nope_and_q_pe,
22+
torch::Tensor const& kv_c_and_k_pe_cache,
23+
torch::Tensor const& seq_lens,
24+
torch::Tensor const& page_table);
25+
#endif
26+
27+
void cutlass_mla_decode(torch::Tensor const& out,
28+
torch::Tensor const& q_nope_and_q_pe,
29+
torch::Tensor const& kv_c_and_k_pe_cache,
30+
torch::Tensor const& seq_lens,
31+
torch::Tensor const& page_table) {
32+
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
33+
return cutlass_mla_decode_sm100a(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
34+
seq_lens, page_table);
35+
#endif
36+
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA");
37+
}
+220
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <torch/all.h>
18+
19+
#include <ATen/cuda/CUDAContext.h>
20+
#include <c10/cuda/CUDAGuard.h>
21+
22+
#include "cute/tensor.hpp"
23+
24+
#include "cutlass/cutlass.h"
25+
#include "cutlass/kernel_hardware_info.h"
26+
27+
#include "cutlass_extensions/common.hpp"
28+
29+
#include "device/sm100_mla.hpp"
30+
#include "kernel/sm100_mla_tile_scheduler.hpp"
31+
32+
using namespace cute;
33+
using namespace cutlass::fmha::kernel;
34+
35+
template <bool v>
36+
struct IsPersistent {
37+
static const bool value = v;
38+
};
39+
40+
template <typename T, typename PersistenceOption = IsPersistent<true>>
41+
struct MlaSm100 {
42+
using Element = T;
43+
using ElementAcc = float;
44+
using ElementOut = T;
45+
46+
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
47+
using TileShapeH = cute::tuple_element_t<0, TileShape>;
48+
using TileShapeD = cute::tuple_element_t<2, TileShape>;
49+
50+
// H K (D_latent D_rope) B
51+
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
52+
53+
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
54+
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
55+
using StrideO = StrideK; // H D B
56+
using StrideLSE = cute::tuple<_1, int>; // H B
57+
58+
using TileScheduler = std::conditional_t<PersistenceOption::value,
59+
Sm100MlaPersistentTileScheduler,
60+
Sm100MlaIndividualTileScheduler>;
61+
62+
using FmhaKernel =
63+
cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
64+
TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler,
65+
/*kIsCpAsync=*/true>;
66+
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
67+
};
68+
69+
template <typename T>
70+
typename T::Fmha::Arguments args_from_options(
71+
at::Tensor const& out, at::Tensor const& q_nope_and_q_pe,
72+
at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens,
73+
at::Tensor const& page_table) {
74+
cutlass::KernelHardwareInfo hw_info;
75+
hw_info.device_id = q_nope_and_q_pe.device().index();
76+
hw_info.sm_count =
77+
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
78+
hw_info.device_id);
79+
80+
int batches = q_nope_and_q_pe.sizes()[0];
81+
int page_count_per_seq = page_table.sizes()[1];
82+
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
83+
int page_size = kv_c_and_k_pe_cache.sizes()[1];
84+
int max_seq_len = page_size * page_count_per_seq;
85+
using TileShapeH = typename T::TileShapeH;
86+
using TileShapeD = typename T::TileShapeD;
87+
auto problem_shape =
88+
cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
89+
90+
auto [H, K, D, B] = problem_shape;
91+
auto [D_latent, D_rope] = D;
92+
93+
// the scale is based on the non-absorbed sizes, change as appropriate
94+
// we can't determine this parameter from the info we have, it's an input
95+
int D_non_latent = 128;
96+
float scale = 1.0 / sqrt(1.0 * (D_non_latent + D_rope));
97+
98+
using StrideQ = typename T::StrideQ;
99+
using StrideK = typename T::StrideK;
100+
using StrideO = typename T::StrideO;
101+
using StrideLSE = typename T::StrideLSE;
102+
103+
StrideQ stride_Q =
104+
cute::make_tuple(static_cast<int64_t>(D_latent + D_rope), _1{},
105+
static_cast<int64_t>(H * (D_latent + D_rope)));
106+
StrideK stride_C =
107+
cute::make_tuple(static_cast<int64_t>(D_latent + D_rope), _1{},
108+
static_cast<int64_t>(page_size * (D_latent + D_rope)));
109+
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
110+
StrideLSE stride_LSE = cute::make_tuple(_1{}, static_cast<int>(H));
111+
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(D_latent), _1{},
112+
static_cast<int64_t>(H * D_latent));
113+
114+
using Element = typename T::Element;
115+
using ElementOut = typename T::ElementOut;
116+
using ElementAcc = typename T::ElementAcc;
117+
auto Q_ptr = static_cast<Element*>(q_nope_and_q_pe.data_ptr());
118+
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
119+
typename T::Fmha::Arguments arguments{
120+
problem_shape,
121+
{scale, Q_ptr, stride_Q, Q_ptr + D_latent, stride_Q, C_ptr, stride_C,
122+
C_ptr + D_latent, stride_C, static_cast<int*>(seq_lens.data_ptr()),
123+
static_cast<int*>(page_table.data_ptr()), stride_PT, page_count_total,
124+
page_size},
125+
{static_cast<ElementOut*>(out.data_ptr()), stride_O,
126+
static_cast<ElementAcc*>(nullptr), stride_LSE},
127+
hw_info,
128+
-1, // split_kv
129+
nullptr, // is_var_split_kv
130+
};
131+
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
132+
// split_kv automatically based on batch size and sequence length to balance
133+
// workload across available SMs. Consider using var_split_kv for manual
134+
// control if needed.
135+
T::Fmha::set_split_kv(arguments);
136+
return arguments;
137+
}
138+
139+
template <typename Element>
140+
void runMla(at::Tensor const& out, at::Tensor const& q_nope_and_q_pe,
141+
at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens,
142+
at::Tensor const& page_table, cudaStream_t stream) {
143+
using MlaSm100Type = MlaSm100<Element>;
144+
typename MlaSm100Type::Fmha fmha;
145+
auto arguments = args_from_options<MlaSm100Type>(
146+
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table);
147+
size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments);
148+
auto const workspace_options = torch::TensorOptions()
149+
.dtype(torch::kUInt8)
150+
.device(q_nope_and_q_pe.device());
151+
auto workspace = torch::empty(workspace_size, workspace_options);
152+
153+
CUTLASS_CHECK(fmha.can_implement(arguments));
154+
155+
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
156+
157+
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
158+
}
159+
160+
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
161+
torch::Tensor const& q_nope_and_q_pe,
162+
torch::Tensor const& kv_c_and_k_pe_cache,
163+
torch::Tensor const& seq_lens,
164+
torch::Tensor const& page_table) {
165+
TORCH_CHECK(q_nope_and_q_pe.device().is_cuda(),
166+
"q_nope_and_q_pe must be on CUDA");
167+
TORCH_CHECK(q_nope_and_q_pe.dim() == 3,
168+
"q_nope_and_q_pe must be a 3D tensor");
169+
TORCH_CHECK(kv_c_and_k_pe_cache.dim() == 3,
170+
"kv_c_and_k_pe_cache must be a 3D tensor");
171+
TORCH_CHECK(seq_lens.dim() == 1, "seq_lens must be a 1D tensor");
172+
TORCH_CHECK(page_table.dim() == 2, "page_table must be a 2D tensor");
173+
174+
auto B_q = q_nope_and_q_pe.size(0);
175+
auto H = q_nope_and_q_pe.size(1);
176+
auto D_q = q_nope_and_q_pe.size(2);
177+
auto B_pt = page_table.size(0);
178+
auto PAGE_NUM = page_table.size(1);
179+
auto PAGE_SIZE = kv_c_and_k_pe_cache.size(1);
180+
auto D_ckv = kv_c_and_k_pe_cache.size(2);
181+
182+
TORCH_CHECK(D_q == D_ckv && D_q == 576,
183+
"D_q must be equal to D_ckv and D_q must be equal to 576");
184+
TORCH_CHECK(H == 128, "H must be 128");
185+
TORCH_CHECK(PAGE_SIZE > 0 && (PAGE_SIZE & (PAGE_SIZE - 1)) == 0,
186+
"PAGE_SIZE must be a power of 2");
187+
TORCH_CHECK(B_q == B_pt,
188+
"Batch dims must be same for page_table and q_nope_and_q_pe");
189+
TORCH_CHECK(PAGE_NUM % (128 / PAGE_SIZE) == 0,
190+
"PAGE_NUM must be divisible by 128 / PAGE_SIZE");
191+
192+
TORCH_CHECK(
193+
q_nope_and_q_pe.dtype() == at::ScalarType::Half ||
194+
q_nope_and_q_pe.dtype() == at::ScalarType::BFloat16 ||
195+
q_nope_and_q_pe.dtype() == at::ScalarType::Float8_e4m3fn,
196+
"q_nope_and_q_pe must be a half, bfloat16, or float8_e4m3fn tensor");
197+
TORCH_CHECK(kv_c_and_k_pe_cache.dtype() == q_nope_and_q_pe.dtype(),
198+
"kv_c_and_k_pe_cache must be the same type as q_nope_and_q_pe");
199+
TORCH_CHECK(seq_lens.dtype() == torch::kInt32,
200+
"seq_lens must be a 32-bit integer tensor");
201+
TORCH_CHECK(page_table.dtype() == torch::kInt32,
202+
"page_table must be a 32-bit integer tensor");
203+
204+
auto in_dtype = q_nope_and_q_pe.dtype();
205+
at::cuda::CUDAGuard device_guard{(char)q_nope_and_q_pe.get_device()};
206+
const cudaStream_t stream =
207+
at::cuda::getCurrentCUDAStream(q_nope_and_q_pe.get_device());
208+
if (in_dtype == at::ScalarType::Half) {
209+
runMla<cutlass::half_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens,
210+
page_table, stream);
211+
} else if (in_dtype == at::ScalarType::BFloat16) {
212+
runMla<cutlass::bfloat16_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
213+
seq_lens, page_table, stream);
214+
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
215+
runMla<cutlass::float_e4m3_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
216+
seq_lens, page_table, stream);
217+
} else {
218+
TORCH_CHECK(false, "Unsupported input data type of MLA");
219+
}
220+
}

csrc/ops.h

+6
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,12 @@ void advance_step_flashinfer(
128128
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
129129
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
130130

131+
void cutlass_mla_decode(torch::Tensor const& out,
132+
torch::Tensor const& q_nope_and_q_pe,
133+
torch::Tensor const& kv_c_and_k_pe_cache,
134+
torch::Tensor const& seq_lens,
135+
torch::Tensor const& page_table);
136+
131137
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
132138

133139
#ifndef USE_ROCM

csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ typename T::Gemm::Arguments args_from_options(
134134
using StrideB = typename T::StrideB;
135135
using StrideD = typename T::StrideD;
136136
using Sm100BlkScaledConfig =
137-
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
137+
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
138138

139139
int m = static_cast<int>(M);
140140
int n = static_cast<int>(N);

csrc/torch_bindings.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
130130
") -> ()");
131131
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
132132

133+
// Compute MLA decode using cutlass.
134+
ops.def(
135+
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe,"
136+
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
137+
" Tensor page_table) -> ()");
138+
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
139+
133140
// Layernorm
134141
// Apply Root Mean Square (RMS) Normalization to the input tensor.
135142
ops.def(

0 commit comments

Comments
 (0)