Skip to content

FA3 Decode Perf - Use single mma warp group for decode batches #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ inline int get_num_splits(Flash_fwd_params const& params) {
// params.page_table must already be set
// This needs to match the kernel configs
bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k;
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params));
// Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits
// has not been set here. It's OK though because we might just underestimate kBlockN a bit
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);
Expand Down Expand Up @@ -585,9 +585,11 @@ mha_fwd_get_scheduler_metadata(
params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);

params.pagedkv_tma = get_pagedkv_tma(params);
params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
// Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide
// Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits)
params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
// Always enable PackGQA for Split
params.pack_gqa = params.num_splits > 1;

bool is_varlen = true;

Expand All @@ -611,7 +613,7 @@ mha_fwd_get_scheduler_metadata(
}

if (params.num_splits_dynamic_ptr) {
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params));
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr);
int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
Expand Down Expand Up @@ -725,7 +727,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1);
int const num_pages = !paged_KV ? 0 : k.size(0);
int const page_size = !paged_KV ? 1 : k.size(1);
int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value();
int const seqlen_k = !max_seqlen_k_.has_value() ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value();
Comment on lines -728 to +730
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's up with this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When seqused_k is used (which is what's required for paged kv-caches) instead of cu_seqlens_k is_varlen_k is false but we frequently have max_seqlen_k, so using that instead here prevents us from overestimating the number of splits. max_seqlen_k is also what the aot scheduler uses so this resolves this mismatch, meaning we end up picking a more efficient combine kernel (tighter num_split bound). I thinks this line is actually worth upstreaming, good catch!

int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
int const num_heads_k = k.size(-2);
int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0);
Expand Down Expand Up @@ -938,9 +940,11 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);

params.pagedkv_tma = get_pagedkv_tma(params);
params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
// Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide
// Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits)
params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
// Always enable PackGQA for Split
params.pack_gqa = params.num_splits > 1;

// This needs to be set after get_num_splits
at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic
Expand Down
32 changes: 19 additions & 13 deletions hopper/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
#include "mainloop_fwd_sm90_tma_gmma_ws.hpp"
#include "mainloop_fwd_sm80.hpp"
#include "epilogue_fwd.hpp"
#include "heuristics.h"

using namespace cute;

template <int Arch, int kHeadDim, int kHeadDimV, int ClusterM, typename Element, typename ElementOut,
bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool PagedKVNonTMA, bool AppendKV, bool HasQv,
bool PackGQA, bool Split, bool V_colmajor>
bool PackGQA, bool Split, bool V_colmajor, bool Use_one_mma_wg>
void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time");
static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time");
Expand All @@ -36,7 +37,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;

// Can't use structured binding since it's not compatible with constexpr
static constexpr std::tuple<int, int, bool, bool> kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap);
static constexpr std::tuple<int, int, bool, bool> kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg);
static constexpr std::tuple<int, int, int, int, bool> kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV);
static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS);
static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS);
Expand Down Expand Up @@ -203,17 +204,22 @@ void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream) {
VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] {
static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1;
VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] {
// Only needed here to decide if we should use cluster
static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128;

static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen;
BOOL_SWITCH(params.qv_ptr, HasQV_, [&] {
static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256;
APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] {
// Only use Cluster if number of tiles along seqlen_q is even and not varlen
CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] {
static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1;
run_flash_fwd<Arch, kHeadDim, kHeadDimV, ClusterM, T, T_out, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV && Varlen, HasQv, PackGQA, Split, V_colmajor>(params, stream);
BOOL_SWITCH(use_one_mma_wg(params), Use_one_mma_wg_, [&] {
// Avoid over compiliation by making sure this only get set if it is actually used, i.e. we currently only support one mma wg for 128 head dim and hopper
static constexpr bool Use_one_mma_wg = Use_one_mma_wg_ && Arch >= 90 && kHeadDim == 128;

// Only needed here to decide if we should use cluster
static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg)) : 128;

static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen;
BOOL_SWITCH(params.qv_ptr, HasQV_, [&] {
static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256;
APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] {
// Only use Cluster if number of tiles along seqlen_q is even and not varlen
CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] {
static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1;
run_flash_fwd<Arch, kHeadDim, kHeadDimV, ClusterM, T, T_out, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV && Varlen, HasQv, PackGQA, Split, V_colmajor, Use_one_mma_wg>(params, stream);
});
});
});
});
Expand Down
6 changes: 6 additions & 0 deletions hopper/heuristics.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
#pragma once

#include <vector>
#include "flash.h"

inline bool use_one_mma_wg(Flash_fwd_params const& params) {
return params.arch >= 90 && params.d == 128 &&
params.seqlen_q * (!params.pack_gqa ? 1 : params.h / params.h_k) <= 64;
};

inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) {
// If varlen, we don't actually know seqlen_q but only max_seqlen_q.
Expand Down
8 changes: 6 additions & 2 deletions hopper/tile_size.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
// Return {kBlockM, kBlockN, MmaPV_is_RS, IntraWGOverlap}
constexpr std::tuple<int, int, bool, bool> tile_size_fwd_sm90(
int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2,
bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false) {
bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false, bool use_one_mma_wg=false) {
if (element_size == 2) {
if (headdim <= 64) {
// return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim};
Expand All @@ -29,7 +29,11 @@ constexpr std::tuple<int, int, bool, bool> tile_size_fwd_sm90(
} else if (headdim <= 96) {
return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true};
} else if (headdim <= 128) {
return {128, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true};
if (use_one_mma_wg) {
return {64, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true};
} else {
return {128, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true};
}
// {128, 192, false, false} and {192, 128, false, true} are quite good too
// 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS
} else if (headdim <= 192) {
Expand Down