@@ -130,13 +130,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
130
130
" ) -> ()" );
131
131
ops.impl (" advance_step_flashinfer" , torch::kCUDA , &advance_step_flashinfer);
132
132
133
- // Compute MLA decode using cutlass.
134
- ops.def (
135
- " cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
136
- " Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
137
- " Tensor page_table, float scale) -> ()" );
138
- ops.impl (" cutlass_mla_decode" , torch::kCUDA , &cutlass_mla_decode);
139
-
140
133
// Layernorm
141
134
// Apply Root Mean Square (RMS) Normalization to the input tensor.
142
135
ops.def (
@@ -450,6 +443,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
450
443
ops.def (" cutlass_sparse_compress(Tensor a) -> Tensor[]" );
451
444
ops.impl (" cutlass_sparse_compress" , &cutlass_sparse_compress);
452
445
446
+ // CUTLASS MLA decode
447
+ ops.def (
448
+ " cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
449
+ " Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
450
+ " Tensor page_table, float scale) -> ()" );
451
+ ops.impl (" cutlass_mla_decode" , torch::kCUDA , &cutlass_mla_decode);
452
+
453
453
// Mamba selective scan kernel
454
454
ops.def (
455
455
" selective_scan_fwd(Tensor! u, Tensor! delta,"
0 commit comments