Better GLM-4.7-Flash long context TG performance (#1182)

* Better GLM-4.7-Flash long context TG performance

* Handle quantized cache
This commit is contained in:
Kawrakow
2026-01-24 07:05:48 +02:00
committed by GitHub
parent 2a7cc09149
commit f0fb76da64

View File

@@ -13,6 +13,7 @@
#include "fattn-mma-f16-interface.cuh"
#include "fattn-new-mma.cuh"
#include "fattn.cuh"
#include "convert.cuh"
#include <cstdint>
@@ -106,6 +107,51 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
// Hence, we use it only for DeepSeek with MLA enabled, where head sizes are 576, 512,
// so no other implementation works.
//
if (new_mma_available(cc) && K->ne[0] == 576 && V->ne[0] == 512 && Q->ne[1] == 1 &&
Q->ne[2]/K->ne[2] == 20 && K->ne[1] > 8192) {
// GLM-4.7-Flash TG hack: split 20 heads into 16+4 heads
auto local_Q = *Q;
auto local_dst = *dst;
ggml_tensor local_K, local_V;
ggml_cuda_pool_alloc<half> K_f16(ctx.pool());
if (ggml_is_quantized(K->type)) {
// We need to dequantize here, else we will dequantize the same cache twice
K_f16.alloc(ggml_nelements(K));
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
to_fp16(K->data, K_f16.ptr, 1, ggml_nelements(K), ctx.stream());
auto bs = ggml_blck_size(K->type);
auto ts = ggml_type_size(K->type);
local_K = *K;
local_K.data = K_f16.get();
local_K.type = GGML_TYPE_F16;
local_K.nb[0] = sizeof(half);
local_K.nb[1] = local_K.nb[1]*bs*sizeof(half)/ts;
local_K.nb[2] = local_K.nb[2]*bs*sizeof(half)/ts;
local_K.nb[3] = local_K.nb[3]*bs*sizeof(half)/ts;
local_dst.src[1] = &local_K;
local_V = local_K;
local_V.ne[0] = V->ne[0];
local_dst.src[2] = &local_V;
}
local_Q.ne[2] = 16;
local_dst.ne[1] = 16;
local_dst.src[0] = &local_Q;
ggml_cuda_flash_attn_ext_mma_new(ctx, &local_dst);
local_Q.ne[2] = 4;
local_Q.data = (char *)local_Q.data + local_Q.nb[2]*16;
local_dst.ne[1] = 4;
local_dst.data = (char *)local_dst.data + local_dst.nb[1]*16;
ggml_cuda_flash_attn_ext_mma_new(ctx, &local_dst);
return;
}
if (new_mma_available(cc) && ((K->ne[0] == 576 && V->ne[0] == 512) || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
//printf("Using ggml_cuda_flash_attn_ext_mma_new\n");
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);