diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index f3cb2875..a1586849 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1337,6 +1337,40 @@ llm_expert_gating_func_type gating_op, return cur; } +static ggml_tensor * build_glm45_fa(ggml_context * ctx, ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, + ggml_tensor * kq_mask, float kq_scale, bool should_use_f32_precision) { + + auto ne1 = 8*v->ne[0]; + auto ne2 = 4*v->ne[0]; + + ggml_tensor *q1, *q2; + if (q->ne[1] == 1 && k->ne[2] == 1) { + q1 = ggml_view_3d(ctx, q, q->ne[0], 1, 8, q->nb[1], q->nb[2], 0); + q2 = ggml_view_3d(ctx, q, q->ne[0], 1, 4, q->nb[1], q->nb[2], 8*q->ne[0]*ggml_element_size(q)); + } else { + q1 = ggml_view_3d(ctx, q, q->ne[0], 8, k->ne[2]*q->ne[1], q->nb[2], q->nb[1]/k->ne[2], 0); + q2 = ggml_view_3d(ctx, q, q->ne[0], 4, k->ne[2]*q->ne[1], q->nb[2], q->nb[1]/k->ne[2], 8*q->ne[0]*ggml_element_size(q)); + q1 = ggml_reshape_3d(ctx, ggml_cont(ctx, q1), q->ne[0], 8*k->ne[2], q->ne[1]); + q2 = ggml_reshape_3d(ctx, ggml_cont(ctx, q2), q->ne[0], 4*k->ne[2], q->ne[1]); + q1 = ggml_permute(ctx, q1, 0, 2, 1, 3); + q2 = ggml_permute(ctx, q2, 0, 2, 1, 3); + } + + auto fa1 = ggml_flash_attn_ext(ctx, q1, k, v, kq_mask, kq_scale, 0.0f, 0.0f); + if (should_use_f32_precision) { + ggml_flash_attn_ext_set_prec(fa1, GGML_PREC_F32); + } + fa1 = ggml_reshape_2d(ctx, fa1, ne1, ggml_nelements(fa1)/ne1); + + auto fa2 = ggml_flash_attn_ext(ctx, q2, k, v, kq_mask, kq_scale, 0.0f, 0.0f); + if (should_use_f32_precision) { + ggml_flash_attn_ext_set_prec(fa2, GGML_PREC_F32); + } + fa2 = ggml_reshape_2d(ctx, fa2, ne2, ggml_nelements(fa2)/ne2); + + return ggml_concat(ctx, fa1, fa2, 0); +} + static ggml_tensor * llm_build_kqv( struct ggml_context * ctx, struct llama_context & lctx, @@ -1381,6 +1415,17 @@ static ggml_tensor * llm_build_kqv( constexpr bool use_f32_precision = false; #endif + bool should_use_f32_precision = use_f32_precision + || model.arch == LLM_ARCH_PHI2 + || model.arch == LLM_ARCH_PHI3 + || model.arch == LLM_ARCH_GPTNEOX + || model.arch == LLM_ARCH_QWEN2 + || model.arch == LLM_ARCH_COHERE2 + || model.arch == LLM_ARCH_GLM4 + || model.arch == LLM_ARCH_GLM4_MOE + || model.arch == LLM_ARCH_MIMO2; + // || (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8); + struct ggml_tensor * cur; if (cparams.flash_attn) { @@ -1396,21 +1441,27 @@ static ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, - hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); - ggml_flash_attn_ext_add_sinks(cur, sinks); - if (n_swa > 0) { - ((int32_t *)cur->op_params)[4] = n_swa; - } + if (q->ne[1] == 1 && k->ne[1] >= 8192 && q->ne[2] / k->ne[2] == 12 && !sinks && n_swa == 0 && + k->view_src && k->view_src->buffer && !ggml_backend_buffer_is_host(k->view_src->buffer) && + k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16) { + cur = build_glm45_fa(ctx, q, k, v, kq_mask, kq_scale, should_use_f32_precision); + } else { - // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA - // For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG. - // Not sure if it is really a matter of insufficient precision, or I have made a mistake in the fattn-vec-f16 kernel. - if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || - (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE) { - ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, + hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); + ggml_flash_attn_ext_add_sinks(cur, sinks); + if (n_swa > 0) { + ((int32_t *)cur->op_params)[4] = n_swa; + } + + // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA + // For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG. + // Not sure if it is really a matter of insufficient precision, or I have made a mistake in the fattn-vec-f16 kernel. + if (should_use_f32_precision) { + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + } + //ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); } - //ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens); } else { @@ -1431,8 +1482,7 @@ static ggml_tensor * llm_build_kqv( //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || - model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE || model.arch == LLM_ARCH_MIMO2) { + if (should_use_f32_precision) { // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 ggml_mul_mat_set_prec(kq, GGML_PREC_F32); @@ -9182,6 +9232,22 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens float freq_base_l = n_swa > 0 ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; float freq_scale_l = n_swa > 0 ? hparams.rope_freq_scale_train_swa : hparams.rope_freq_scale_train; +#ifdef GGML_USE_VULKAN + constexpr bool use_f32_precision = true; +#else + constexpr bool use_f32_precision = false; +#endif + + bool should_use_f32_precision = use_f32_precision + || model.arch == LLM_ARCH_PHI2 + || model.arch == LLM_ARCH_PHI3 + || model.arch == LLM_ARCH_GPTNEOX + || model.arch == LLM_ARCH_QWEN2 + || model.arch == LLM_ARCH_COHERE2 + || model.arch == LLM_ARCH_GLM4 + // || model.arch == LLM_ARCH_GLM4_MOE + || model.arch == LLM_ARCH_MIMO2; + // || (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8); if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn && model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) { @@ -9324,30 +9390,29 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens ggml_row_size(split_vl->type, n_embd_head_v), 0); cb(v, "v", il_cb); -#ifdef GGML_USE_VULKAN - constexpr bool use_f32_precision = true; -#else - constexpr bool use_f32_precision = false; -#endif - cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, KQ_scale, hparams.f_max_alibi_bias, - hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); - cb(cur, "flash_attn", il_cb); - if (model.layers[il].attn_sinks && model.layers[il].attn_sinks->extra) { - auto split = (ggml_split_tensor_t *)model.layers[il].attn_sinks->extra; - GGML_ASSERT(split->n_device == wq->n_device); - GGML_ASSERT(split->splits[id]); - ggml_flash_attn_ext_add_sinks(cur, split->splits[id]); + if (q->ne[1] == 1 && k->ne[1] >= 65536/k->ne[2] && q->ne[2] / k->ne[2] == 12 && !sinks && n_swa == 0 && + k->view_src && k->view_src->buffer && !ggml_backend_buffer_is_host(k->view_src->buffer) && + k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16) { + cur = build_glm45_fa(ctx0, q, k, v, KQ_mask, KQ_scale, should_use_f32_precision); } else { - ggml_flash_attn_ext_add_sinks(cur, sinks); - } - if (n_swa > 0) { - ((int32_t *)cur->op_params)[4] = n_swa; - } - // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA - if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || - (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || - model.arch == LLM_ARCH_GLM4_MOE) { - ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, KQ_scale, hparams.f_max_alibi_bias, + hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); + cb(cur, "flash_attn", il_cb); + if (model.layers[il].attn_sinks && model.layers[il].attn_sinks->extra) { + auto split = (ggml_split_tensor_t *)model.layers[il].attn_sinks->extra; + GGML_ASSERT(split->n_device == wq->n_device); + GGML_ASSERT(split->splits[id]); + ggml_flash_attn_ext_add_sinks(cur, split->splits[id]); + } else { + ggml_flash_attn_ext_add_sinks(cur, sinks); + } + if (n_swa > 0) { + ((int32_t *)cur->op_params)[4] = n_swa; + } + // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA + if (should_use_f32_precision) { + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + } } cur = ggml_reshape_2d(ctx0, cur, split_wo->ne[0], n_tokens);