diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 2de13259..4615605c 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1241,7 +1241,13 @@ llm_expert_gating_func_type gating_op, GGML_ASSERT(split_gate_inp && split_gate_inp->n_device == split_up_exps->n_device); auto split_exp_probs_b = exp_probs_b ? (ggml_split_tensor_t *)exp_probs_b->extra : nullptr; GGML_ASSERT(!split_exp_probs_b || split_exp_probs_b->n_device == split_up_exps->n_device); + + auto split_gate_inp_b = gate_inp_b ? (ggml_split_tensor_t *)gate_inp_b->extra : nullptr; + auto split_exps_down_b = down_exps_b ? (ggml_split_tensor_t *)down_exps_b->extra : nullptr; + auto split_exps_gate_b = gate_exps_b ? (ggml_split_tensor_t *)gate_exps_b->extra : nullptr; + auto split_exps_up_b = up_exps_b ? (ggml_split_tensor_t *)up_exps_b->extra : nullptr; int last_id = -1; + bool down_bias_added = false; for (int id = 0; id < split_up_exps->n_device; ++id) { GGML_ASSERT((split_up_exps->splits[id] && split_gate_exps->splits[id] && split_down_exps->splits[id]) || (!split_up_exps->splits[id] && !split_gate_exps->splits[id] && !split_down_exps->splits[id])); @@ -1257,11 +1263,15 @@ llm_expert_gating_func_type gating_op, if (cur->type != GGML_TYPE_F32) { cur = ggml_cast(ctx, cur, GGML_TYPE_F32); } + GGML_ASSERT(!split_gate_inp_b || split_gate_inp_b->splits[id]); + GGML_ASSERT(!split_exps_down_b || split_exps_down_b->splits[id]); + GGML_ASSERT(!split_exps_gate_b || split_exps_gate_b->splits[id]); + GGML_ASSERT(!split_exps_up_b || split_exps_up_b->splits[id]); auto routed_out = llm_build_moe_ffn(ctx, lctx, cur, - split_gate_inp->splits[id], gate_inp_b, - split_up_exps->splits[id], up_exps_b, - split_gate_exps->splits[id], gate_exps_b, - split_down_exps->splits[id], down_exps_b, + split_gate_inp->splits[id], split_gate_inp_b ? split_gate_inp_b->splits[id] : nullptr, + split_up_exps->splits[id], split_exps_up_b ? split_exps_up_b->splits[id] : nullptr, + split_gate_exps->splits[id], split_exps_gate_b ? split_exps_gate_b->splits[id] : nullptr, + split_down_exps->splits[id], !down_bias_added && split_exps_down_b ? split_exps_down_b->splits[id] : nullptr, split_exp_probs_b ? split_exp_probs_b->splits[id] : nullptr, n_expert, n_expert_used, type_op, norm_w, scale_w, w_scale, @@ -1275,7 +1285,7 @@ llm_expert_gating_func_type gating_op, auto shared_out = llm_build_ffn(ctx, lctx, nullptr, cur, split_up_shexp->splits[id], split_up_b_shexp ? split_up_b_shexp->splits[id] : nullptr, nullptr, split_gate_shexp->splits[id], split_gate_b_shexp ? split_gate_b_shexp->splits[id] : nullptr, nullptr, - split_down_shexp->splits[id], split_down_b_shexp ? split_down_b_shexp->splits[id] : nullptr, nullptr, + split_down_shexp->splits[id], !down_bias_added && split_down_b_shexp ? split_down_b_shexp->splits[id] : nullptr, nullptr, nullptr, type_op_shexp, LLM_FFN_PAR, cb, il); cb(shared_out, "ffn_shexp_out", il_cb); @@ -1291,6 +1301,7 @@ llm_expert_gating_func_type gating_op, ggml_build_forward_expand(graph, cur); results[id] = cur; last_id = id; + down_bias_added = true; } GGML_ASSERT(last_id >= 0); if (add_input) { @@ -8449,95 +8460,43 @@ ggml_cgraph * llm_build_context::build_openai_moe() { inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); - // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(); - //const int64_t n_embd_head = hparams.n_embd_head_v; - const float kq_scale = 1.0f / sqrtf(float(n_rot)); //float(n_embd_head)); - - //auto * inp_attn = build_attn_inp_kv_unified_iswa(); + const float kq_scale = 1.0f / sqrtf(float(n_rot)); const int sliding_window_pattern = 2; - auto rope_cache = cparams.rope_cache && (rope_type == LLAMA_ROPE_TYPE_NEOX || rope_type == LLAMA_ROPE_TYPE_NORM) ? - ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow) : nullptr; - for (int il = 0; il < n_layer; ++il) { const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1); - ggml_tensor * inpSA = inpL; struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask; - // norm - cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il); - cb(cur, "attn_norm", il); - - // self-attention - { - auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, - model.layers[il].wqkv, model.layers[il].bqkv, - model.layers[il].wqk, model.layers[il].bqk, - model.layers[il].wq, model.layers[il].bq, - model.layers[il].wk, model.layers[il].bk, - model.layers[il].wv, model.layers[il].bv, - nullptr, nullptr, 0.0f, il); - - if (rope_cache) { - Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache); - Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache); - } else { - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - } - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - - cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks, - is_sliding ? hparams.n_swa : 0); - - cb(cur, "attn_out", il); - } + cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask_l, + model.layers[il].attn_sinks, nullptr, kq_scale, 0.0f, is_sliding ? hparams.n_swa : 0, il, true, false, true); if (il == n_layer - 1) { // skip computing output for unused tokens ggml_tensor * inp_out_ids = build_inp_out_ids(); cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - cur = ffn_inp; - cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il); - ggml_build_forward_expand(gf, cur); - cb(cur, "attn_post_norm", il); - bool use_dup_bias = cur->ne[1] < 32 && model.layers[il].ffn_up_exps_b_dup && model.layers[il].ffn_gate_exps_b_dup && model.layers[il].ffn_down_exps_b_dup; - // MoE branch - cur = llm_build_moe_ffn(ctx0, lctx, cur, + cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b, model.layers[il].ffn_up_exps, use_dup_bias ? model.layers[il].ffn_up_exps_b_dup : model.layers[il].ffn_up_exps_b, model.layers[il].ffn_gate_exps, use_dup_bias ? model.layers[il].ffn_gate_exps_b_dup : model.layers[il].ffn_gate_exps_b, model.layers[il].ffn_down_exps, use_dup_bias ? model.layers[il].ffn_down_exps_b_dup : model.layers[il].ffn_down_exps_b, nullptr, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, // no shared experts n_expert, n_expert_used, - LLM_FFN_SWIGLU_OAI_MOE, false, - false, 0.0, + LLM_FFN_SWIGLU_OAI_MOE, false, false, 0.0f, LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT, - cb, il, gf); - cb(cur, "ffn_moe_out", il); - - cur = ggml_add(ctx0, cur, ffn_inp); + LLM_FFN_SWIGLU_OAI_MOE, cb, il, gf, true); cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); @@ -8546,18 +8505,13 @@ ggml_cgraph * llm_build_context::build_openai_moe() { inpL = cur; } - cur = inpL; - - cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, nullptr, LLM_NORM_RMS, cb, -1); - cb(cur, "result_norm", -1); - - cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); - + cur = build_output(lctx, ctx0, inpL, model.output, model.output_norm, cb); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); return gf; + } ggml_cgraph * llm_build_context::build_bailingmoe2() { @@ -9323,6 +9277,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens } std::vector attn(wq->n_device, nullptr); int id_last = -1; + bool output_bias_added = false; for (int id = 0; id < wq->n_device; ++id) { int il_cb = 1000*(id+1) + il; auto split_wq = wq->splits[id]; @@ -9477,9 +9432,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } cb(cur, "kqv_wo", il_cb); - if (bo) { + if (!output_bias_added && bo) { cur = ggml_add(ctx0, cur, bo->splits[id]); cb(cur, "kqv_wo_biased", il_cb); + output_bias_added = true; } if (cur->ne[1] > 32 && lctx.cparams.split_mode_f16) { cur = ggml_cast(ctx0, cur, GGML_TYPE_F16); diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index 344976da..bc7e70e7 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -2574,32 +2574,31 @@ bool create_tensors_helper::create_openai_moe_tensors(const LLM_TN & tn) { // output model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); for (int i = 0; i < n_layer; ++i) { - ggml_context * ctx_layer = ctx_for_layer(i); ggml_context * ctx_split = ctx_for_layer_split(i); auto & layer = model.layers[i]; - layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); use_mmap_buffer &= !merge_qkv(tn, i, 2); layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); - layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.bo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - layer.attn_sinks = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0); + layer.attn_sinks = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0); ggml_context *ctx_ffn_gate, *ctx_ffn_up, *ctx_ffn_down; - layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); + layer.ffn_gate_inp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0, &ctx_ffn_gate); layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0, &ctx_ffn_down); layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0, &ctx_ffn_up); // bias ggml_context *ctx_ffn_gate_b, *ctx_ffn_up_b, *ctx_ffn_down_b; - layer.ffn_gate_inp_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); + layer.ffn_gate_inp_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); layer.ffn_gate_exps_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0, &ctx_ffn_gate_b); layer.ffn_down_exps_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), { n_embd, n_expert}, 0, &ctx_ffn_down_b); layer.ffn_up_exps_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0, &ctx_ffn_up_b); @@ -3149,6 +3148,18 @@ bool create_tensors_helper::create_tensors() { prepare_split_tensors(0, ctx_split, layer.ffn_down_exps, layer.split_ffn_down_exps, split, mem_used); prepare_split_tensors(1, ctx_split, layer.ffn_up_exps, layer.split_ffn_up_exps, split, mem_used); prepare_split_tensors(1, ctx_split, layer.ffn_gate_exps, layer.split_ffn_gate_exps, split, mem_used); + if (layer.ffn_down_exps_b) { + prepare_split_tensors(-1, ctx_split, layer.ffn_down_exps_b, layer.split_ffn_down_exps_b, split, mem_used); + } + if (layer.ffn_up_exps_b) { + prepare_split_tensors( 0, ctx_split, layer.ffn_up_exps_b, layer.split_ffn_up_exps_b, split, mem_used); + } + if (layer.ffn_gate_exps_b) { + prepare_split_tensors( 0, ctx_split, layer.ffn_gate_exps_b, layer.split_ffn_gate_exps_b, split, mem_used); + } + if (layer.ffn_gate_inp_b) { + prepare_split_tensors(-1, ctx_split, layer.ffn_gate_inp_b, layer.split_ffn_gate_inp_b, split, mem_used); + } } } diff --git a/src/llama.cpp b/src/llama.cpp index d637f712..ce619ee2 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1734,6 +1734,7 @@ static bool is_model_split_supported(const llama_model & model) { LLM_ARCH_QWEN3, LLM_ARCH_QWEN3VL, LLM_ARCH_HUNYUAN_MOE, + LLM_ARCH_OPENAI_MOE, }; auto it = k_supported.find(model.arch); return it != k_supported.end(); @@ -4432,6 +4433,15 @@ struct llama_context * llama_new_context_with_model( //LLAMA_LOG_WARN("=====================================================================\n"); cparams.mla_attn = 0; } + if (model->arch == LLM_ARCH_OPENAI_MOE && model->split_mode == LLAMA_SPLIT_MODE_GRAPH) { + if (cparams.split_mode_f16) { + LLAMA_LOG_WARN("=====================================================================\n"); + LLAMA_LOG_WARN("GPT-OSS with split mode graph requires f32 precision\n"); + LLAMA_LOG_WARN(" => changing cparams.split_mode_f16 to 'false'\n"); + LLAMA_LOG_WARN("=====================================================================\n"); + cparams.split_mode_f16 = false; + } + } LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);