Merge Q and K into a single tensor (#892)

* Merge Q and K into a single tensor

* Make V mul mat follow QK mul mat

so they can be fused, which gives a slightly bbetter TG performance.

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-11-05 10:54:36 +02:00
committed by GitHub
parent abb966eba1
commit 1a3aaa33c1
4 changed files with 81 additions and 1 deletions

View File

@@ -1270,6 +1270,7 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_build_mul_mat_qkv(ggml_cgraph * gf, ggml_tensor * cur,
ggml_tensor * wqkv, ggml_tensor * bqkv,
ggml_tensor * wqk, ggml_tensor * bqk,
ggml_tensor * wq, ggml_tensor * bq,
ggml_tensor * wk, ggml_tensor * bk,
ggml_tensor * wv, ggml_tensor * bv,
@@ -1307,6 +1308,40 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
//ggml_build_forward_expand(gf, Vcur);
}
if (wqk) {
auto qk = llm_build_lora_mm(lctx, ctx0, wqk, cur);
cb(qk, "qkv", il);
if (bqk) {
qk = ggml_add(ctx0, qk, bqk);
cb(qk, "qkv_b", il);
}
auto Vcur = llm_build_lora_mm(lctx, ctx0, wv, cur);
cb(Vcur, "Vcur", il);
if (bv) {
Vcur = ggml_add(ctx0, Vcur, bv);
cb(Vcur, "Vcur", il);
}
ggml_build_forward_expand(gf, qk);
ggml_build_forward_expand(gf, Vcur);
auto Qcur = ggml_view_3d(ctx0, qk, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), qk->nb[1], 0*sizeof(float)*(n_embd));
auto Kcur = ggml_view_3d(ctx0, qk, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), qk->nb[1], 1*sizeof(float)*Qcur->ne[0]*Qcur->ne[1]);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
if (q_norm) {
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
ggml_build_forward_expand(gf, Qcur);
}
if (k_norm) {
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
ggml_build_forward_expand(gf, Kcur);
}
return {Qcur, Kcur, Vcur};
}
auto [Q, K, V] = llm_build_mul_mat_qkv(gf, cur, wq, bq, wk, bk, wv, bv, attention_scale, il);
auto Qcur = ggml_reshape_3d(ctx0, Q, n_embd_head, n_head, n_tokens);
if (q_norm) {
@@ -1374,6 +1409,7 @@ ggml_cgraph * llm_build_context::build_llama() {
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wqk, model.layers[il].bqk,
model.layers[il].wq, model.layers[il].bq,
model.layers[il].wk, model.layers[il].bk,
model.layers[il].wv, model.layers[il].bv,
@@ -3400,6 +3436,7 @@ ggml_cgraph * llm_build_context::build_qwen3() {
{
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, nullptr,
model.layers[il].wqk, nullptr,
model.layers[il].wq, nullptr,
model.layers[il].wk, nullptr,
model.layers[il].wv, nullptr,
@@ -3502,6 +3539,7 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
{
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, nullptr,
model.layers[il].wqk, nullptr,
model.layers[il].wq, nullptr, model.layers[il].wk, nullptr, model.layers[il].wv, nullptr,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0, il);
@@ -6403,6 +6441,7 @@ ggml_cgraph * llm_build_context::build_glm4_moe() {
{
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wqk, model.layers[il].bqk,
model.layers[il].wq, model.layers[il].bq,
model.layers[il].wk, model.layers[il].bk,
model.layers[il].wv, model.layers[il].bv,
@@ -6814,6 +6853,7 @@ ggml_cgraph * llm_build_context::build_cohere2() {
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wqk, model.layers[il].bqk,
model.layers[il].wq, model.layers[il].bq,
model.layers[il].wk, model.layers[il].bk,
model.layers[il].wv, model.layers[il].bv, nullptr, nullptr, 0.f, il);
@@ -8116,6 +8156,7 @@ ggml_cgraph * llm_build_context::build_openai_moe() {
{
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wqk, model.layers[il].bqk,
model.layers[il].wq, model.layers[il].bq,
model.layers[il].wk, model.layers[il].bk,
model.layers[il].wv, model.layers[il].bv,
@@ -8234,7 +8275,7 @@ ggml_cgraph * llm_build_context::build_bailingmoe2() {
// self_attention
{
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wqkv, model.layers[il].bqkv,
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.0f, il);
if (rope_cache) {