From 996e77047a8eed8ddd81778a04a59b971d7000a6 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 20 Jan 2026 15:38:21 +0200 Subject: [PATCH] Avoid ggml_get_rows if not necessary (#1160) * Copy reduce result to other GPUs if necessary * Avoid ggml_get_rows for TG * For the output ops use the result of the split that ran on the main GPU * More models --- ggml/src/ggml-backend.cpp | 4 +- src/llama-build-context.cpp | 124 +++++++++++++++++++++--------------- src/llama-build-context.h | 3 +- src/llama-load-tensors.cpp | 104 +++++++++++++++--------------- src/llama.cpp | 6 +- 5 files changed, 132 insertions(+), 109 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index c05bf566..d99d7022 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -2244,7 +2244,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } } - if (split->graph.nodes[0]->op == GGML_OP_REDUCE) { + if (split->graph.nodes[0]->op == GGML_OP_REDUCE && i < sched->n_splits - 1) { last_reduce = split_backend_id; if (ith == split_backend_id) { auto node = split->graph.nodes[0]; @@ -2318,7 +2318,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } } - if (split->graph.nodes[0]->op == GGML_OP_REDUCE) { + if (split->graph.nodes[0]->op == GGML_OP_REDUCE && i < sched->n_splits - 1) { last_reduce = split_backend_id; barrier.arrive_and_wait(); if (ith == split_backend_id) { diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 1aad8c39..f3cb2875 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1759,7 +1759,8 @@ static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml return cur; } -static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur, ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb) { +static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur, + ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb) { // lm_head if (output->extra) { auto split_output = (ggml_split_tensor_t *)output->extra; @@ -1790,6 +1791,10 @@ static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml } } } else { + if (cur->op == GGML_OP_REDUCE && cur->src[lctx.model.main_gpu]) { + // avoid copy to main GPU + cur->view_src = cur->src[lctx.model.main_gpu]; + } if (output_norm) { cur = llm_build_context::llm_build_norm(ctx, cur, lctx.model.hparams, output_norm, NULL, LLM_NORM_RMS, cb, -1); cb(cur, "result_norm", -1); @@ -1830,6 +1835,8 @@ ggml_cgraph * llm_build_context::build_llama() { KQ_mask_swa = build_inp_KQ_mask_swa(); } + auto inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr; + //const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : 1.f; for (int il = 0; il < n_layer; ++il) { @@ -1845,7 +1852,7 @@ ggml_cgraph * llm_build_context::build_llama() { // self-attention if (use_rope) { - cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, + cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, il == n_layer - 1 ? inp_out_ids : nullptr, nullptr, this_KQ_mask, nullptr, nullptr, kq_scale, hparams.f_attention_scale, this_n_swa, il, true, false, true); } else { @@ -1895,16 +1902,14 @@ ggml_cgraph * llm_build_context::build_llama() { } //printf("%s: attn result for layer %d is %s, %s\n", __func__, il, cur->name, ggml_op_name(cur->op)); - if (il == n_layer - 1) { + if (il == n_layer - 1 && !use_rope && inp_out_ids) { // skip computing output for unused tokens - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + auto inp_out_ids = build_inp_out_ids(); n_tokens = n_outputs; cur = ggml_get_rows(ctx0, cur, inp_out_ids); cb(cur, "last_attn", il); - if (!use_rope) { - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - cb(inpSA, "last_ffn_inp", il); - } + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + cb(inpSA, "last_ffn_inp", il); } // For Granite architecture @@ -2047,7 +2052,7 @@ ggml_cgraph * llm_build_context::build_mistral3() { auto rope_factors = build_rope_factors(il); - cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, rope_factors, KQ_mask, + cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, rope_factors, KQ_mask, nullptr, inp_attn_scale, kq_scale, hparams.f_attention_scale, 0, il); if (il == n_layer - 1 && inp_out_ids) { @@ -3943,12 +3948,14 @@ ggml_cgraph * llm_build_context::build_qwen3() { ext_factor, attn_factor, beta_fast, beta_slow); } + auto inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr; + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; if (!rope_cache) { - cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask, nullptr, nullptr, - 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0, il, true, false, true); + cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, il == n_layer-1 ? inp_out_ids : nullptr, nullptr, + KQ_mask, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0, il, true, false, true); } else { // norm @@ -3986,7 +3993,7 @@ ggml_cgraph * llm_build_context::build_qwen3() { } } - if (il == n_layer - 1) { + if (il == n_layer - 1 && rope_cache && inp_out_ids) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); cur = ggml_get_rows(ctx0, cur, inp_out_ids); @@ -4034,27 +4041,18 @@ ggml_cgraph * llm_build_context::build_qwen3moe() { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + ggml_tensor * inp_out_ids = nullptr; //build_inp_out_ids(); + for (int il = 0; il < n_layer; ++il) { - //struct ggml_tensor * inpSA = inpL; - // norm - //cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); - //cb(cur, "attn_norm", il); - - cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0, - il, true, false, true); - //printf("%s: attn = %s(%s)\n", __func__, cur->name, ggml_op_name(cur->op)); - - if (il == n_layer - 1) { - // skip computing output for unused tokens - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - //inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + if (il == n_layer - 1 && n_tokens > 1) { + inp_out_ids = build_inp_out_ids(); } + cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, inp_out_ids, nullptr, + KQ_mask, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0, il, true, false, true); + auto ffn_inp = cur; - //struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - //cb(ffn_inp, "ffn_inp", il); cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_gate_inp, nullptr, @@ -4071,9 +4069,6 @@ ggml_cgraph * llm_build_context::build_qwen3moe() { LLM_FFN_SILU, cb, il, gf, true, model.layers[il].ffn_up_gate_exps); - //printf("%s: ffn = %s(%s)\n", __func__, cur->name, ggml_op_name(cur->op)); - - //cur = ggml_add(ctx0, cur, ffn_inp); cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); @@ -4130,10 +4125,10 @@ ggml_cgraph * llm_build_context::build_qwen3vl() { for (int il = 0; il < n_layer; ++il) { - cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask, + cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, nullptr, KQ_mask, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0, il, true, false, true, false, true); - if (il == n_layer - 1) { + if (il == n_layer - 1 && n_tokens > 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); cur = ggml_get_rows(ctx0, cur, inp_out_ids); @@ -6851,7 +6846,7 @@ ggml_cgraph * llm_build_context::build_glm4_moe() { struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); // output token IDs (for last layer cropping) - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + struct ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr; auto rope_cache = model.split_mode != LLAMA_SPLIT_MODE_GRAPH && cparams.rope_cache && (rope_type == LLAMA_ROPE_TYPE_NEOX || rope_type == LLAMA_ROPE_TYPE_NORM) ? ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -6867,7 +6862,8 @@ ggml_cgraph * llm_build_context::build_glm4_moe() { // self-attention if (rope_cache == nullptr) { - cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask, nullptr, nullptr, kq_scale, 0.0f, 0, il, true, false, true); + cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, nullptr, + KQ_mask, nullptr, nullptr, kq_scale, 0.0f, 0, il, true, false, true); } else { // Pre-attention norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); @@ -6907,7 +6903,9 @@ ggml_cgraph * llm_build_context::build_glm4_moe() { if (il == n_transformer_layers - 1 && inp_out_ids) { // skip computing output for unused tokens cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + if (rope_cache) { + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } } // residual connection for attention output @@ -7256,11 +7254,12 @@ ggml_cgraph * llm_build_context::build_cohere2() { struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask; // self-attention - auto attn_out = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask_l, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), 0.f, + auto attn_out = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, nullptr, + KQ_mask_l, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), 0.f, is_sliding ? hparams.n_swa : 0, il, is_sliding, false, true, true); cb(attn_out, "attn_out", il); - if (il == n_layer - 1) { + if (il == n_layer - 1 && n_tokens > 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); attn_out = ggml_get_rows(ctx0, attn_out, inp_out_ids); @@ -8196,12 +8195,12 @@ ggml_cgraph * llm_build_context::build_ernie4_5_moe() { struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); // output token IDs (for last layer cropping) - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + struct ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr; GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Ernie 4.5 MoE requires n_moe_layer_step > 0"); for (int il = 0; il < n_layer; ++il) { - cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask, nullptr, nullptr, + cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, nullptr, KQ_mask, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0, il, true, false, true); if (il == n_layer - 1 && inp_out_ids) { @@ -8269,11 +8268,11 @@ ggml_cgraph * llm_build_context::build_hunyuan_moe() { const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); - ggml_tensor * inp_out_ids = build_inp_out_ids(); + ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr; for (int il = 0; il < n_layer; ++il) { - cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask, + cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, nullptr, KQ_mask, nullptr, nullptr, kq_scale, 0.0f, 0, il, true, false, true); if (il == n_layer - 1 && inp_out_ids) { @@ -8324,7 +8323,7 @@ ggml_cgraph * llm_build_context::build_mimo2() { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + struct ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr; // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); @@ -8334,10 +8333,11 @@ ggml_cgraph * llm_build_context::build_mimo2() { const bool is_sliding = model.hparams.swa_layers[il]; auto KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask; - cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask_l, model.layers[il].attn_sinks, + cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, nullptr, + KQ_mask_l, model.layers[il].attn_sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), 0.0f, is_sliding ? hparams.n_swa : 0, il, true, false, true); - if (il == n_layer - 1) { + if (il == n_layer - 1 && inp_out_ids) { // skip computing output for unused tokens cur = ggml_get_rows(ctx0, cur, inp_out_ids); } @@ -8397,6 +8397,7 @@ ggml_cgraph * llm_build_context::build_openai_moe() { inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); ggml_tensor * inp_pos = build_inp_pos(); + auto inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr; struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(); @@ -8409,14 +8410,13 @@ ggml_cgraph * llm_build_context::build_openai_moe() { struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask; - cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask_l, - model.layers[il].attn_sinks, nullptr, kq_scale, 0.0f, is_sliding ? hparams.n_swa : 0, il, true, false, true); + cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, il == n_layer - 1 ? inp_out_ids : nullptr, nullptr, + KQ_mask_l, model.layers[il].attn_sinks, nullptr, kq_scale, 0.0f, is_sliding ? hparams.n_swa : 0, il, true, false, true); - if (il == n_layer - 1) { - // skip computing output for unused tokens - ggml_tensor * inp_out_ids = build_inp_out_ids(); - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - } + //if (il == n_layer - 1 && inp_out_ids) { + // // skip computing output for unused tokens + // cur = ggml_get_rows(ctx0, cur, inp_out_ids); + //} bool use_dup_bias = cur->ne[1] < 32 && model.layers[il].ffn_up_exps_b_dup && model.layers[il].ffn_gate_exps_b_dup && @@ -9176,7 +9176,7 @@ ggml_cgraph * llm_build_context::llama_build_graph( } ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tensor * the_attn_norm, - ggml_tensor * input, ggml_tensor * inp_pos, ggml_tensor * rope_factors_in, + ggml_tensor * input, ggml_tensor * inp_pos, ggml_tensor * inp_out_ids, ggml_tensor * rope_factors_in, ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale, int n_swa, int il, bool do_rope, bool add_graph_split, bool add_input, bool is_norm, bool is_multi) { @@ -9353,6 +9353,11 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens cur = ggml_reshape_2d(ctx0, cur, split_wo->ne[0], n_tokens); cb(cur, "flash_attn_reshaped", il_cb); + if (inp_out_ids) { // && ggml_nrows(inp_out_ids) > 1) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + cb(cur, "fa_get_rows", il_cb); + } + cur = llm_build_lora_mm(lctx, ctx0, split_wo, cur); if (lctx.model.arch == LLM_ARCH_GLM4 || lctx.model.arch == LLM_ARCH_GLM4_MOE) { // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators @@ -9373,6 +9378,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens } GGML_ASSERT(id_last >= 0); if (add_input) { + if (inp_out_ids) { // && ggml_nrows(inp_out_ids) > 1) { + input = ggml_get_rows(ctx0, input, inp_out_ids); + cb(input, "sainp_get_rows", il); + } attn[id_last] = ggml_add(ctx0, attn[id_last], input); cb(attn[id_last], "attn_out_with_input", il); } @@ -9424,6 +9433,15 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa); + if (inp_out_ids) { // && ggml_nrows(inp_out_ids) > 1) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + cb(cur, "sa_get_rows", il); + if (add_input) { + input = ggml_get_rows(ctx0, input, inp_out_ids); + cb(input, "sainp_get_rows", il); + } + } + if (add_input) { cb(cur, "attn_out", il); cur = ggml_add(ctx0, cur, input); diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 2cf36ece..65b8e9f8 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -414,7 +414,8 @@ llm_expert_gating_func_type gating_op, static ggml_cgraph * llama_build_graph(llama_context & lctx, const llama_batch & batch, bool worst_case); - ggml_tensor * build_std_attention(ggml_cgraph * gf, ggml_tensor * attn_norm, ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * rope_factors, + ggml_tensor * build_std_attention(ggml_cgraph * gf, ggml_tensor * attn_norm, ggml_tensor * cur, + ggml_tensor * inp_pos, ggml_tensor * inp_out_ids, ggml_tensor * rope_factors, ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale, int n_swa, int il, bool do_rope = true, bool add_graph_split = false, bool add_input = false, bool is_norm = false, bool is_multi = false); diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index 9e691634..ba33948b 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -380,7 +380,7 @@ void create_tensors_helper::create_std_ffn(int i, const LLM_TN & tn, llama_layer bool create_tensors_helper::create_llama_tensors(const LLM_TN & tn) { LOADING_PRELUDE - create_embd_output(tn, n_embd, n_vocab, true, false); //true); + create_embd_output(tn, n_embd, n_vocab, true); for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -678,9 +678,9 @@ bool create_tensors_helper::create_falcon_tensors(const LLM_TN & tn) { model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); if (!model.output) { - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU } } @@ -712,12 +712,12 @@ bool create_tensors_helper::create_starcoder_tensors(const LLM_TN & tn) { // output { - model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); if (!model.output) { // needs to be on GPU - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -860,9 +860,9 @@ bool create_tensors_helper::create_bloom_tensors(const LLM_TN & tn) { // output { - model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -902,9 +902,9 @@ bool create_tensors_helper::create_mpt_tensors(const LLM_TN & tn) { model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); if (!model.output) { - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU } } @@ -1010,9 +1010,9 @@ bool create_tensors_helper::create_qwen2_tensors(const LLM_TN & tn) { // output { - model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); - model.output_b = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { @@ -1133,11 +1133,11 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) { // output { - model.output_norm = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -1177,7 +1177,7 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) { bool create_tensors_helper::create_mimo2_tensors(const LLM_TN & tn) { LOADING_PRELUDE - create_embd_output(tn, n_embd, n_vocab, true, false); //true); + create_embd_output(tn, n_embd, n_vocab, true); for (int i = 0; i < n_layer; ++i) { uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); @@ -1224,10 +1224,10 @@ bool create_tensors_helper::create_phi2_tensors(const LLM_TN & tn) { // output { - model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); - model.output_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -1300,9 +1300,9 @@ bool create_tensors_helper::create_gpt2_tensors(const LLM_TN & tn) { // output { - model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -1366,11 +1366,11 @@ bool create_tensors_helper::create_codeshell_tensors(const LLM_TN & tn) { void create_tensors_helper::create_default_embd_output(const LLM_TN & tn, int n_embd, int n_vocab, bool norm_bias) { model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); if (norm_bias) { - model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); } - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } bool create_tensors_helper::create_orion_tensors(const LLM_TN & tn) { @@ -1473,7 +1473,7 @@ bool create_tensors_helper::create_starcoder2_tensors(const LLM_TN & tn) { model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); @@ -1531,10 +1531,10 @@ bool create_tensors_helper::create_mamba_tensors(const LLM_TN & tn) { { model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed, duplicated to allow offloading if (model.output == NULL) { - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -1684,9 +1684,9 @@ bool create_tensors_helper::create_gptneox_tensors(const LLM_TN & tn) { // output { - model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -1762,8 +1762,8 @@ bool create_tensors_helper::create_deepseek2_tensors(const LLM_TN & tn) { // output { - model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -1862,7 +1862,7 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) { GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers"); GGML_ASSERT(hparams.n_expert_used > 0 && "n_expert_used must be > 0 for GLM4_MOE MoE layers"); - create_embd_output(tn, n_embd, n_vocab, true, false); //true); + create_embd_output(tn, n_embd, n_vocab, true); for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -2019,8 +2019,8 @@ bool create_tensors_helper::create_bitnet2_tensors(const LLM_TN & tn) { // output { - model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { @@ -2110,7 +2110,7 @@ bool create_tensors_helper::create_t5_tensors(const LLM_TN & tn) { model.output_norm_enc = create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); @@ -2170,7 +2170,7 @@ bool create_tensors_helper::create_tsencoder_tensors(const LLM_TN & tn) { // output { model.output_norm_enc = create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); @@ -2205,9 +2205,9 @@ bool create_tensors_helper::create_jais_tensors(const LLM_TN & tn) { // Output { - model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -2246,8 +2246,8 @@ bool create_tensors_helper::create_chatglm_tensors(const LLM_TN & tn) { // output { - model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -2275,7 +2275,7 @@ bool create_tensors_helper::create_chatglm_tensors(const LLM_TN & tn) { bool create_tensors_helper::create_cohere2_tensors(const LLM_TN & tn) { LOADING_PRELUDE - create_embd_output(tn, n_embd, n_vocab, true, false); //true); + create_embd_output(tn, n_embd, n_vocab, true); for (int i = 0; i < n_layer; ++i) { auto & layer = model.layers[i]; @@ -2295,10 +2295,10 @@ bool create_tensors_helper::create_glm4_tensors(const LLM_TN & tn) { // output model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } for (int i = 0; i < n_layer; ++i) { @@ -2341,7 +2341,7 @@ bool create_tensors_helper::create_dots1_tensors(const LLM_TN & tn) { model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); for (int i = 0; i < n_layer; ++i) { auto & layer = model.layers[i]; ggml_context * ctx_layer = ctx_for_layer(i); @@ -2391,7 +2391,7 @@ bool create_tensors_helper::create_bailingmoe2_tensors(const LLM_TN & tn) { // output model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2"); GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2"); diff --git a/src/llama.cpp b/src/llama.cpp index 488dcadc..c6b87eb0 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2316,9 +2316,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { #if IK_PRINT_TIMING == 2 auto tim1 = ggml_time_us(); #endif - GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); const int64_t n_tokens = batch.n_tokens; + if (n_tokens > 1) { + GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); + } + if (lctx.inp_out_ids && lctx.inp_out_ids->buffer) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer)); int32_t * data = (int32_t *) lctx.inp_out_ids->data; @@ -2341,6 +2344,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } else { GGML_ASSERT(lctx.n_outputs == 0); } + } #if IK_PRINT_TIMING == 2 auto tim2 = ggml_time_us(); printf("set_inputs(outputs): %d us\n", int(tim2-tim1));