mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-22 14:14:32 +00:00
Much faster long context TG for Minimax-M2
This commit is contained in:
@@ -2148,11 +2148,15 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
||||
if (K->ne[0] == 128 && gqa_ratio == 12) {
|
||||
if (K->ne[0] == 128 && (gqa_ratio == 12 || gqa_ratio == 6)) {
|
||||
GGML_ASSERT(Q->ne[0] == 128 && V->ne[0] == 128);
|
||||
//GGML_ASSERT(Q->ne[1] <= 4);
|
||||
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, 128, 16>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 16>(ctx, dst);
|
||||
if (gqa_ratio == 12) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 16>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 8>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
//if (K->ne[0] == 64 && V->ne[0] == 64) {
|
||||
|
||||
@@ -90,7 +90,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
|
||||
if (new_mma_available(cc) && K->ne[0] == 128 && V->ne[0] == 128 && Q->ne[0] == 128 && Q->ne[1] == 1 &&
|
||||
Q->ne[2] / K->ne[2] == 12) {
|
||||
(Q->ne[2] / K->ne[2] == 12 || Q->ne[2] / K->ne[2] == 6)) {
|
||||
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user