Make FA work for mla != 0

This commit is contained in:
Kawrakow
2026-01-20 07:58:31 +00:00
parent e6115f7241
commit 03c0629b3c
2 changed files with 17 additions and 5 deletions

View File

@@ -1999,9 +1999,9 @@ static void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ct
const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
constexpr int ncols = ncols1 * ncols2;
constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp.
constexpr int ntiles = ncols <= 8 && DKQ < 576 ? 1 : 2; // Number of tiles per warp.
constexpr int cols_per_warp = ntiles * tile_B::I;
constexpr int nwarps_max_x = ncols / cols_per_warp;
constexpr int nwarps_max_x = (ncols + cols_per_warp - 1) / cols_per_warp;
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
@@ -2063,6 +2063,10 @@ template <int DKQ, int DV, int ncols2>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
if constexpr (DKQ == 576 && ncols2 <= 4) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 4, ncols2>(ctx, dst);
} else {
if constexpr (ncols2 <= 8) {
if (Q->ne[1] <= 8/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
@@ -2081,6 +2085,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
}
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
}
}
template <int DKQ, int DV>
@@ -2156,8 +2161,15 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
return;
}
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
GGML_ASSERT(gqa_ratio % 16 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
if (gqa_ratio % 16 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
} else if (gqa_ratio % 4 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
} else {
GGML_ABORT("Unsupported GQA 576 x 512 case");
}
//GGML_ASSERT(gqa_ratio % 16 == 0);
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
//switch (Q->ne[0]) {
// case 64:

View File

@@ -180,7 +180,7 @@ bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_te
if (new_mma_available(cc) && (Q->ne[0] == 576 || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
if (Q->ne[0] == 576) {
int gqa_ratio = Q->ne[2]/K->ne[2];
return (gqa_ratio % 16) == 0;
return (gqa_ratio % 4) == 0;
}
return true;
}