From 5dacb5355afb0c28967a201f4e2790938e32bc1a Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 23 Feb 2026 07:58:00 +0100 Subject: [PATCH] Graph parallel for Qwen3-Next (#1292) * WIP * This works, but is slower than split mode layer --- src/llama-build-context.cpp | 212 ++++++++++++++++++++++-------------- src/llama-build-context.h | 3 + src/llama-delta-net.cpp | 4 + src/llama-load-tensors.cpp | 37 ++++--- src/llama-model.h | 2 + src/llama.cpp | 60 +++++----- 6 files changed, 199 insertions(+), 119 deletions(-) diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 532143d5..c9209370 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -714,9 +714,6 @@ static inline ggml_tensor * do_split_norm(ggml_context * ctx, ggml_tensor * cur, if (the_norm && the_norm->extra) { auto norm = (ggml_split_tensor_t *)the_norm->extra; GGML_ASSERT(norm->splits[id]); - //if (cur->type != GGML_TYPE_F16 && cur->type != GGML_TYPE_F32) { - // cur = ggml_cast(ctx, cur, GGML_TYPE_F32); - //} if (is_norm) { cur = ggml_fused_norm(ctx, cur, norm->splits[id], hparams.f_norm_eps); } else { @@ -771,6 +768,9 @@ ggml_tensor * llm_build_context::llm_build_ffn( if (!split_u) continue; auto cur = get_input_tensor_sm_graph(ctx, input, id); cur = do_split_norm(ctx, cur, ffn_norm, lctx.model.hparams, cb, id, il_cb, is_norm); + if (input->op != GGML_OP_REDUCE) { + cur->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff; + } cur = ggml_fused_up_gate(ctx, split_u, split_g, cur, unary_op); cb(cur, "ffn_up_gate", il_cb); if (lctx.model.arch == LLM_ARCH_STEP35) { @@ -1327,13 +1327,29 @@ llm_expert_gating_func_type gating_op, (!split_up_shexp->splits[id] && !split_gate_shexp->splits[id] && !split_down_shexp->splits[id])); if (!split_up_shexp->splits[id]) continue; auto the_ffn_norm = ffn_norm ? ffn_norm->extra ? ((ggml_split_tensor_t *)ffn_norm->extra)->splits[id] : ffn_norm : nullptr; - auto shared_out = llm_build_ffn(ctx, lctx, the_ffn_norm, input, + auto this_input = input; + if (the_ffn_norm) { + this_input = llm_build_norm(ctx, input, lctx.model.hparams, the_ffn_norm, nullptr, LLM_NORM_RMS, cb, il); + } + auto shared_out = llm_build_ffn(ctx, lctx, nullptr, this_input, split_up_shexp->splits[id], split_up_b_shexp ? split_up_b_shexp->splits[id] : nullptr, nullptr, split_gate_shexp->splits[id], split_gate_b_shexp ? split_gate_b_shexp->splits[id] : nullptr, nullptr, split_down_shexp->splits[id], !down_bias_added && split_down_b_shexp ? split_down_b_shexp->splits[id] : nullptr, nullptr, nullptr, type_op_shexp, LLM_FFN_PAR, cb, il, graph, false, false, id == id_add_routed ? routed_out : nullptr); cb(shared_out, "ffn_shexp_out", il_cb); + if (shexp_gate) { + auto split_shexp_gate = (ggml_split_tensor_t *)shexp_gate->extra; + GGML_ASSERT(split_shexp_gate && split_shexp_gate->splits[id]); + auto gate = llm_build_lora_mm(lctx, ctx, split_shexp_gate->splits[id], this_input); + if (gate->ne[1] == 1) { + shared_out = ggml_fused_mul_unary(ctx, gate, shared_out, GGML_UNARY_OP_SIGMOID); + } else { + gate = ggml_sigmoid(ctx, gate); + shared_out = ggml_mul(ctx, shared_out, gate); + } + cb(shared_out, "ffn_shexp_gated", il_cb); + } if (shared_out->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) { shared_out = ggml_cast(ctx, shared_out, lctx.cparams.reduce_type); } @@ -1396,6 +1412,9 @@ llm_expert_gating_func_type gating_op, int il_cb = 1000*(id + 1) + il; auto cur = get_input_tensor_sm_graph(ctx, input, id); cur = do_split_norm(ctx, cur, ffn_norm, lctx.model.hparams, cb, id, il_cb, false); + if (cur->op != GGML_OP_REDUCE) { + cur->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff; + } GGML_ASSERT(!split_gate_inp_b || split_gate_inp_b->splits[id]); GGML_ASSERT(!split_exps_down_b || split_exps_down_b->splits[id]); GGML_ASSERT(!split_exps_gate_b || split_exps_gate_b->splits[id]); @@ -1421,6 +1440,18 @@ llm_expert_gating_func_type gating_op, split_down_shexp->splits[id], !down_bias_added && split_down_b_shexp ? split_down_b_shexp->splits[id] : nullptr, nullptr, nullptr, type_op_shexp, LLM_FFN_PAR, cb, il); cb(shared_out, "ffn_shexp_out", il_cb); + if (shexp_gate) { + auto split_shexp_gate = (ggml_split_tensor_t *)shexp_gate->extra; + GGML_ASSERT(split_shexp_gate && split_shexp_gate->splits[id]); + auto gate = llm_build_lora_mm(lctx, ctx, split_shexp_gate->splits[id], cur); + if (gate->ne[1] == 1) { + shared_out = ggml_fused_mul_unary(ctx, gate, shared_out, GGML_UNARY_OP_SIGMOID); + } else { + gate = ggml_sigmoid(ctx, gate); + shared_out = ggml_mul(ctx, shared_out, gate); + } + cb(shared_out, "ffn_shexp_gated", il_cb); + } cur = ggml_add(ctx, routed_out, shared_out); cb(cur, "ffn_out", il_cb); @@ -1770,6 +1801,38 @@ std::tuple llm_build_context::llm_buil return {Qcur, Kcur, Vcur}; } +std::tuple llm_build_context::llm_build_mul_mat_qkv_gated(ggml_cgraph * gf, ggml_tensor * cur, + ggml_tensor * wq, ggml_tensor * wk, ggml_tensor * wv, ggml_tensor * q_norm, ggml_tensor * k_norm, int il) const { + auto Qaux = llm_build_lora_mm(lctx, ctx0, wq, cur); + cb(Qaux, "Qaux", il); + auto Kcur = llm_build_lora_mm(lctx, ctx0, wk, cur); + cb(Kcur, "Kcur", il); + auto Vcur = llm_build_lora_mm(lctx, ctx0, wv, cur); + cb(Vcur, "Vcur", il); + ggml_build_forward_expand(gf, Qaux); + ggml_build_forward_expand(gf, Kcur); + ggml_build_forward_expand(gf, Vcur); + auto row_size = ggml_row_size(Qaux->type, n_embd_head_k); + // TODO: check why CUDA performance suffers so much if we don't make these two tensors contiguous + auto Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, Qaux, n_embd_head_k, Qaux->ne[0]/(2*n_embd_head_k), n_tokens, 2*row_size, Qaux->nb[1], 0)); + auto gate = ggml_cont_2d(ctx0, ggml_view_3d(ctx0, Qaux, n_embd_head_k, Qaux->ne[0]/(2*n_embd_head_k), n_tokens, 2*row_size, Qaux->nb[1], row_size), Qaux->ne[0]/2, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, Kcur->ne[0]/n_embd_head_k, n_tokens); + if (q_norm) { + Qcur = llm_build_norm(ctx0, Qcur, hparams, q_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur_normed", il); + ggml_build_forward_expand(gf, Qcur); + } + if (k_norm) { + Kcur = llm_build_norm(ctx0, Kcur, hparams, k_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur_normed", il); + ggml_build_forward_expand(gf, Kcur); + } + gate = ggml_sigmoid(ctx0, gate); + //gate = ggml_reshape_2d(ctx0, gate, gate->ne[0]*gate->ne[1], gate->ne[2]); + cb(gate, "gate", il); + return {Qcur, Kcur, Vcur, gate}; +} + std::tuple llm_build_context::llm_build_mul_mat_qkv(ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * wqkv, ggml_tensor * bqkv, ggml_tensor * wqk, ggml_tensor * bqk, @@ -4351,59 +4414,6 @@ ggml_cgraph * llm_build_context::build_qwen3next() { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - auto build_layer_attn = [&](ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * KQ_mask, int il) -> ggml_tensor * { - - auto Qaux = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - auto Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - auto Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Qaux, "Qaux", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - ggml_build_forward_expand(gf, Qaux); - ggml_build_forward_expand(gf, Kcur); - ggml_build_forward_expand(gf, Vcur); - - Qaux = ggml_reshape_3d(ctx0, Qaux, n_embd_head * 2, n_head, n_tokens); - auto Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, Qaux, n_embd_head, n_head, n_tokens, Qaux->nb[1], Qaux->nb[2], 0)); - auto gate = ggml_cont_2d(ctx0, ggml_view_3d(ctx0, Qaux, n_embd_head, n_head, n_tokens, Qaux->nb[1], Qaux->nb[2], n_embd_head*ggml_element_size(Qaux)), n_embd_head*n_head, n_tokens); - cb(Qcur, "Qcur", il); - cb(gate, "gate", il); - - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il); - cb(Qcur, "Qcur_normed", il); - - Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il); - cb(Kcur, "Kcur_normed", il); - - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); - - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); - - cb(Qcur, "Qcur_roped", il); - cb(Kcur, "Kcur_roped", il); - - ggml_tensor * attn = llm_build_kv(ctx0, lctx, kv_self, gf, nullptr, nullptr, - Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, - hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale, cb, il); - cb(attn, "attn_pregate", il); - - gate = ggml_sigmoid(ctx0, gate); - cb(gate, "gate_sigmoid", il); - attn = ggml_mul(ctx0, attn, gate); - cb(attn, "attn_gated", il); - - attn = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, attn); - cb(attn, "attn_output", il); - - return attn; - - }; - ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr; @@ -4425,6 +4435,8 @@ ggml_cgraph * llm_build_context::build_qwen3next() { ggml_build_forward_expand(gf, identity); ggml_build_forward_expand(gf, diag_mask); + float KQ_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + ggml_tensor * cur = nullptr; for (int il = 0; il < n_layer; ++il) { @@ -4444,23 +4456,29 @@ ggml_cgraph * llm_build_context::build_qwen3next() { GGML_ASSERT(model.layers[il].ffn_down_exps != nullptr); } - cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il); - cb(cur, "attn_norm", il); if (hparams.is_recurrent(il)) { + if (inpL->op == GGML_OP_REDUCE && inpL->src[model.default_layer_device[il]]) { + inpL->view_src = inpL->src[model.default_layer_device[il]]; + //printf("Using reduce result on device %d\n", model.default_layer_device[il]); + //inpL = inpL->src[model.default_layer_device[il]]; + } + auto norm = model.layers[il].attn_norm->extra ? ((ggml_split_tensor_t *)model.layers[il].attn_norm->extra)->splits[model.default_layer_device[il]] : model.layers[il].attn_norm; + cur = llm_build_norm(ctx0, inpL, hparams, norm, nullptr, LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); cur = delta.build_layer_attn_linear(ctx0, gf, cur, causal_mask, identity, diag_mask, il, cb); + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_residual", il); } else { - cur = build_layer_attn(cur, inp_pos, KQ_mask, il); + //cur = build_layer_attn(cur, inp_pos, KQ_mask, il); + cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, il == n_layer - 1 ? inp_out_ids : nullptr, nullptr, + KQ_mask, nullptr, nullptr, KQ_scale, 0.0f, 0, il, true, false, true, false, false); } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - - cur = ggml_add(ctx0, cur, inpSA); - cb(cur, "attn_residual", il); - if (!model.layers[il].ffn_gate_inp) { // dense FFN cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur, @@ -10093,11 +10111,19 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens ((ggml_split_tensor_t *)model.layers[il].attn_q_norm->extra)->splits[id] : model.layers[il].attn_q_norm : nullptr; auto the_k_norm = model.layers[il].attn_k_norm ? model.layers[il].attn_k_norm->extra ? ((ggml_split_tensor_t *)model.layers[il].attn_k_norm->extra)->splits[id] : model.layers[il].attn_k_norm : nullptr; - auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, nullptr, nullptr, nullptr, nullptr, - split_wq, bq ? bq->splits[id] : nullptr, - split_wk, bk ? bk->splits[id] : nullptr, - split_wv, bv ? bv->splits[id] : nullptr, - the_q_norm, the_k_norm, f_attn_scale, il, add_graph_split); + ggml_tensor *Qcur, *Kcur, *Vcur, *gate = nullptr; + if (model.arch == LLM_ARCH_QWEN3NEXT) { + auto [Q, K, V, G] = llm_build_mul_mat_qkv_gated(gf, cur, split_wq, split_wk, split_wv, + the_q_norm, the_k_norm, il); + Qcur = Q; Kcur = K; Vcur = V; gate = G; + } else { + auto [Q, K, V] = llm_build_mul_mat_qkv(gf, cur, nullptr, nullptr, nullptr, nullptr, + split_wq, bq ? bq->splits[id] : nullptr, + split_wk, bk ? bk->splits[id] : nullptr, + split_wv, bv ? bv->splits[id] : nullptr, + the_q_norm, the_k_norm, f_attn_scale, il, add_graph_split); + Qcur = Q; Kcur = K; Vcur = V; + } auto rope_factors = rope_factors_in; if (rope_factors) { GGML_ASSERT(rope_factors->extra); @@ -10224,6 +10250,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens cur = ggml_reshape_2d(ctx0, cur, split_wo->ne[0], n_tokens); cb(cur, "flash_attn_reshaped", il_cb); + if (gate) { + cur = ggml_mul(ctx0, cur, gate); + cb(cur, "qkv_gated", il_cb); + } if (inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); @@ -10272,11 +10302,19 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens } auto input_normed = cur; - auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, - model.layers[il].wqkv, model.layers[il].bqkv, - model.layers[il].wqk, model.layers[il].bqk, - model.layers[il].wq, model.layers[il].bq, model.layers[il].wk, model.layers[il].bk, model.layers[il].wv, model.layers[il].bv, - model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, f_attn_scale, il); + ggml_tensor *Qcur, *Kcur, *Vcur, *gate = nullptr; + if (model.arch == LLM_ARCH_QWEN3NEXT) { + auto [Q, K, V, G] = llm_build_mul_mat_qkv_gated(gf, cur, model.layers[il].wq, model.layers[il].wk, model.layers[il].wv, + model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, il); + Qcur = Q; Kcur = K; Vcur = V; gate = G; + } else { + auto [Q, K, V] = llm_build_mul_mat_qkv(gf, cur, + model.layers[il].wqkv, model.layers[il].bqkv, + model.layers[il].wqk, model.layers[il].bqk, + model.layers[il].wq, model.layers[il].bq, model.layers[il].wk, model.layers[il].bk, model.layers[il].wv, model.layers[il].bv, + model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, f_attn_scale, il); + Qcur = Q; Kcur = K; Vcur = V; + } if (do_rope) { if (is_multi) { @@ -10323,9 +10361,21 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens } cb(cur, "attn_out", il); } else { - cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa); + if (gate) { + cur = llm_build_kv(ctx0, lctx, kv_self, gf, nullptr, nullptr, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa); + cur = ggml_mul(ctx0, cur, gate); + cb(cur, "qkv_gated", il); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); + if (model.layers[il].bo) { + cur = ggml_add(ctx0, cur, model.layers[il].bo); + } + cb(cur, "attn_out", il); + } else { + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa); + } } if (inp_out_ids) { diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 0810605b..847dd733 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -162,6 +162,9 @@ struct llm_build_context { ggml_tensor * wv, ggml_tensor * bv, ggml_tensor * q_norm, ggml_tensor * k_norm, float attention_scale, int il, bool add_graph_split = false) const; + std::tuple llm_build_mul_mat_qkv_gated(ggml_cgraph * gf, ggml_tensor * cur, + ggml_tensor * wq, ggml_tensor * wk, ggml_tensor * wv, ggml_tensor * q_norm, ggml_tensor * k_norm, int il) const; + ggml_cgraph * build_llama(); ggml_cgraph * build_mistral3(); diff --git a/src/llama-delta-net.cpp b/src/llama-delta-net.cpp index f2fb9499..8dc52429 100644 --- a/src/llama-delta-net.cpp +++ b/src/llama-delta-net.cpp @@ -504,6 +504,8 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ } cb(beta, "beta", il); cb(alpha, "alpha", il); + ggml_build_forward_expand(gf, beta); + ggml_build_forward_expand(gf, alpha); ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); @@ -529,6 +531,7 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ } if (reset_state_local) { state_f32 = ggml_scale(ctx0, state_f32, 0.0f); + cb(state_f32, "state_reset", il); } ggml_tensor * conv_state_flat = ggml_view_2d(ctx0, state_f32, conv_state_dim, 1, state_f32->nb[1], 0); @@ -539,6 +542,7 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_state_flat, head_v_dim, head_v_dim, num_v_heads, 1); cb(conv_states, "conv_states", il); cb(state, "state_predelta", il); + ggml_build_forward_expand(gf, state); ggml_tensor * conv_output_raw = ggml_ssm_conv(ctx0, conv_states, qkv_mixed, model.layers[il].ssm_conv1d, inp_s_seq_qnext); cb(conv_output_raw, "conv_output_raw", il); diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index 7e60f156..659343b9 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -1328,8 +1328,8 @@ bool create_tensors_helper::create_qwen3next_tensors(const LLM_TN & tn) { auto & layer = model.layers[i]; - layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_post_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}); layer.ffn_norm = layer.attn_post_norm; if (!hparams.is_recurrent(i)) { @@ -1339,22 +1339,22 @@ bool create_tensors_helper::create_qwen3next_tensors(const LLM_TN & tn) { layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); - layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}); - layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}); + layer.attn_q_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}); + layer.attn_k_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}); } else { // Recurrent linear-attention layer - layer.ssm_in = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, qkvz_dim}, + layer.ssm_in = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, qkvz_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, key_dim * 2 + value_dim}, + layer.wqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, key_dim * 2 + value_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, value_dim}, + layer.wqkv_gate = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, value_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {hparams.ssm_d_conv, conv_dim}); + layer.ssm_conv1d = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {hparams.ssm_d_conv, conv_dim}); layer.ssm_dt = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {hparams.ssm_dt_rank}); layer.ssm_a = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_A_NOSCAN, i), {hparams.ssm_dt_rank}); - layer.ssm_beta_alpha = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), {n_embd, ba_dim}); + layer.ssm_beta_alpha = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), {n_embd, ba_dim}); layer.ssm_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_NORM, "weight", i), {head_v_dim}); - layer.ssm_out = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {value_dim, n_embd}); + layer.ssm_out = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_OUT, "weight", i), {value_dim, n_embd}); } auto ffn_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer; @@ -1378,7 +1378,7 @@ bool create_tensors_helper::create_qwen3next_tensors(const LLM_TN & tn) { } // Shared expert path (optional per-layer) - layer.ffn_gate_inp_shexp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_gate_inp_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); if (layer.ffn_gate_inp_shexp != nullptr) { layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, llama_model_loader::TENSOR_NOT_REQUIRED); @@ -3540,6 +3540,10 @@ bool create_tensors_helper::create_tensors() { } if (layer.wo && layer.wq && layer.wk && layer.wv) { auto granularity_kq = hparams.n_embd_head_k * gqa_ratio; + int wq_ne1 = layer.wq->ne[1]; + if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) { + granularity_kq *= 2; wq_ne1 /= 2; + } auto granularity_vo = hparams.n_embd_head_v * gqa_ratio; if (ggml_is_quantized(layer.wo->type)) { auto tt = ggml_internal_get_type_traits(layer.wo->type); @@ -3553,7 +3557,7 @@ bool create_tensors_helper::create_tensors() { LLAMA_LOG_DEBUG(" split_kq:"); for ([[maybe_unused]] auto s : split_kq) LLAMA_LOG_DEBUG(" %d", s); LLAMA_LOG_DEBUG("\n"); - if (layer.attn_q_norm && layer.attn_q_norm->ne[0] == layer.wq->ne[1]) { + if (layer.attn_q_norm && layer.attn_q_norm->ne[0] == wq_ne1) { // If RMS norm is not applied per attention head, as it is usually the case, but is applied to the // entire Q tensor (e.g., MiniMax-2), we need to have a copy of the entire wq and attn_q_norm tensors // on each participating GPU. @@ -3593,7 +3597,11 @@ bool create_tensors_helper::create_tensors() { LLAMA_LOG_DEBUG("\n"); prepare_split_tensors(1, ctx_split, layer.wqkv_gate, layer.split_wqkv_gate, wqkv_gate_split, mem_used); } - for (auto & s : split_kq) s /= gqa_ratio; + if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) { + for (auto & s : split_kq) s /= 2*gqa_ratio; + } else { + for (auto & s : split_kq) s /= gqa_ratio; + } for (auto & s : split_vo) s /= gqa_ratio; if (layer.attn_k_norm && layer.attn_k_norm->ne[0] == layer.wk->ne[1]) { // If RMS norm is not applied per attention head, as it is usually the case, but is applied to the @@ -3717,6 +3725,9 @@ bool create_tensors_helper::create_tensors() { prepare_split_tensors(0, ctx_split, layer.ffn_down_shexp, layer.split_ffn_down_shexp, split, mem_used); prepare_split_tensors(1, ctx_split, layer.ffn_up_shexp, layer.split_ffn_up_shexp, split, mem_used); prepare_split_tensors(1, ctx_split, layer.ffn_gate_shexp, layer.split_ffn_gate_shexp, split, mem_used); + if (layer.ffn_gate_inp_shexp) { + prepare_split_tensors(-1, ctx_split, layer.ffn_gate_inp_shexp, layer.split_ffn_gate_inp_shexp, split, mem_used); + } } } diff --git a/src/llama-model.h b/src/llama-model.h index 0d27d937..e256a8a7 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -267,6 +267,7 @@ struct llama_layer { llama_split_tensor split_ffn_up_shexp; llama_split_tensor split_ffn_gate_shexp; llama_split_tensor split_ffn_down_shexp; + llama_split_tensor split_ffn_gate_inp_shexp; llama_split_tensor split_ffn_gate_inp_b; llama_split_tensor split_ffn_gate_exps_b; @@ -378,6 +379,7 @@ struct llama_model { std::vector rpc_servers; std::vector devices; + std::vector default_layer_device; // gguf metadata std::unordered_map gguf_kv; diff --git a/src/llama.cpp b/src/llama.cpp index 477a505f..d9a5a709 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -572,7 +572,7 @@ bool llama_context::update_cache_copies() { const int n_layer = model.mtp ? model.hparams.n_layer : model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2; auto layer_has_attention_kv = [&](int il) { - return !((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && model.hparams.is_recurrent(il)); + return !model.hparams.is_recurrent(il); }; if ((int)kv_self.k_l.size() != n_layer) return false; if (!(kv_self.v_l.empty() || (int)kv_self.v_l.size() == n_layer)) return false; @@ -661,11 +661,8 @@ llama_context::~llama_context() { // kv cache helpers // -static inline bool llama_qwen3next_is_recurrent_layer( - const llama_model & model, - const llama_hparams & hparams, - uint32_t il) { - return (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && hparams.is_recurrent(il); +static inline bool llama_qwen3next_is_recurrent_layer(const llama_hparams & hparams, uint32_t il) { + return hparams.is_recurrent(il); } static inline uint32_t llama_kv_v_row_embd( @@ -767,7 +764,7 @@ static bool llama_kv_cache_init( std::map buft_layer_count; if (offload) { for (int64_t i = 0; i < n_layer; ++i) { - if (split_cache) { + if (split_cache && !hparams.is_recurrent(i)) { buft_layer_count[model.buft_layer[i].buft_matrix]++; } else { buft_layer_count[model.buft_layer[i].buft]++; @@ -827,7 +824,7 @@ static bool llama_kv_cache_init( needs_v_cache = cparams.mla_attn == 1 && !cparams.flash_attn; } if (needs_v_cache) cache.v_l.reserve(n_layer); - cache.s_l.reserve(n_layer); + cache.s_l.resize(n_layer, nullptr); std::vector mem_split(model.splits.size(), 0); @@ -839,12 +836,12 @@ static bool llama_kv_cache_init( int n_mla = 0; for (int i = 0; i < (int) n_layer; i++) { - const bool qnext_recurrent = llama_qwen3next_is_recurrent_layer(model, hparams, i); + const bool qnext_recurrent = llama_qwen3next_is_recurrent_layer(hparams, i); const uint32_t n_embd_v_row = llama_kv_v_row_embd(model, hparams, i); const uint32_t n_head_kv = hparams.n_head_kv(i); const uint32_t n_embd_head_k= hparams.n_embd_head_k; - struct ggml_context * ctx = split_cache ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); + struct ggml_context * ctx = split_cache && !qnext_recurrent ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); ggml_tensor * k = nullptr; ggml_tensor * v = nullptr; ggml_tensor * s = nullptr; @@ -871,16 +868,21 @@ static bool llama_kv_cache_init( n_mla++; } else { + if (qnext_recurrent) { + s = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_embd_v_s(), qnext_state_slots); + auto s_name = std::string{"cache_s_l"} + std::to_string(i); + ggml_set_name(s, s_name.c_str()); + cache.s_l[i] = s; + cache.k_l.push_back(nullptr); + cache.v_l.push_back(nullptr); + continue; + } bool split_cache_i = split_cache; auto K = model.layers[i].wk; auto V = model.layers[i].wv; if (split_cache && (!K || !V || !K->extra || !V->extra)) { ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); split_cache_i = false; - } - if (qnext_recurrent) { - s = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_embd_v_s(), qnext_state_slots); - split_cache_i = false; } else { int n_embd_head_v = hparams.n_embd_head_v; k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size); @@ -934,14 +936,9 @@ static bool llama_kv_cache_init( v->extra = (void *)&split_v_l.ggml; } } - if (s) { - auto s_name = std::string{"cache_s_l"} + std::to_string(i); - ggml_set_name(s, s_name.c_str()); - } cache.k_l.push_back(k); cache.v_l.push_back(v); } - cache.s_l.push_back(s); } if (is_mla_attn && cparams.mla_attn && n_mla < n_layer && n_mla > 0) { LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer)); @@ -1926,6 +1923,7 @@ static bool is_model_split_supported(const llama_model & model) { LLM_ARCH_MINIMAX_M2, LLM_ARCH_SEED_OSS, LLM_ARCH_STEP35, + LLM_ARCH_QWEN3NEXT, }; auto it = k_supported.find(model.arch); return it != k_supported.end(); @@ -2015,18 +2013,30 @@ static bool llm_load_tensors( } int device_count = model.splits.size(); - // assign the repeating layers to the devices according to the splits + model.default_layer_device = std::vector(hparams.n_layer+1, device_count-1); int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1); - if (split_mode == LLAMA_SPLIT_MODE_LAYER) { - + if (device_count > 1) { for (int i = i_gpu_start; i < n_layer; ++i) { int layer_gpu = std::upper_bound(model.splits.begin(), model.splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - model.splits.begin(); - model.buft_layer[i] = llama_default_buffer_type_offload(model, model.devices[layer_gpu]); + model.default_layer_device[i] = model.devices[layer_gpu]; + } + if (n_gpu_layers > n_layer) { + int layer_gpu = std::upper_bound(model.splits.begin(), model.splits.begin() + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - model.splits.begin(); + model.default_layer_device[n_layer] = model.devices[layer_gpu]; + } + } + // assign the repeating layers to the devices according to the splits + if (split_mode == LLAMA_SPLIT_MODE_LAYER) { + for (int i = i_gpu_start; i < n_layer; ++i) { + model.buft_layer[i] = llama_default_buffer_type_offload(model, model.default_layer_device[i]); + //int layer_gpu = std::upper_bound(model.splits.begin(), model.splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - model.splits.begin(); + //model.buft_layer[i] = llama_default_buffer_type_offload(model, model.devices[layer_gpu]); } // assign the output layer if (n_gpu_layers > n_layer) { - int layer_gpu = std::upper_bound(model.splits.begin(), model.splits.begin() + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - model.splits.begin(); - model.buft_output = llama_default_buffer_type_offload(model, model.devices[layer_gpu]); + //int layer_gpu = std::upper_bound(model.splits.begin(), model.splits.begin() + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - model.splits.begin(); + //model.buft_output = llama_default_buffer_type_offload(model, model.devices[layer_gpu]); + model.buft_output = llama_default_buffer_type_offload(model, model.default_layer_device[n_layer]); } else { model.buft_output = llama_default_buffer_type_cpu(true); }