diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 14771e65..a1586849 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1442,7 +1442,8 @@ static ggml_tensor * llm_build_kqv( cb(v, "v", il); if (q->ne[1] == 1 && k->ne[1] >= 8192 && q->ne[2] / k->ne[2] == 12 && !sinks && n_swa == 0 && - k->view_src && k->view_src->buffer && !ggml_backend_buffer_is_host(k->view_src->buffer)) { + k->view_src && k->view_src->buffer && !ggml_backend_buffer_is_host(k->view_src->buffer) && + k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16) { cur = build_glm45_fa(ctx, q, k, v, kq_mask, kq_scale, should_use_f32_precision); } else { @@ -9390,7 +9391,8 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens cb(v, "v", il_cb); if (q->ne[1] == 1 && k->ne[1] >= 65536/k->ne[2] && q->ne[2] / k->ne[2] == 12 && !sinks && n_swa == 0 && - k->view_src && k->view_src->buffer && !ggml_backend_buffer_is_host(k->view_src->buffer)) { + k->view_src && k->view_src->buffer && !ggml_backend_buffer_is_host(k->view_src->buffer) && + k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16) { cur = build_glm45_fa(ctx0, q, k, v, KQ_mask, KQ_scale, should_use_f32_precision); } else { cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, KQ_scale, hparams.f_max_alibi_bias,