WIP - factor out split attention

This commit is contained in:
Kawrakow
2025-12-06 09:44:52 +00:00
parent 2f645f2579
commit 22ac19958f
2 changed files with 201 additions and 181 deletions

View File

@@ -9307,192 +9307,209 @@ ggml_cgraph * llm_build_context::llama_build_graph(
return result;
}
void llm_build_context::build_std_attention(ggml_cgraph * gf, std::vector<ggml_tensor *> & input, ggml_tensor * inp_pos, ggml_tensor * rope_factors_in,
ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale, int n_swa, int il) {
GGML_ASSERT(cparams.flash_attn);
GGML_ASSERT(!model.layers[il].wqkv);
GGML_ASSERT(!model.layers[il].wqk);
GGML_ASSERT(model.layers[il].wq && model.layers[il].wq->extra);
GGML_ASSERT(model.layers[il].wk && model.layers[il].wk->extra);
GGML_ASSERT(model.layers[il].wv && model.layers[il].wv->extra);
GGML_ASSERT(kv_self.k_l[il]->extra && kv_self.v_l[il]->extra);
auto attn_norm = model.layers[il].attn_norm ? (ggml_split_tensor_t *)model.layers[il].attn_norm->extra : nullptr;
auto wq = (ggml_split_tensor_t *)model.layers[il].wq->extra;
auto wk = (ggml_split_tensor_t *)model.layers[il].wk->extra;
auto wv = (ggml_split_tensor_t *)model.layers[il].wv->extra;
auto wo = (ggml_split_tensor_t *)model.layers[il].wo->extra;
GGML_ASSERT(wq->n_device == wk->n_device && wq->n_device == wv->n_device && wq->n_device == wo->n_device);
GGML_ASSERT(wq->n_device == int(input.size()));
auto kl = (ggml_split_tensor_t *)kv_self.k_l[il]->extra;
auto vl = (ggml_split_tensor_t *)kv_self.v_l[il]->extra;
GGML_ASSERT(wq->n_device == kl->n_device && wq->n_device == vl->n_device);
ggml_split_tensor_t *bq = nullptr, *bo = nullptr, *bk = nullptr, *bv = nullptr;
if (model.layers[il].bq && model.layers[il].bq->extra) {
bq = (ggml_split_tensor_t *)model.layers[il].bq->extra;
GGML_ASSERT(bq->n_device == wq->n_device);
}
if (model.layers[il].bo && model.layers[il].bo->extra) {
bo = (ggml_split_tensor_t *)model.layers[il].bo->extra;
GGML_ASSERT(bo->n_device == wq->n_device);
}
if (model.layers[il].bk && model.layers[il].bk->extra) {
bk = (ggml_split_tensor_t *)model.layers[il].bk->extra;
GGML_ASSERT(bk->n_device == wq->n_device);
}
if (model.layers[il].bv && model.layers[il].bv->extra) {
bv = (ggml_split_tensor_t *)model.layers[il].bv->extra;
GGML_ASSERT(bv->n_device == wq->n_device);
}
for (int id = 0; id < wq->n_device; ++id) {
int il_cb = 1000*(id+1) + il;
auto split_wq = wq->splits[id];
auto split_wk = wk->splits[id];
auto split_wv = wv->splits[id];
auto split_wo = wo->splits[id];
auto split_kl = kl->splits[id];
auto split_vl = vl->splits[id];
GGML_ASSERT((!split_wq && !split_wk && !split_wv && !split_wo && !split_kl && !split_vl) ||
(split_wq && split_wk && split_wv && split_wo && split_kl && split_vl));
if (!split_wq) {
GGML_ASSERT(input[id] == nullptr);
continue;
}
auto cur = input[id];
if (attn_norm) {
cur = llm_build_norm(ctx0, cur, hparams, attn_norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il_cb);
}
else if (cur->type != GGML_TYPE_F32) {
cur = ggml_cast(ctx0, cur, GGML_TYPE_F32);
}
auto the_q_norm = model.layers[il].attn_q_norm ? model.layers[il].attn_q_norm->extra ?
((ggml_split_tensor_t *)model.layers[il].attn_q_norm->extra)->splits[id] : model.layers[il].attn_q_norm : nullptr;
auto the_k_norm = model.layers[il].attn_k_norm ? model.layers[il].attn_k_norm->extra ?
((ggml_split_tensor_t *)model.layers[il].attn_k_norm->extra)->splits[id] : model.layers[il].attn_k_norm : nullptr;
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, nullptr, nullptr, nullptr, nullptr,
split_wq, bq ? bq->splits[id] : nullptr,
split_wk, bk ? bk->splits[id] : nullptr,
split_wv, bv ? bv->splits[id] : nullptr,
the_q_norm, the_k_norm, f_attn_scale, il_cb);
auto rope_factors = rope_factors_in;
if (!rope_factors && model.layers[il].rope_freqs && model.layers[il].rope_freqs->extra) {
auto extra = (ggml_split_tensor_t *)model.layers[il].rope_freqs->extra;
rope_factors = extra->splits[id];
}
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il_cb);
cb(Kcur, "Kcur", il_cb);
if (inp_attn_scale) {
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
cb(Qcur, "Qcur_temp_scaled", il_cb);
}
if (cparams.k_cache_hadamard) {
Qcur = ggml_hadamard(ctx0, Qcur, hparams.n_embd_head_k);
Kcur = ggml_hadamard(ctx0, Kcur, hparams.n_embd_head_k);
cb(Qcur, "Qcur_hadamard", il_cb);
cb(Kcur, "Kcur_hadamard", il_cb);
}
ggml_build_forward_expand(gf, Qcur);
ggml_build_forward_expand(gf, Kcur);
ggml_build_forward_expand(gf, Vcur);
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_head_kv = split_wk->ne[1] / n_embd_head_k;
GGML_ASSERT(kv_self.size == cparams.n_ctx);
auto idx = 2*wq->n_device*il + 2*id;
GGML_ASSERT(idx+1 < (int)lctx.cache_copies.size());
auto k_row_size = ggml_row_size(split_kl->type, n_embd_head_k);
ggml_tensor * k_cache_view = ggml_view_2d(ctx0, split_kl, n_embd_head_k, n_tokens*n_head_kv,
k_row_size, k_row_size*n_head_kv*kv_head);
lctx.cache_copies[idx+0].cpy = ggml_cpy(ctx0, Kcur, k_cache_view);
lctx.cache_copies[idx+0].step = k_row_size*n_head_kv;
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(gf, lctx.cache_copies[idx+0].cpy);
auto v_cache_view = ggml_view_1d(ctx0, split_vl, n_tokens*split_wv->ne[1],
kv_head*ggml_row_size(split_vl->type, split_wv->ne[1]));
cb(v_cache_view, "v_cache_view", il_cb);
lctx.cache_copies[idx+1].step = ggml_row_size(split_vl->type, split_wv->ne[1]);
lctx.cache_copies[idx+1].cpy = ggml_cpy(ctx0, Vcur, v_cache_view);
ggml_build_forward_expand(gf, lctx.cache_copies[idx+1].cpy);
auto q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
cb(q, "q", il_cb);
auto k = ggml_view_3d(ctx0, split_kl, n_embd_head_k, n_kv, n_head_kv,
ggml_row_size(split_kl->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa),
ggml_row_size(split_kl->type, n_embd_head_k), 0);
cb(k, "k", il_cb);
auto v = ggml_view_3d(ctx0, split_vl, n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(split_vl->type, split_wv->ne[1]),
ggml_row_size(split_vl->type, n_embd_head_v), 0);
cb(v, "v", il_cb);
#ifdef GGML_USE_VULKAN
constexpr bool use_f32_precision = true;
#else
constexpr bool use_f32_precision = false;
#endif
cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, KQ_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
cb(cur, "flash_attn", il_cb);
ggml_flash_attn_ext_add_sinks(cur, sinks);
if (n_swa > 0) {
((int32_t *)cur->op_params)[4] = n_swa;
}
// Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX ||
(model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 ||
model.arch == LLM_ARCH_GLM4_MOE) {
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
}
cur = ggml_reshape_2d(ctx0, cur, split_wo->ne[0], n_tokens);
cb(cur, "flash_attn_reshaped", il_cb);
cur = llm_build_lora_mm(lctx, ctx0, split_wo, cur);
if (lctx.model.arch == LLM_ARCH_GLM4 || lctx.model.arch == LLM_ARCH_GLM4_MOE) {
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
cb(cur, "kqv_wo", il_cb);
if (bo) {
cur = ggml_add(ctx0, cur, bo->splits[id]);
cb(cur, "kqv_wo_biased", il_cb);
}
if (cur->ne[1] >= 32) {
cur = ggml_cast(ctx0, cur, GGML_TYPE_F16);
}
ggml_build_forward_expand(gf, cur);
input[id] = cur;
}
}
ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tensor * input, ggml_tensor * inp_pos, ggml_tensor * rope_factors_in,
ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale, int n_swa, int il) {
if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn &&
model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) {
if (kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) {
ggml_split_tensor_t * attn_norm = model.layers[il].attn_norm ? (ggml_split_tensor_t *)model.layers[il].attn_norm->extra : nullptr;
auto wq = (ggml_split_tensor_t *)model.layers[il].wq->extra;
auto wk = (ggml_split_tensor_t *)model.layers[il].wk->extra;
auto wv = (ggml_split_tensor_t *)model.layers[il].wv->extra;
auto wo = (ggml_split_tensor_t *)model.layers[il].wo->extra;
GGML_ASSERT(wq->n_device == wk->n_device && wq->n_device == wv->n_device && wq->n_device == wo->n_device);
auto kl = (ggml_split_tensor_t *)kv_self.k_l[il]->extra;
auto vl = (ggml_split_tensor_t *)kv_self.v_l[il]->extra;
GGML_ASSERT(wq->n_device == kl->n_device && wq->n_device == vl->n_device);
ggml_split_tensor_t *bq = nullptr, *bo = nullptr, *bk = nullptr, *bv = nullptr;
if (model.layers[il].bq && model.layers[il].bq->extra) {
bq = (ggml_split_tensor_t *)model.layers[il].bq->extra;
GGML_ASSERT(bq->n_device == wq->n_device);
model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra &&
kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) {
auto wq = (ggml_split_tensor_t *)model.layers[il].wq->extra;
auto wk = (ggml_split_tensor_t *)model.layers[il].wk->extra;
auto wv = (ggml_split_tensor_t *)model.layers[il].wv->extra;
auto wo = (ggml_split_tensor_t *)model.layers[il].wo->extra;
GGML_ASSERT(wq->n_device == wk->n_device && wq->n_device == wv->n_device && wq->n_device == wo->n_device);
auto kl = (ggml_split_tensor_t *)kv_self.k_l[il]->extra;
auto vl = (ggml_split_tensor_t *)kv_self.v_l[il]->extra;
GGML_ASSERT(wq->n_device == kl->n_device && wq->n_device == vl->n_device);
std::vector<ggml_tensor *> attn(wq->n_device, nullptr);
std::vector<int> ids; ids.reserve(wq->n_device);
for (int id = 0; id < wq->n_device; ++id) {
if (wq->splits[id]) {
attn[id] = input;
ids.push_back(id);
}
if (model.layers[il].bo && model.layers[il].bo->extra) {
bo = (ggml_split_tensor_t *)model.layers[il].bo->extra;
GGML_ASSERT(bo->n_device == wq->n_device);
}
if (model.layers[il].bk && model.layers[il].bk->extra) {
bk = (ggml_split_tensor_t *)model.layers[il].bk->extra;
GGML_ASSERT(bk->n_device == wq->n_device);
}
if (model.layers[il].bv && model.layers[il].bv->extra) {
bv = (ggml_split_tensor_t *)model.layers[il].bv->extra;
GGML_ASSERT(bv->n_device == wq->n_device);
}
std::vector<ggml_tensor*> attn; attn.reserve(wq->n_device);
for (int id = 0; id < wq->n_device; ++id) {
int il_cb = 1000*(id+1) + il;
auto split_wq = wq->splits[id];
auto split_wk = wk->splits[id];
auto split_wv = wv->splits[id];
auto split_wo = wo->splits[id];
auto split_kl = kl->splits[id];
auto split_vl = vl->splits[id];
GGML_ASSERT((!split_wq && !split_wk && !split_wv && !split_wo && !split_kl && !split_vl) ||
(split_wq && split_wk && split_wv && split_wo && split_kl && split_vl));
if (!split_wq) continue;
auto cur = input;
if (attn_norm) {
auto split_norm = attn_norm->splits[id];
cur = llm_build_norm(ctx0, cur, hparams, split_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il_cb);
}
else if (cur->type != GGML_TYPE_F32) {
cur = ggml_cast(ctx0, cur, GGML_TYPE_F32);
}
auto the_q_norm = model.layers[il].attn_q_norm ? model.layers[il].attn_q_norm->extra ?
((ggml_split_tensor_t *)model.layers[il].attn_q_norm->extra)->splits[id] : model.layers[il].attn_q_norm : nullptr;
auto the_k_norm = model.layers[il].attn_k_norm ? model.layers[il].attn_k_norm->extra ?
((ggml_split_tensor_t *)model.layers[il].attn_k_norm->extra)->splits[id] : model.layers[il].attn_k_norm : nullptr;
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, nullptr, nullptr, nullptr, nullptr,
split_wq, bq ? bq->splits[id] : nullptr,
split_wk, bk ? bk->splits[id] : nullptr,
split_wv, bv ? bv->splits[id] : nullptr,
the_q_norm, the_k_norm, f_attn_scale, il_cb);
auto rope_factors = rope_factors_in;
if (!rope_factors && model.layers[il].rope_freqs && model.layers[il].rope_freqs->extra) {
auto extra = (ggml_split_tensor_t *)model.layers[il].rope_freqs->extra;
rope_factors = extra->splits[id];
}
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il_cb);
cb(Kcur, "Kcur", il_cb);
if (inp_attn_scale) {
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
cb(Qcur, "Qcur_temp_scaled", il_cb);
}
if (cparams.k_cache_hadamard) {
Qcur = ggml_hadamard(ctx0, Qcur, hparams.n_embd_head_k);
Kcur = ggml_hadamard(ctx0, Kcur, hparams.n_embd_head_k);
cb(Qcur, "Qcur_hadamard", il_cb);
cb(Kcur, "Kcur_hadamard", il_cb);
}
ggml_build_forward_expand(gf, Qcur);
ggml_build_forward_expand(gf, Kcur);
ggml_build_forward_expand(gf, Vcur);
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_head_kv = split_wk->ne[1] / n_embd_head_k;
GGML_ASSERT(kv_self.size == cparams.n_ctx);
auto idx = 2*wq->n_device*il + 2*id;
GGML_ASSERT(idx+1 < (int)lctx.cache_copies.size());
auto k_row_size = ggml_row_size(split_kl->type, n_embd_head_k);
ggml_tensor * k_cache_view = ggml_view_2d(ctx0, split_kl, n_embd_head_k, n_tokens*n_head_kv,
k_row_size, k_row_size*n_head_kv*kv_head);
lctx.cache_copies[idx+0].cpy = ggml_cpy(ctx0, Kcur, k_cache_view);
lctx.cache_copies[idx+0].step = k_row_size*n_head_kv;
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(gf, lctx.cache_copies[idx+0].cpy);
struct ggml_tensor * v_cache_view = nullptr;
if (cparams.flash_attn) {
v_cache_view = ggml_view_1d(ctx0, split_vl, n_tokens*split_wv->ne[1],
kv_head*ggml_row_size(split_vl->type, split_wv->ne[1]));
lctx.cache_copies[idx+1].step = ggml_row_size(split_vl->type, split_wv->ne[1]);
} else {
// note: the V cache is transposed when not using flash attention
v_cache_view = ggml_view_2d(ctx0, split_vl, n_tokens, split_wv->ne[1],
( n_ctx)*ggml_element_size(split_vl),
(kv_head)*ggml_element_size(split_vl));
lctx.cache_copies[idx+1].step = ggml_element_size(split_vl);
Vcur = ggml_transpose(ctx0, Vcur);
}
cb(v_cache_view, "v_cache_view", il_cb);
lctx.cache_copies[idx+1].cpy = ggml_cpy(ctx0, Vcur, v_cache_view);
ggml_build_forward_expand(gf, lctx.cache_copies[idx+1].cpy);
auto q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
cb(q, "q", il_cb);
auto k = ggml_view_3d(ctx0, split_kl, n_embd_head_k, n_kv, n_head_kv,
ggml_row_size(split_kl->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa),
ggml_row_size(split_kl->type, n_embd_head_k), 0);
cb(k, "k", il_cb);
auto v = ggml_view_3d(ctx0, split_vl, n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(split_vl->type, split_wv->ne[1]),
ggml_row_size(split_vl->type, n_embd_head_v), 0);
cb(v, "v", il_cb);
#ifdef GGML_USE_VULKAN
constexpr bool use_f32_precision = true;
#else
constexpr bool use_f32_precision = false;
#endif
cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, KQ_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
cb(cur, "flash_attn", il_cb);
ggml_flash_attn_ext_add_sinks(cur, sinks);
if (n_swa > 0) {
((int32_t *)cur->op_params)[4] = n_swa;
}
// Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX ||
(model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 ||
model.arch == LLM_ARCH_GLM4_MOE) {
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
}
cur = ggml_reshape_2d(ctx0, cur, split_wo->ne[0], n_tokens);
cb(cur, "flash_attn_reshaped", il_cb);
cur = llm_build_lora_mm(lctx, ctx0, split_wo, cur);
if (lctx.model.arch == LLM_ARCH_GLM4 || lctx.model.arch == LLM_ARCH_GLM4_MOE) {
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
cb(cur, "kqv_wo", il_cb);
if (bo) {
cur = ggml_add(ctx0, cur, bo->splits[id]);
cb(cur, "kqv_wo_biased", il_cb);
}
if (cur->ne[1] >= 32) {
cur = ggml_cast(ctx0, cur, GGML_TYPE_F16);
}
ggml_build_forward_expand(gf, cur);
attn.push_back(cur);
}
GGML_ASSERT(!attn.empty());
if (attn.size() == 1) return attn.front();
auto cur = ggml_add(ctx0, attn[0], attn[1]);
cb(cur, "combine_attn", il);
cur->op_params[0] = 0xff;
for (int id = 2; id < (int)attn.size(); ++id) {
cur = ggml_add(ctx0, cur, attn[id]);
cb(cur, "combine_attn", il);
}
return cur;
}
GGML_ASSERT(!ids.empty());
build_std_attention(gf, attn, inp_pos, rope_factors_in, KQ_mask, sinks, inp_attn_scale, KQ_scale, f_attn_scale, n_swa, il);
if (ids.size() == 1) return attn[ids.front()];
auto cur = ggml_add(ctx0, attn[ids[0]], attn[ids[1]]);
cb(cur, "combine_attn", il);
cur->op_params[0] = 0xff;
for (int id = 2; id < (int)ids.size(); ++id) {
cur = ggml_add(ctx0, cur, attn[ids[id]]);
cb(cur, "combine_attn", il);
}
return cur;
}
auto cur = input;

View File

@@ -410,4 +410,7 @@ llm_expert_gating_func_type gating_op,
ggml_tensor * build_std_attention(ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * rope_factors,
ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale, int n_swa, int il);
void build_std_attention(ggml_cgraph * gf, std::vector<ggml_tensor *> & cur, ggml_tensor * inp_pos, ggml_tensor * rope_factors,
ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale, int n_swa, int il);
};