Skip to content

Commit ad60e14

Browse files
tywuAMDMu Huai
authored and
Mu Huai
committed
[Misc][ROCm] Exclude cutlass_mla_decode for ROCm build (vllm-project#17289)
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
1 parent 13eef06 commit ad60e14

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

csrc/torch_bindings.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
130130
") -> ()");
131131
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
132132

133-
// Compute MLA decode using cutlass.
134-
ops.def(
135-
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
136-
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
137-
" Tensor page_table, float scale) -> ()");
138-
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
139-
140133
// Layernorm
141134
// Apply Root Mean Square (RMS) Normalization to the input tensor.
142135
ops.def(
@@ -450,6 +443,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
450443
ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]");
451444
ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress);
452445

446+
// CUTLASS MLA decode
447+
ops.def(
448+
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
449+
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
450+
" Tensor page_table, float scale) -> ()");
451+
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
452+
453453
// Mamba selective scan kernel
454454
ops.def(
455455
"selective_scan_fwd(Tensor! u, Tensor! delta,"

0 commit comments

Comments
 (0)