Better CUDA TG for GQA = 10 (#1221)

* Better CUDA TG for GQA = 10

* Cleanup
This commit is contained in:
Kawrakow
2026-02-03 09:18:46 +02:00
committed by GitHub
parent 7e8d444033
commit f8acfc2bf0
2 changed files with 7 additions and 9 deletions

View File

@@ -2136,21 +2136,19 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
if (K->ne[0] == 128 && (gqa_ratio == 12 || gqa_ratio == 6)) {
if (K->ne[0] == 128) {
GGML_ASSERT(Q->ne[0] == 128 && V->ne[0] == 128);
//GGML_ASSERT(Q->ne[1] <= 4);
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, 128, 16>(ctx, dst);
if (gqa_ratio == 12) {
ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 16>(ctx, dst);
} else {
} else if (gqa_ratio == 6) {
ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 8>(ctx, dst);
} else if (gqa_ratio == 10) {
ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 16>(ctx, dst);
} else {
GGML_ABORT("Not implemented");
}
return;
}
//if (K->ne[0] == 64 && V->ne[0] == 64) {
// ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst);
// return;
//}
if (K->ne[0] == 192 && V->ne[0] == 128) {
GGML_ASSERT(Q->ne[0] == 192);
//GGML_ASSERT(gqa_ratio == 1); // Haha, this assert was for DeepSeek. But now we have Mimo2, which has GQA > 1

View File

@@ -90,7 +90,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
}
if (new_mma_available(cc) && K->ne[0] == 128 && V->ne[0] == 128 && Q->ne[0] == 128 && Q->ne[1] == 1 &&
(Q->ne[2] / K->ne[2] == 12 || Q->ne[2] / K->ne[2] == 6)) {
(Q->ne[2] / K->ne[2] == 12 || Q->ne[2] / K->ne[2] == 6 || Q->ne[2] / K->ne[2] == 10)) {
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
return;
}