diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index c9a120976b1..a1d07bb29fd 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -130,13 +130,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ") -> ()"); ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer); - // Compute MLA decode using cutlass. - ops.def( - "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," - " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," - " Tensor page_table, float scale) -> ()"); - ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); - // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( @@ -361,6 +354,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); // conditionally compiled so impl registration is in source file + // Compute MLA decode using cutlass. + ops.def( + "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," + " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," + " Tensor page_table, float scale) -> ()"); + ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); + // CUTLASS nvfp4 block scaled GEMM ops.def( "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"