mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
Much faster long context TG for Minimax-M2 (#1194)
This commit is contained in:
@@ -2148,11 +2148,15 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
|||||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||||
|
|
||||||
if (K->ne[0] == 128 && gqa_ratio == 12) {
|
if (K->ne[0] == 128 && (gqa_ratio == 12 || gqa_ratio == 6)) {
|
||||||
GGML_ASSERT(Q->ne[0] == 128 && V->ne[0] == 128);
|
GGML_ASSERT(Q->ne[0] == 128 && V->ne[0] == 128);
|
||||||
//GGML_ASSERT(Q->ne[1] <= 4);
|
//GGML_ASSERT(Q->ne[1] <= 4);
|
||||||
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, 128, 16>(ctx, dst);
|
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, 128, 16>(ctx, dst);
|
||||||
ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 16>(ctx, dst);
|
if (gqa_ratio == 12) {
|
||||||
|
ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 16>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 8>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
//if (K->ne[0] == 64 && V->ne[0] == 64) {
|
//if (K->ne[0] == 64 && V->ne[0] == 64) {
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (new_mma_available(cc) && K->ne[0] == 128 && V->ne[0] == 128 && Q->ne[0] == 128 && Q->ne[1] == 1 &&
|
if (new_mma_available(cc) && K->ne[0] == 128 && V->ne[0] == 128 && Q->ne[0] == 128 && Q->ne[1] == 1 &&
|
||||||
Q->ne[2] / K->ne[2] == 12) {
|
(Q->ne[2] / K->ne[2] == 12 || Q->ne[2] / K->ne[2] == 6)) {
|
||||||
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
|
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user