diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 277da085..1783b1a7 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1721,7 +1721,7 @@ ggml_cgraph * llm_build_context::build_llama() { // self-attention if (use_rope) { - cur = build_std_attention(gf, inpL, inp_pos, nullptr, this_KQ_mask, nullptr, kq_scale, hparams.f_attention_scale, this_n_swa, il); + cur = build_std_attention(gf, inpL, inp_pos, nullptr, this_KQ_mask, nullptr, nullptr, kq_scale, hparams.f_attention_scale, this_n_swa, il); } else { @@ -1905,57 +1905,15 @@ ggml_cgraph * llm_build_context::build_mistral3() { ggml_tensor * inp_out_ids = build_inp_out_ids(); - const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; - //const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : 1.f; - - // ==================================== - - //auto * inp_attn = build_attn_inp_kv(); + //const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : 1.f; for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; auto rope_factors = build_rope_factors(il); - // self-attention - if (!inp_attn_scale) { - cur = build_std_attention(gf, inpL, inp_pos, rope_factors, KQ_mask, nullptr, kq_scale, hparams.f_attention_scale, 0, il); - } - else { - - // norm - cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "attn_norm", il); - - auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, - model.layers[il].wqkv, model.layers[il].bqkv, - model.layers[il].wqk, model.layers[il].bqk, - model.layers[il].wq, model.layers[il].bq, - model.layers[il].wk, model.layers[il].bk, - model.layers[il].wv, model.layers[il].bv, - nullptr, nullptr, hparams.f_attention_scale, il); - - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - if (inp_attn_scale) { - Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale); - cb(Qcur, "Qcur_temp_scaled", il); - } - - cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il, nullptr, 0); - } + cur = build_std_attention(gf, inpL, inp_pos, rope_factors, KQ_mask, nullptr, inp_attn_scale, kq_scale, hparams.f_attention_scale, 0, il); if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); @@ -3947,7 +3905,7 @@ ggml_cgraph * llm_build_context::build_qwen3moe() { //cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); //cb(cur, "attn_norm", il); - cur = build_std_attention(gf, inpL, inp_pos, nullptr, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0, il); + cur = build_std_attention(gf, inpL, inp_pos, nullptr, KQ_mask, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0, il); if (il == n_layer - 1) { // skip computing output for unused tokens @@ -6826,7 +6784,7 @@ ggml_cgraph * llm_build_context::build_glm4_moe() { // self-attention if (rope_cache == nullptr) { - cur = build_std_attention(gf, inpL, inp_pos, nullptr, KQ_mask, nullptr, kq_scale, 0.0f, 0, il); + cur = build_std_attention(gf, inpL, inp_pos, nullptr, KQ_mask, nullptr, nullptr, kq_scale, 0.0f, 0, il); } else { // Pre-attention norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); @@ -9329,7 +9287,7 @@ ggml_cgraph * llm_build_context::llama_build_graph( } ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tensor * input, ggml_tensor * inp_pos, ggml_tensor * rope_factors_in, - ggml_tensor * KQ_mask, ggml_tensor * sinks, float KQ_scale, float f_attn_scale, int n_swa, int il) { + ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale, int n_swa, int il) { if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn && model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) { if (kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) { @@ -9400,6 +9358,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il_cb); cb(Kcur, "Kcur", il_cb); + if (inp_attn_scale) { + Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale); + cb(Qcur, "Qcur_temp_scaled", il_cb); + } ggml_build_forward_expand(gf, Qcur); ggml_build_forward_expand(gf, Kcur); ggml_build_forward_expand(gf, Vcur); @@ -9528,6 +9490,11 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); + if (inp_attn_scale) { + Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale); + cb(Qcur, "Qcur_temp_scaled", il); + } + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa); diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 41597248..cbf12817 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -407,6 +407,6 @@ llm_expert_gating_func_type gating_op, static ggml_cgraph * llama_build_graph(llama_context & lctx, const llama_batch & batch, bool worst_case); ggml_tensor * build_std_attention(ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * rope_factors, - ggml_tensor * KQ_mask, ggml_tensor * sinks, float KQ_scale, float f_attn_scale, int n_swa, int il); + ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale, int n_swa, int il); }; diff --git a/src/llama.cpp b/src/llama.cpp index 1a0480bb..2124b180 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1728,6 +1728,7 @@ static bool is_model_split_supported(const llama_model & model) { LLM_ARCH_LLAMA, LLM_ARCH_QWEN3MOE, LLM_ARCH_GLM4_MOE, + LLM_ARCH_MISTRAL3, }; auto it = k_supported.find(model.arch); return it != k_supported.end();