Add softcap to flash attention

Just CPU and CUDA for now (but, as we know, flash attention
on the CPU is useless in llama.cpp).

On CUDA this improves PP performance quite a bit, especially for
long contexts. E.g., for PP-16384, I now get 3777 t/s.
Without this change, one cannot use FA, and one gets 2300 t/s
(after fusing softcap and softmax), or 2000 t/s without the
fused softcap+softmax.

In comparison, mainline llama.cpp has PP-16384 = 1549 t/s before
PR-8542 (where Johannes Gaessler has also added softcap to FA),
and PP-16384 = 3097 t/s after this PR.
This commit is contained in:
Iwan Kawrakow
2024-08-26 18:22:29 +03:00
parent 7168adfe71
commit 46862d725b
12 changed files with 257 additions and 105 deletions

View File

@@ -8290,7 +8290,8 @@ static struct ggml_tensor * llm_build_kqv(
0);
cb(v, "v", il);
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
@@ -13222,47 +13223,31 @@ struct llm_build_context {
0);
cb(k, "k", il);
if (cparams.flash_attn) {
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
cb(kq, "kq", il);
// split cached v into n_head heads (not transposed)
struct ggml_tensor * v =
ggml_view_3d(ctx0, kv_self.v_l[il],
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v),
0);
cb(v, "v", il);
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
GGML_ASSERT(kv_self.size == n_ctx);
cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
} else {
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
cb(kq, "kq", il);
// split cached v into n_head heads
struct ggml_tensor * v =
ggml_view_3d(ctx0, kv_self.v_l[il],
n_kv, n_embd_head_v, n_head_kv,
ggml_element_size(kv_self.v_l[il])*n_ctx,
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
0);
cb(v, "v", il);
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
cb(kqv, "kqv", il);
GGML_ASSERT(kv_self.size == n_ctx);
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il);
// split cached v into n_head heads
struct ggml_tensor * v =
ggml_view_3d(ctx0, kv_self.v_l[il],
n_kv, n_embd_head_v, n_head_kv,
ggml_element_size(kv_self.v_l[il])*n_ctx,
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
0);
cb(v, "v", il);
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
cb(kqv, "kqv", il);
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il);
cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
cb(cur_attn, "kqv_merged_cont", il);
}
cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
cb(cur_attn, "kqv_merged_cont", il);
cur_attn = llm_build_norm(ctx0, cur_attn, hparams,
model.layers[il].attn_sub_norm, NULL,
@@ -16813,12 +16798,6 @@ struct llama_context * llama_new_context_with_model(
params.flash_attn = false;
}
if (params.flash_attn && model->hparams.attn_soft_cap) {
LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
params.flash_attn = false;
}
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
params.flash_attn = false;