From 478b56871f7bcf89eb13f10edca326f6f60da5f9 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 26 Jan 2026 07:21:47 +0200 Subject: [PATCH] Faster long context TG on CUDA for GLM-4.5/4.6/4.7/AIR (part 2) (#1190) * This works * Make quantized KV cache work * Remove the glm45 graph building changes * Add condition --- ggml/src/ggml-cuda/fattn-mma-f16.cu | 154 ++++++++++++++++++++++++++++ src/llama-build-context.cpp | 105 ++++++------------- 2 files changed, 183 insertions(+), 76 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cu b/ggml/src/ggml-cuda/fattn-mma-f16.cu index 539d7728..07bdabd3 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cu @@ -1,6 +1,154 @@ #include "fattn-mma-f16.cuh" #include "fattn-mma-f16-interface.cuh" +static __global__ void k_repack_q(int nelements, int ne0, int ne0_1, const float * src, float * dst1, float * dst2) { + int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i >= nelements) { + return; + } + int row = i / ne0; + int i0 = i % ne0; + if (i0 < ne0_1) { + dst1[row*ne0_1 + i0] = src[i]; + } else { + dst2[row*(ne0 - ne0_1) + i0 - ne0_1] = src[i]; + } +} + +static __global__ void k_pack_fa(const float * x, const float * y, float * dst, int ne0, int ne00, int nelem) { + int i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= nelem) { + return; + } + + int row = i / ne0; + int i0 = i % ne0; + + if (i0 < ne00) { + dst[row*ne0 + i0] = x[row*ne00 + i0]; + } else { + dst[row*ne0 + i0] = y[row*(ne0 - ne00) + i0 - ne00]; + } +} + + +static void repack_q(const ggml_tensor * q, float * dst, int nhead1, int nhead2, int nek2, cudaStream_t stream) { + constexpr int kBlockSize = 256; + GGML_ASSERT((nhead1 + nhead2)*nek2 == q->ne[2]); + int ne0 = q->ne[0] * (nhead1 + nhead2); // we know that Q is contiguous along the second dimension + int ne0_1 = q->ne[0] * nhead1; + int nelements = ne0 * q->ne[1] * q->ne[3] * nek2; + int nblocks = (nelements + kBlockSize - 1)/kBlockSize; + auto dst1 = dst; + auto dst2 = dst + ne0_1 * q->ne[1] * q->ne[3] * nek2; + k_repack_q<<>>(nelements, ne0, ne0_1, (const float *)q->data, dst1, dst2); +} + +static void pack_glm45_result(const ggml_tensor * fa1, const ggml_tensor * fa2, ggml_tensor * dst, cudaStream_t stream) { + constexpr int kBlockSize = 256; + GGML_ASSERT(dst->ne[1] % 12 == 0); + GGML_ASSERT(fa1->ne[0] == fa2->ne[0] && fa1->ne[0] == dst->ne[0]); + GGML_ASSERT(fa1->ne[1] + fa2->ne[1] == dst->ne[1]); + GGML_ASSERT(fa1->ne[2] == fa2->ne[2] && fa1->ne[2] == dst->ne[2]); + GGML_ASSERT(fa1->ne[3] == fa2->ne[3] && fa1->ne[3] == dst->ne[3]); + GGML_ASSERT(fa1->type == GGML_TYPE_F32 && fa2->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + int ne0 = dst->ne[0] * 12; + int ne00 = dst->ne[0] * 8; + int nelem = ne0 * dst->ne[1]/12 * dst->ne[2] * dst->ne[3]; + int nblocks = (nelem + kBlockSize - 1)/kBlockSize; + k_pack_fa<<>>((const float *)fa1->data, (const float *)fa2->data, (float *)dst->data, ne0, ne00, nelem); +} + +static inline ggml_tensor get_float_tensor(int ne0, int ne1, int ne2, int ne3) { + return {GGML_TYPE_F32, {}, nullptr, {ne0, ne1, ne2, ne3}, + {sizeof(float), ne0*sizeof(float), ne0*ne1*sizeof(float), ne0*ne1*ne2*sizeof(float)}, + GGML_OP_NONE, {}, 0, nullptr, {}, nullptr, 0, nullptr, {}, nullptr}; +} +static inline void permute_21(ggml_tensor & t) { + auto tmp1 = t.ne[1]; t.ne[1] = t.ne[2]; t.ne[2] = tmp1; + auto tmp2 = t.nb[1]; t.nb[1] = t.nb[2]; t.nb[2] = tmp2; +} + +static void glm45_flash_attention(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + auto Q = dst->src[0]; + auto K = dst->src[1]; + auto V = dst->src[2]; + GGML_ASSERT(Q->ne[2] / K->ne[2] == 12); + + ggml_cuda_pool_alloc q_data(ctx.pool(), ggml_nelements(Q)); + ggml_cuda_pool_alloc dst_data(ctx.pool(), ggml_nelements(dst)); + ggml_cuda_pool_alloc k_data(ctx.pool()); + ggml_cuda_pool_alloc v_data(ctx.pool()); + + repack_q(Q, q_data.get(), 8, 4, K->ne[2], ctx.stream()); + + auto local_Q1 = get_float_tensor(Q->ne[0], 8*K->ne[2], Q->ne[1], Q->ne[3]); + permute_21(local_Q1); + local_Q1.data = q_data.get(); + + auto local_Q2 = get_float_tensor(Q->ne[0], 4*K->ne[2], Q->ne[1], Q->ne[3]); + permute_21(local_Q2); + local_Q2.data = q_data.get() + ggml_nelements(&local_Q1); + + GGML_ASSERT(ggml_nelements(Q) == ggml_nelements(&local_Q1) + ggml_nelements(&local_Q2)); + + auto local_K = *K; + auto local_V = *V; + + if (K->type != GGML_TYPE_F16) { + auto nelem = ggml_nelements(K); + k_data.alloc(nelem); + auto to_fp_16 = ggml_get_to_fp16_cuda(K->type); + to_fp_16(K->data, k_data.get(), 1, nelem, ctx.stream()); + local_K.type = GGML_TYPE_F16; + local_K.data = k_data.get(); + auto ts = ggml_type_size(K->type); + auto bs = ggml_blck_size(K->type); + local_K.nb[0] = sizeof(half); + local_K.nb[1] = sizeof(half)*bs * local_K.nb[1]/ts; + local_K.nb[2] = sizeof(half)*bs * local_K.nb[2]/ts; + local_K.nb[3] = sizeof(half)*bs * local_K.nb[3]/ts; + } + if (V->type != GGML_TYPE_F16) { + auto nelem = ggml_nelements(V); + v_data.alloc(nelem); + auto to_fp_16 = ggml_get_to_fp16_cuda(V->type); + to_fp_16(V->data, v_data.get(), 1, nelem, ctx.stream()); + local_V.type = GGML_TYPE_F16; + local_V.data = v_data.get(); + auto ts = ggml_type_size(V->type); + auto bs = ggml_blck_size(V->type); + local_V.nb[0] = sizeof(half); + local_V.nb[1] = sizeof(half)*bs * local_V.nb[1]/ts; + local_V.nb[2] = sizeof(half)*bs * local_V.nb[2]/ts; + local_V.nb[3] = sizeof(half)*bs * local_V.nb[3]/ts; + } + + constexpr int n_op_params = GGML_MAX_OP_PARAMS / sizeof(int); + + auto fa1 = get_float_tensor(V->ne[0], local_Q1.ne[2], local_Q1.ne[1], local_Q1.ne[3]); + fa1.data = dst_data.get(); + fa1.op = GGML_OP_FLASH_ATTN_EXT; + fa1.src[0] = &local_Q1; + fa1.src[1] = &local_K; + fa1.src[2] = &local_V; + for (int i = 3; i < GGML_MAX_SRC; ++i) fa1.src[i] = dst->src[i]; + for (int i = 0; i < n_op_params; ++i) fa1.op_params[i] = dst->op_params[i]; + + auto fa2 = get_float_tensor(V->ne[0], local_Q2.ne[2], local_Q2.ne[1], local_Q2.ne[3]); + fa2.data = dst_data.get() + ggml_nelements(&fa1); + fa2.op = GGML_OP_FLASH_ATTN_EXT; + fa2.src[0] = &local_Q2; + fa2.src[1] = &local_K; + fa2.src[2] = &local_V; + for (int i = 3; i < GGML_MAX_SRC; ++i) fa2.src[i] = dst->src[i]; + for (int i = 0; i < n_op_params; ++i) fa2.op_params[i] = dst->op_params[i]; + + ggml_cuda_flash_attn_ext_mma_f16(ctx, &fa1); + ggml_cuda_flash_attn_ext_mma_f16(ctx, &fa2); + pack_glm45_result(&fa1, &fa2, dst, ctx.stream()); +} + template static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; @@ -69,6 +217,12 @@ void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tens GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; + if (gqa_ratio == 12 && Q->ne[1] == 1 && K->ne[1]*K->ne[2] >= 65536) { + // This is a hack to improve GLM-4.5/4.6/4.7/AIR TG performance + glm45_flash_attention(ctx, dst); + return; + } + if (use_gqa_opt && gqa_ratio % 8 == 0) { ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); return; diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index a1586849..19f4bc27 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1337,40 +1337,6 @@ llm_expert_gating_func_type gating_op, return cur; } -static ggml_tensor * build_glm45_fa(ggml_context * ctx, ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, - ggml_tensor * kq_mask, float kq_scale, bool should_use_f32_precision) { - - auto ne1 = 8*v->ne[0]; - auto ne2 = 4*v->ne[0]; - - ggml_tensor *q1, *q2; - if (q->ne[1] == 1 && k->ne[2] == 1) { - q1 = ggml_view_3d(ctx, q, q->ne[0], 1, 8, q->nb[1], q->nb[2], 0); - q2 = ggml_view_3d(ctx, q, q->ne[0], 1, 4, q->nb[1], q->nb[2], 8*q->ne[0]*ggml_element_size(q)); - } else { - q1 = ggml_view_3d(ctx, q, q->ne[0], 8, k->ne[2]*q->ne[1], q->nb[2], q->nb[1]/k->ne[2], 0); - q2 = ggml_view_3d(ctx, q, q->ne[0], 4, k->ne[2]*q->ne[1], q->nb[2], q->nb[1]/k->ne[2], 8*q->ne[0]*ggml_element_size(q)); - q1 = ggml_reshape_3d(ctx, ggml_cont(ctx, q1), q->ne[0], 8*k->ne[2], q->ne[1]); - q2 = ggml_reshape_3d(ctx, ggml_cont(ctx, q2), q->ne[0], 4*k->ne[2], q->ne[1]); - q1 = ggml_permute(ctx, q1, 0, 2, 1, 3); - q2 = ggml_permute(ctx, q2, 0, 2, 1, 3); - } - - auto fa1 = ggml_flash_attn_ext(ctx, q1, k, v, kq_mask, kq_scale, 0.0f, 0.0f); - if (should_use_f32_precision) { - ggml_flash_attn_ext_set_prec(fa1, GGML_PREC_F32); - } - fa1 = ggml_reshape_2d(ctx, fa1, ne1, ggml_nelements(fa1)/ne1); - - auto fa2 = ggml_flash_attn_ext(ctx, q2, k, v, kq_mask, kq_scale, 0.0f, 0.0f); - if (should_use_f32_precision) { - ggml_flash_attn_ext_set_prec(fa2, GGML_PREC_F32); - } - fa2 = ggml_reshape_2d(ctx, fa2, ne2, ggml_nelements(fa2)/ne2); - - return ggml_concat(ctx, fa1, fa2, 0); -} - static ggml_tensor * llm_build_kqv( struct ggml_context * ctx, struct llama_context & lctx, @@ -1441,28 +1407,21 @@ static ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - if (q->ne[1] == 1 && k->ne[1] >= 8192 && q->ne[2] / k->ne[2] == 12 && !sinks && n_swa == 0 && - k->view_src && k->view_src->buffer && !ggml_backend_buffer_is_host(k->view_src->buffer) && - k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16) { - cur = build_glm45_fa(ctx, q, k, v, kq_mask, kq_scale, should_use_f32_precision); - } else { - - cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, - hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); - ggml_flash_attn_ext_add_sinks(cur, sinks); - if (n_swa > 0) { - ((int32_t *)cur->op_params)[4] = n_swa; - } - - // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA - // For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG. - // Not sure if it is really a matter of insufficient precision, or I have made a mistake in the fattn-vec-f16 kernel. - if (should_use_f32_precision) { - ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); - } - //ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, + hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); + ggml_flash_attn_ext_add_sinks(cur, sinks); + if (n_swa > 0) { + ((int32_t *)cur->op_params)[4] = n_swa; } + // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA + // For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG. + // Not sure if it is really a matter of insufficient precision, or I have made a mistake in the fattn-vec-f16 kernel. + if (should_use_f32_precision) { + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + } + //ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens); } else { @@ -9390,29 +9349,23 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens ggml_row_size(split_vl->type, n_embd_head_v), 0); cb(v, "v", il_cb); - if (q->ne[1] == 1 && k->ne[1] >= 65536/k->ne[2] && q->ne[2] / k->ne[2] == 12 && !sinks && n_swa == 0 && - k->view_src && k->view_src->buffer && !ggml_backend_buffer_is_host(k->view_src->buffer) && - k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16) { - cur = build_glm45_fa(ctx0, q, k, v, KQ_mask, KQ_scale, should_use_f32_precision); + cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, KQ_scale, hparams.f_max_alibi_bias, + hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); + cb(cur, "flash_attn", il_cb); + if (model.layers[il].attn_sinks && model.layers[il].attn_sinks->extra) { + auto split = (ggml_split_tensor_t *)model.layers[il].attn_sinks->extra; + GGML_ASSERT(split->n_device == wq->n_device); + GGML_ASSERT(split->splits[id]); + ggml_flash_attn_ext_add_sinks(cur, split->splits[id]); } else { - cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, KQ_scale, hparams.f_max_alibi_bias, - hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); - cb(cur, "flash_attn", il_cb); - if (model.layers[il].attn_sinks && model.layers[il].attn_sinks->extra) { - auto split = (ggml_split_tensor_t *)model.layers[il].attn_sinks->extra; - GGML_ASSERT(split->n_device == wq->n_device); - GGML_ASSERT(split->splits[id]); - ggml_flash_attn_ext_add_sinks(cur, split->splits[id]); - } else { - ggml_flash_attn_ext_add_sinks(cur, sinks); - } - if (n_swa > 0) { - ((int32_t *)cur->op_params)[4] = n_swa; - } - // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA - if (should_use_f32_precision) { - ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); - } + ggml_flash_attn_ext_add_sinks(cur, sinks); + } + if (n_swa > 0) { + ((int32_t *)cur->op_params)[4] = n_swa; + } + // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA + if (should_use_f32_precision) { + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); } cur = ggml_reshape_2d(ctx0, cur, split_wo->ne[0], n_tokens);