Reduce size of compute buffers (#237)

* This reduces compute buffer size for MLA

* This should accomplish it for standard attention

* Much better

* Better concat for contiguous tensors

If all the op does is to concatenate the second tensor
to the first, why would we want to have a loop?

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-03-01 08:25:27 +02:00
committed by GitHub
parent 472b4c37c1
commit e787c00141
7 changed files with 236 additions and 79 deletions

View File

@@ -2511,6 +2511,7 @@ struct llama_cparams {
bool offload_kqv;
bool flash_attn;
int mla_attn;
int attn_max_batch;
bool fused_moe_up_gate;
enum llama_pooling_type pooling_type;
@@ -8774,45 +8775,8 @@ static struct ggml_tensor * llm_build_kqv(
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
} else {
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
}
if (model.arch == LLM_ARCH_GROK) {
// need to do the following:
// multiply by attn_output_multiplyer of 0.08838834764831845
// and then :
// kq = 30 * tanh(kq / 30)
// before the softmax below
//try from phi2
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
//kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
//kq = ggml_scale(ctx, kq, 30);
kq = ggml_softcap(ctx, kq, 0.08838834764831845f/30.0f, 30.f);
}
if (hparams.attn_soft_cap) {
//kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
kq = ggml_softcap_max(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias,
1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
} else {
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
}
cb(kq, "kq_soft_max_ext", il);
GGML_ASSERT(kv.size == n_ctx);
// split cached v into n_head heads
// split cached v into n_head heads
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
n_kv, n_embd_head_v, n_head_kv,
@@ -8821,14 +8785,98 @@ static struct ggml_tensor * llm_build_kqv(
0);
cb(v, "v", il);
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
cb(kqv, "kqv", il);
auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024);
if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2]) {
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il);
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
cb(cur, "kqv_merged_cont", il);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
}
if (model.arch == LLM_ARCH_GROK) {
// need to do the following:
// multiply by attn_output_multiplyer of 0.08838834764831845
// and then :
// kq = 30 * tanh(kq / 30)
// before the softmax below
//try from phi2
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
//kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
//kq = ggml_scale(ctx, kq, 30);
kq = ggml_softcap(ctx, kq, 0.08838834764831845f/30.0f, 30.f);
}
if (hparams.attn_soft_cap) {
//kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
kq = ggml_softcap_max(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias,
1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
} else {
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
}
cb(kq, "kq_soft_max_ext", il);
GGML_ASSERT(kv.size == n_ctx);
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
cb(kqv, "kqv", il);
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il);
cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
cb(cur, "kqv_merged_cont", il);
}
else {
// For now we will not support this option if k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2];
GGML_ASSERT(k->ne[2] == v->ne[2] && k->ne[2] == q->ne[2]);
int n_step = (kq_size + cparams.attn_max_batch - 1)/cparams.attn_max_batch;
n_step = std::min(n_step, int(k->ne[2]));
int n_per_step = (q->ne[2] + n_step - 1)/n_step;
auto r2k = q->ne[2] / k->ne[2];
auto r2v = q->ne[2] / v->ne[2];
n_step = q->ne[2];
n_per_step = 1;
ggml_tensor * kqv;
for (int i12 = 0; i12 < q->ne[2]; i12 += n_per_step) {
int this_ne12 = i12 + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i12;
int i02 = i12/r2k;
auto k_i = ggml_view_3d(ctx, k, k->ne[0], k->ne[1], this_ne12, k->nb[1], k->nb[2], k->nb[2]*i02);
auto q_i = ggml_view_3d(ctx, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i12);
auto kq_i = ggml_mul_mat(ctx, k_i, q_i);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
ggml_mul_mat_set_prec(kq_i, GGML_PREC_F32);
}
if (model.arch == LLM_ARCH_GROK) {
kq_i = ggml_softcap(ctx, kq_i, 0.08838834764831845f/30.0f, 30.f);
}
if (hparams.attn_soft_cap) {
kq_i = ggml_softcap_max(ctx, kq_i, kq_mask, kq_scale, hparams.f_max_alibi_bias,
1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
} else {
kq_i = ggml_soft_max_ext(ctx, kq_i, kq_mask, kq_scale, hparams.f_max_alibi_bias);
}
i02 = i12 / r2v;
auto v_i = ggml_view_3d(ctx, v, v->ne[0], v->ne[1], this_ne12, v->nb[1], v->nb[2], v->nb[2]*i02);
auto kqv_i = ggml_mul_mat(ctx, v_i, kq_i);
if (i12 == 0) {
kqv = kqv_i;
} else {
kqv = ggml_concat(ctx, kqv, kqv_i, 2);
}
}
ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il);
cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
cb(cur, "kqv_merged_cont", il);
}
}
ggml_build_forward_expand(graph, cur);
@@ -8924,6 +8972,7 @@ struct llm_build_context {
const bool flash_attn;
const int mla_attn;
const int attn_max_batch;
const bool fused_moe_up_gate;
const enum llama_pooling_type pooling_type;
@@ -8976,6 +9025,7 @@ struct llm_build_context {
n_ctx_orig (cparams.n_ctx_orig_yarn),
flash_attn (cparams.flash_attn),
mla_attn (cparams.mla_attn),
attn_max_batch (cparams.attn_max_batch),
fused_moe_up_gate(cparams.fused_moe_up_gate),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
@@ -13572,25 +13622,6 @@ struct llm_build_context {
ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0);
cb(q, "q", il);
if (!pp_opt) {
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_perm", il);
}
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
cb(kq, "kq", il);
if (!pp_opt) {
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
cb(kq, "kq_perm", il);
}
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
if (!pp_opt) {
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
cb(kq, "kq_soft_max_ext_perm", il);
}
if (lctx.cparams.mla_attn > 1) {
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
@@ -13602,12 +13633,60 @@ struct llm_build_context {
cb(kv_cache_trans, "kv_cache_trans", il);
}
struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
cb(kqv_compressed, "kqv_compressed", il);
ggml_tensor * kqv_compressed;
auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB
if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) {
if (!pp_opt) {
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_perm", il);
}
if (!pp_opt) {
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
cb(kqv_compressed, "kqv_compressed_perm", il);
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
cb(kq, "kq", il);
if (!pp_opt) {
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
cb(kq, "kq_perm", il);
}
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
if (!pp_opt) {
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
cb(kq, "kq_soft_max_ext_perm", il);
}
kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
cb(kqv_compressed, "kqv_compressed", il);
if (!pp_opt) {
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
cb(kqv_compressed, "kqv_compressed_perm", il);
}
} else {
int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
n_step = std::min(n_step, int(q->ne[2]));
int n_per_step = (q->ne[2] + n_step - 1)/n_step;
//printf("kq size would be %ld MiB -> splitting kqv computation into %d steps\n", kq_size, n_step);
for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) {
int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head;
ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head);
ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
if (i_head == 0) {
kqv_compressed = kqv_i;
} else {
kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2);
}
ggml_build_forward_expand(gf, kqv_compressed);
}
cb(kqv_compressed, "kqv_compressed", il);
}
struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head,
@@ -17644,6 +17723,7 @@ struct llama_context_params llama_context_default_params() {
/*.offload_kqv =*/ true,
/*.flash_attn =*/ false,
/*.mla_attn =*/ 0,
/*.attn_max_batch =*/ 0,
/*.fused_moe_up_gate =*/ false,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
@@ -17844,6 +17924,7 @@ struct llama_context * llama_new_context_with_model(
cparams.offload_kqv = params.offload_kqv;
cparams.flash_attn = params.flash_attn;
cparams.mla_attn = params.mla_attn;
cparams.attn_max_batch = params.attn_max_batch;
cparams.fused_moe_up_gate= params.fused_moe_up_gate;
cparams.pooling_type = params.pooling_type;
@@ -17912,6 +17993,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);
LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);