mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-07 20:40:02 +00:00
Better CUDA TG for GQA = 10 (#1221)
* Better CUDA TG for GQA = 10 * Cleanup
This commit is contained in:
@@ -2136,21 +2136,19 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
||||
if (K->ne[0] == 128 && (gqa_ratio == 12 || gqa_ratio == 6)) {
|
||||
if (K->ne[0] == 128) {
|
||||
GGML_ASSERT(Q->ne[0] == 128 && V->ne[0] == 128);
|
||||
//GGML_ASSERT(Q->ne[1] <= 4);
|
||||
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, 128, 16>(ctx, dst);
|
||||
if (gqa_ratio == 12) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 16>(ctx, dst);
|
||||
} else {
|
||||
} else if (gqa_ratio == 6) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 8>(ctx, dst);
|
||||
} else if (gqa_ratio == 10) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 16>(ctx, dst);
|
||||
} else {
|
||||
GGML_ABORT("Not implemented");
|
||||
}
|
||||
return;
|
||||
}
|
||||
//if (K->ne[0] == 64 && V->ne[0] == 64) {
|
||||
// ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst);
|
||||
// return;
|
||||
//}
|
||||
if (K->ne[0] == 192 && V->ne[0] == 128) {
|
||||
GGML_ASSERT(Q->ne[0] == 192);
|
||||
//GGML_ASSERT(gqa_ratio == 1); // Haha, this assert was for DeepSeek. But now we have Mimo2, which has GQA > 1
|
||||
|
||||
@@ -90,7 +90,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
|
||||
if (new_mma_available(cc) && K->ne[0] == 128 && V->ne[0] == 128 && Q->ne[0] == 128 && Q->ne[1] == 1 &&
|
||||
(Q->ne[2] / K->ne[2] == 12 || Q->ne[2] / K->ne[2] == 6)) {
|
||||
(Q->ne[2] / K->ne[2] == 12 || Q->ne[2] / K->ne[2] == 6 || Q->ne[2] / K->ne[2] == 10)) {
|
||||
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user