Merge Q, K, V (#878)

* POC: merge Q, K, V into a single, contiguous tensor

Done just for Qwen3-MoE, where I see a 4% uplift in TG.
PP performance gain is sub-percent, if any.
Still, it seems it makes sense to do it in general given
the TG performance gain.

* WIP

* merge_qkv: it works for gpt-oss

...but we see a smaller TG gain (~1.5%)

* WIP

* Don't ignore the return value of create_tensors()

else, when q, k, v get merged and we are running on the CPU,
we get a crash because the backend is trying to use mmap,
but that no longer works.

* merge_qkv: bias can be required, optional, or mandatory

* merge_qkv: glm4.5moe

* merge_qkv: add command loine argument to enable

* merge_qkv: fix tensor dimensions

* merge_qkv: llama-4

* merge_qkv: qwen3 (dense)

* merge_qkv: simplify build_qwen3moe

* cohere2 - simplify graph building

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-10-30 10:49:48 +02:00
committed by GitHub
parent 92517e74ad
commit 56fc5454ff
10 changed files with 260 additions and 119 deletions

View File

@@ -1272,6 +1272,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.validate_quants = true;
return true;
}
if (arg == "-mqkv" || arg == "--merge-qkv") {
params.merge_qkv = true;
return true;
}
if (arg == "--numa") {
CHECK_ARG
std::string value(argv[i]);
@@ -1911,6 +1915,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" });
options.push_back({ "*", "-no-mmad, --no-fused-mul-multiadd", "disaable fused mul-multi_add (default: %s)", params.fused_mmad? "enabled" : "disabled" });
options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts});
options.push_back({ "*", "-mqkv, --merge-qkv,", "merge Q,K,V (default: %d)", params.merge_qkv});
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
"in conversation mode, this will be used as system prompt\n"
"(default: '%s')", params.prompt.c_str() });
@@ -2778,7 +2783,7 @@ void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lor
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
auto mparams = llama_model_default_params();
mparams.devices = params.devices.c_str();
mparams.devices = params.devices.c_str();
if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers;
@@ -2794,6 +2799,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
mparams.repack_tensors = params.repack_tensors;
mparams.use_thp = params.use_thp;
mparams.validate_quants = params.validate_quants;
mparams.merge_qkv = params.merge_qkv;
if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL;
} else {
@@ -3965,6 +3971,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "repack: %s # default: false\n", params.repack_tensors ? "true" : "false");
fprintf(stream, "use_thp: %s # default: false\n", params.use_thp ? "true" : "false");
fprintf(stream, "validate_quants: %s # default: false\n", params.validate_quants ? "true" : "false");
fprintf(stream, "merge_qkv: %s # default: false\n", params.merge_qkv ? "true" : "false");
fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false");
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);

View File

@@ -269,6 +269,7 @@ struct gpt_params {
bool use_thp = false; // use transparent huge pages (linux only)
bool validate_quants = false; // if true, check for NaNs while loading the model
bool only_active_exps = true; // if true, offload only active experts (relevant only for hybrid CPU/GPU)
bool merge_qkv = false; // if true, merge separate Q, K, V tensors into a single, contiguous tensor
std::string cache_type_k = "f16"; // KV cache data type for the K
std::string cache_type_v = "f16"; // KV cache data type for the V

View File

@@ -382,6 +382,7 @@ extern "C" {
bool repack_tensors;// repack if available
bool use_thp; // use transparent huge pages (linux only)
bool validate_quants; // if true, check for NaNs while loading the model
bool merge_qkv; // if true, merge separate Q, K, V tensors into a single, contiguous tensor
};
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations

View File

@@ -1238,23 +1238,76 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
cb(Qcur, "Qcur", il);
}
if (bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
Qcur = ggml_add(ctx0, Qcur, bq);
cb(Qcur, "Qcur", il);
ggml_build_forward_expand(gf, Qcur);
}
if (bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
Kcur = ggml_add(ctx0, Kcur, bk);
cb(Kcur, "Kcur", il);
ggml_build_forward_expand(gf, Kcur);
}
if (bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
Vcur = ggml_add(ctx0, Vcur, bv);
cb(Vcur, "Vcur", il);
ggml_build_forward_expand(gf, Vcur);
}
return {Qcur, Kcur, Vcur};
}
std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_build_mul_mat_qkv(ggml_cgraph * gf, ggml_tensor * cur,
ggml_tensor * wqkv, ggml_tensor * bqkv,
ggml_tensor * wq, ggml_tensor * bq,
ggml_tensor * wk, ggml_tensor * bk,
ggml_tensor * wv, ggml_tensor * bv,
ggml_tensor * q_norm, ggml_tensor * k_norm, float attention_scale, int il) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
if (wqkv) {
auto qkv = llm_build_lora_mm(lctx, ctx0, wqkv, cur);
cb(qkv, "qkv", il);
if (bqkv) {
qkv = ggml_add(ctx0, qkv, bqkv);
cb(qkv, "qkv_b", il);
}
auto Qcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), qkv->nb[1], 0*sizeof(float)*(n_embd));
auto Kcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), qkv->nb[1], 1*sizeof(float)*Qcur->ne[0]*Qcur->ne[1]);
auto Vcur = ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(Qcur->ne[0]*Qcur->ne[1] + Kcur->ne[0]*Kcur->ne[1]));
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
if (q_norm) {
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
}
if (k_norm) {
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
}
return {Qcur, Kcur, Vcur};
//ggml_build_forward_expand(gf, Qcur);
//ggml_build_forward_expand(gf, Kcur);
//ggml_build_forward_expand(gf, Vcur);
}
auto [Q, K, V] = llm_build_mul_mat_qkv(gf, cur, wq, bq, wk, bk, wv, bv, attention_scale, il);
auto Qcur = ggml_reshape_3d(ctx0, Q, n_embd_head, n_head, n_tokens);
if (q_norm) {
Qcur = llm_build_norm(ctx0, Qcur, hparams, q_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
}
auto Kcur = ggml_reshape_3d(ctx0, K, n_embd_head, n_head_kv, n_tokens);
if (k_norm) {
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
}
auto Vcur = V;
return {Qcur, Kcur, Vcur};
}
ggml_cgraph * llm_build_context::build_llama() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
@@ -1304,21 +1357,23 @@ ggml_cgraph * llm_build_context::build_llama() {
// rope freq factors for llama3; may return nullptr for llama2 and other models
struct ggml_tensor * rope_factors = build_rope_factors(il);
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq,
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wq, model.layers[il].bq,
model.layers[il].wk, model.layers[il].bk,
model.layers[il].wv, model.layers[il].bv,
hparams.f_attention_scale, il);
nullptr, nullptr, hparams.f_attention_scale, il);
if (use_rope) {
Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
} else if (inp_attn_scale) {
Qcur = ggml_mul(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_attn_scale);
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
}
cb(Qcur, "Qcur", il);
@@ -3324,30 +3379,21 @@ ggml_cgraph * llm_build_context::build_qwen3() {
// self-attention
{
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, nullptr,
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, nullptr,
model.layers[il].wq, nullptr,
model.layers[il].wk, nullptr,
model.layers[il].wv, nullptr, 0, il);
model.layers[il].wv, nullptr,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0, il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
@@ -3430,13 +3476,10 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
// self_attention
{
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, nullptr,
model.layers[il].wk, nullptr,
model.layers[il].wv, nullptr, 0, il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, nullptr,
model.layers[il].wq, nullptr, model.layers[il].wk, nullptr, model.layers[il].wv, nullptr,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0, il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
@@ -3445,10 +3488,6 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
);
cb(Qcur, "Qcur", il);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -6054,24 +6093,12 @@ ggml_cgraph * llm_build_context::build_glm4_moe() {
// self-attention
{
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq,
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wq, model.layers[il].bq,
model.layers[il].wk, model.layers[il].bk,
model.layers[il].wv, model.layers[il].bv, 0.f, il);
// reshape for multi-head
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
// Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
// Apply Q/K norm if available (GLM-4.5 355B variant)
if (model.layers[il].attn_q_norm) {
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
}
if (model.layers[il].attn_k_norm) {
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
}
model.layers[il].wv, model.layers[il].bv,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.f, il);
// apply RoPE
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
@@ -6474,28 +6501,23 @@ ggml_cgraph * llm_build_context::build_cohere2() {
// rope freq factors for 128k context
struct ggml_tensor * rope_factors = build_rope_factors(il);
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq,
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wq, model.layers[il].bq,
model.layers[il].wk, model.layers[il].bk,
model.layers[il].wv, model.layers[il].bv, 0.f, il);
model.layers[il].wv, model.layers[il].bv, nullptr, nullptr, 0.f, il);
if (is_sliding) {
Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos,
rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il);
} else {
// For non-sliding layers, just reshape without applying RoPE
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cb(Qcur, "Qcur", il);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
cb(Kcur, "Kcur", il);
}
};
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il, nullptr,
@@ -6537,6 +6559,7 @@ ggml_cgraph * llm_build_context::build_cohere2() {
// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "output", -1);
if (f_logit_scale) {
cur = ggml_scale(ctx0, cur, f_logit_scale);
@@ -7773,20 +7796,37 @@ ggml_cgraph * llm_build_context::build_openai_moe() {
// self-attention
{
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq,
model.layers[il].wk, model.layers[il].bk,
model.layers[il].wv, model.layers[il].bv, 0.f, il);
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wq, model.layers[il].bq,
model.layers[il].wk, model.layers[il].bk,
model.layers[il].wv, model.layers[il].bv,
nullptr, nullptr, 0.0f, il);
Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr,
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il);
//auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq,
// model.layers[il].wk, model.layers[il].bk,
// model.layers[il].wv, model.layers[il].bv, 0.f, il);
//Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr,
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
// beta_fast, beta_slow);
//cb(Qcur, "Qcur", il);
//Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
// attn_factor, beta_fast, beta_slow);
//cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks,
is_sliding ? hparams.n_swa : 0);
@@ -7853,7 +7893,7 @@ ggml_cgraph * llm_build_context::build_openai_moe() {
ggml_cgraph * llm_build_context::build_bailingmoe2() {
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
//const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7883,23 +7923,27 @@ ggml_cgraph * llm_build_context::build_bailingmoe2() {
// self_attention
{
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wqkv, model.layers[il].bqkv,
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.0f, il);
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
//ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
//cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
//cb(cur, "wqkv", il);
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
//ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
//ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
////ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
//ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
//Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
//cb(Qcur, "Qcur_normed", il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
//Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
//cb(Kcur, "Kcur_normed", il);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,

View File

@@ -149,6 +149,13 @@ struct llm_build_context {
ggml_tensor * wv, ggml_tensor * bv,
float attention_scale, int il);
std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_mul_mat_qkv(ggml_cgraph * gf, ggml_tensor * cur,
ggml_tensor * wqkv, ggml_tensor * bqkv,
ggml_tensor * wq, ggml_tensor * bq,
ggml_tensor * wk, ggml_tensor * bk,
ggml_tensor * wv, ggml_tensor * bv,
ggml_tensor * q_norm, ggml_tensor * k_norm, float attention_scale, int il);
ggml_cgraph * build_llama();
ggml_cgraph * build_deci();

View File

@@ -28,6 +28,8 @@ struct create_tensors_helper : public create_tensors_helper_interface {
virtual size_t get_ctx_size() const override { return ctx_size; }
bool merge_qkv(const LLM_TN & tn, int i, int bias);
bool create_tensors() override;
bool create_llama_tensors(const LLM_TN & tn);
@@ -284,15 +286,11 @@ bool create_tensors_helper::create_llama_tensors(const LLM_TN & tn) {
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
use_mmap_buffer &= !merge_qkv(tn, i, 1);
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
// optional bias tensors
layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
@@ -418,9 +416,8 @@ bool create_tensors_helper::create_llama4_tensors(const LLM_TN & tn) {
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
use_mmap_buffer &= !merge_qkv(tn, i, 0);
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
@@ -1018,9 +1015,8 @@ bool create_tensors_helper::create_qwen3_tensors(const LLM_TN & tn) {
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
use_mmap_buffer &= !merge_qkv(tn, i, 0);
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k});
@@ -1044,9 +1040,8 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) {
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
use_mmap_buffer &= !merge_qkv(tn, i, 0);
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k});
@@ -1700,12 +1695,16 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) {
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
// GLM-style attention with bias terms
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags);
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags);
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags);
layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, flags);
layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, flags);
layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, flags);
if (!flags) {
use_mmap_buffer &= !merge_qkv(tn, i, 2);
} else {
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags);
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags);
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags);
layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, flags);
layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, flags);
layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, flags);
}
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags);
@@ -2380,10 +2379,10 @@ bool create_tensors_helper::create_openai_moe_tensors(const LLM_TN & tn) {
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0);
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
use_mmap_buffer &= !merge_qkv(tn, i, 2);
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0);
layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
layer.attn_sinks = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0);
@@ -2394,11 +2393,6 @@ bool create_tensors_helper::create_openai_moe_tensors(const LLM_TN & tn) {
layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0, &ctx_ffn_up);
// bias
layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_head * n_rot}, 0);
layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_head_kv * n_rot}, 0);
layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_head_kv * n_rot}, 0);
layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
ggml_context *ctx_ffn_gate_b, *ctx_ffn_up_b, *ctx_ffn_down_b;
layer.ffn_gate_inp_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0);
layer.ffn_gate_exps_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0, &ctx_ffn_gate_b);
@@ -2421,6 +2415,88 @@ bool create_tensors_helper::create_openai_moe_tensors(const LLM_TN & tn) {
return use_mmap_buffer;
}
bool create_tensors_helper::merge_qkv(const LLM_TN & tn, int i, int bias) {
auto& hparams = model.hparams;
const int64_t n_head = hparams.n_head();
const int64_t n_head_kv = hparams.n_head_kv();
const int64_t n_embd = hparams.n_embd;
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_gqa = n_embd_v_gqa;
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
auto wq_name = tn(LLM_TENSOR_ATTN_Q, "weight", i);
auto wk_name = tn(LLM_TENSOR_ATTN_K, "weight", i);
auto wv_name = tn(LLM_TENSOR_ATTN_V, "weight", i);
auto wq = ml.require_tensor_meta(wq_name.c_str());
auto wk = ml.require_tensor_meta(wk_name.c_str());
auto wv = ml.require_tensor_meta(wv_name.c_str());
GGML_ASSERT(wq && wk && wv);
bool fused_qkv = false;
if (ml.merge_qkv && wq->type == wk->type && wq->type == wv->type && hparams.f_attention_scale == 0.0f) {
GGML_ASSERT(wq->ne[0] == n_embd && wq->ne[1] == n_head * n_embd_head_k);
GGML_ASSERT(wk->ne[0] == n_embd && wk->ne[1] == n_embd_gqa);
GGML_ASSERT(wv->ne[0] == n_embd && wv->ne[1] == n_embd_gqa);
layer.wqkv = ggml_new_tensor_2d(ctx_split, wq->type, n_embd, n_embd_head_k * (n_head + n_head_kv + n_head_kv));
snprintf(layer.wqkv->name, GGML_MAX_NAME, "blk.%d.attn_qkv.weight", i);
// This does not work. If we are doing this merge manually, it basically means that the arch does not have
// an LLM_TENSOR_ATTN_QKV entry, so we will get __missing__ as the tensor name.
//ggml_set_name(layer.wqkv, tn(LLM_TENSOR_ATTN_QKV, "weight", i).c_str());
layer.wq = ml.create_tensor_as_view(ctx_split, layer.wqkv, wq_name.c_str(), { wq->ne[0], wq->ne[1] }, 0);
layer.wk = ml.create_tensor_as_view(ctx_split, layer.wqkv, wk_name.c_str(), { wk->ne[0], wk->ne[1] }, wq->ne[1]*wq->nb[1]);
layer.wv = ml.create_tensor_as_view(ctx_split, layer.wqkv, wv_name.c_str(), { wv->ne[0], wv->ne[1] }, wq->ne[1]*wq->nb[1] + wk->ne[1]*wk->nb[1] );
fused_qkv = true;
printf("================================== Created merged qkv %s\n", layer.wqkv->name);
if (bias) {
auto bq_name = tn(LLM_TENSOR_ATTN_Q, "bias", i);
auto bk_name = tn(LLM_TENSOR_ATTN_K, "bias", i);
auto bv_name = tn(LLM_TENSOR_ATTN_V, "bias", i);
auto bq = ml.get_tensor_meta(bq_name.c_str());
auto bk = ml.get_tensor_meta(bk_name.c_str());
auto bv = ml.get_tensor_meta(bv_name.c_str());
if (bias == 2) {
GGML_ASSERT(bq && bk && bv);
} else {
GGML_ASSERT(!bq && !bk && !bv);
}
if (bq && bk && bv) {
GGML_ASSERT(bq->type == GGML_TYPE_F32 && bk->type == GGML_TYPE_F32 && bv->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_nrows(bq) == 1 && bq->ne[0] == wq->ne[1]);
GGML_ASSERT(ggml_nrows(bk) == 1 && bk->ne[0] == wk->ne[1]);
GGML_ASSERT(ggml_nrows(bv) == 1 && bv->ne[0] == wv->ne[1]);
layer.bqkv = ggml_new_tensor_1d(ctx_layer, bq->type, n_embd_head_k * (n_head + n_head_kv + n_head_kv));
snprintf(layer.bqkv->name, GGML_MAX_NAME, "blk.%d.attn_qkv.bias", i);
layer.bq = ml.create_tensor_as_view(ctx_layer, layer.bqkv, bq_name.c_str(), { bq->ne[0] }, 0);
layer.bk = ml.create_tensor_as_view(ctx_layer, layer.bqkv, bk_name.c_str(), { bk->ne[0] }, bq->ne[0]*bq->nb[0]);
layer.bv = ml.create_tensor_as_view(ctx_layer, layer.bqkv, bv_name.c_str(), { bv->ne[0] }, bq->ne[0]*bq->nb[0] + bk->ne[0]*bk->nb[0] );
}
}
}
if (!fused_qkv) {
if (ml.merge_qkv) {
printf("%s: did not merge Q, K, V in layer %d because %d, %d, %d\n", __func__, i,
wq->type == wk->type, wq->type == wv->type, hparams.f_attention_scale == 0.0f);
}
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
if (bias) {
auto flags = bias == 1 ? llama_model_loader::TENSOR_NOT_REQUIRED : 0;
layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {layer.wq->ne[1]}, flags);
layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {layer.wk->ne[1]}, flags);
layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {layer.wv->ne[1]}, flags);
}
}
return fused_qkv;
}
bool create_tensors_helper::create_tensors() {
const auto tn = LLM_TN(model.arch);
bool use_mmap_buffer = true;

View File

@@ -203,9 +203,10 @@ namespace GGUFMeta {
};
}
llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp,
const llama_model_kv_override * param_overrides_p,
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors,
bool repack_tensors, bool use_thp, bool merge_qkv,
const llama_model_kv_override * param_overrides_p,
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
int trace = 0;
if (getenv("LLAMA_TRACE")) {
trace = atoi(getenv("LLAMA_TRACE"));
@@ -495,6 +496,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
this->check_tensors = check_tensors;
this->repack_tensors = repack_tensors;
this->use_thp = use_thp;
this->merge_qkv = merge_qkv;
}
llama_model_loader::~llama_model_loader() {

View File

@@ -44,6 +44,7 @@ struct llama_model_loader {
bool check_tensors;
bool repack_tensors = false;
bool use_thp = false;
bool merge_qkv = false;
llama_files files;
llama_ftype ftype;
@@ -78,7 +79,7 @@ struct llama_model_loader {
std::string arch_name;
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp,
llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp, bool merge_qkv,
const llama_model_kv_override * param_overrides_p,
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p);

View File

@@ -1007,7 +1007,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
kv_overrides = v->data();
}
llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false, /* use_thp */ false, kv_overrides, nullptr);
llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false,
/* use_thp */ false, /* merge_qkv */ false, kv_overrides, nullptr);
ml.init_mappings(false); // no prefetching
llama_model model;

View File

@@ -1684,7 +1684,7 @@ static bool llm_load_tensors(
throw std::runtime_error("model has expert layers but no expert layers are used");
}
cth->create_tensors();
use_mmap_buffer = cth->create_tensors();
ml.done_getting_tensors();
@@ -1896,7 +1896,7 @@ static bool llm_load_tensors(
static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
try {
llama_model_loader ml(fname, params.use_mmap, params.check_tensors,
params.repack_tensors, params.use_thp, params.kv_overrides, params.tensor_buft_overrides);
params.repack_tensors, params.use_thp, params.merge_qkv, params.kv_overrides, params.tensor_buft_overrides);
model.hparams.vocab_only = params.vocab_only;
@@ -3788,6 +3788,7 @@ struct llama_model_params llama_model_default_params() {
/*.repack_tensors =*/ false,
/*.use_thp =*/ false,
/*.validate_quants =*/ false,
/*.merge_qkv =*/ false,
};
#ifdef GGML_USE_METAL