Use standard attention for Ministral3 (#1032)

Required adding the "temperature scaling" to the standard attention
implementation.

But in this way split mode "graph" is automatically supported.

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-12-03 13:43:31 +01:00
committed by GitHub
parent 7fbe8d3ac2
commit 90f36eb517
3 changed files with 18 additions and 50 deletions

View File

@@ -1721,7 +1721,7 @@ ggml_cgraph * llm_build_context::build_llama() {
// self-attention
if (use_rope) {
cur = build_std_attention(gf, inpL, inp_pos, nullptr, this_KQ_mask, nullptr, kq_scale, hparams.f_attention_scale, this_n_swa, il);
cur = build_std_attention(gf, inpL, inp_pos, nullptr, this_KQ_mask, nullptr, nullptr, kq_scale, hparams.f_attention_scale, this_n_swa, il);
}
else {
@@ -1905,57 +1905,15 @@ ggml_cgraph * llm_build_context::build_mistral3() {
ggml_tensor * inp_out_ids = build_inp_out_ids();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
//const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : 1.f;
// ====================================
//auto * inp_attn = build_attn_inp_kv();
//const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : 1.f;
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
auto rope_factors = build_rope_factors(il);
// self-attention
if (!inp_attn_scale) {
cur = build_std_attention(gf, inpL, inp_pos, rope_factors, KQ_mask, nullptr, kq_scale, hparams.f_attention_scale, 0, il);
}
else {
// norm
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wqk, model.layers[il].bqk,
model.layers[il].wq, model.layers[il].bq,
model.layers[il].wk, model.layers[il].bk,
model.layers[il].wv, model.layers[il].bv,
nullptr, nullptr, hparams.f_attention_scale, il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
if (inp_attn_scale) {
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
cb(Qcur, "Qcur_temp_scaled", il);
}
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il, nullptr, 0);
}
cur = build_std_attention(gf, inpL, inp_pos, rope_factors, KQ_mask, nullptr, inp_attn_scale, kq_scale, hparams.f_attention_scale, 0, il);
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
@@ -3947,7 +3905,7 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
//cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
//cb(cur, "attn_norm", il);
cur = build_std_attention(gf, inpL, inp_pos, nullptr, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0, il);
cur = build_std_attention(gf, inpL, inp_pos, nullptr, KQ_mask, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0, il);
if (il == n_layer - 1) {
// skip computing output for unused tokens
@@ -6826,7 +6784,7 @@ ggml_cgraph * llm_build_context::build_glm4_moe() {
// self-attention
if (rope_cache == nullptr) {
cur = build_std_attention(gf, inpL, inp_pos, nullptr, KQ_mask, nullptr, kq_scale, 0.0f, 0, il);
cur = build_std_attention(gf, inpL, inp_pos, nullptr, KQ_mask, nullptr, nullptr, kq_scale, 0.0f, 0, il);
} else {
// Pre-attention norm
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
@@ -9329,7 +9287,7 @@ ggml_cgraph * llm_build_context::llama_build_graph(
}
ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tensor * input, ggml_tensor * inp_pos, ggml_tensor * rope_factors_in,
ggml_tensor * KQ_mask, ggml_tensor * sinks, float KQ_scale, float f_attn_scale, int n_swa, int il) {
ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale, int n_swa, int il) {
if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn &&
model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) {
if (kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) {
@@ -9400,6 +9358,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il_cb);
cb(Kcur, "Kcur", il_cb);
if (inp_attn_scale) {
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
cb(Qcur, "Qcur_temp_scaled", il_cb);
}
ggml_build_forward_expand(gf, Qcur);
ggml_build_forward_expand(gf, Kcur);
ggml_build_forward_expand(gf, Vcur);
@@ -9528,6 +9490,11 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
if (inp_attn_scale) {
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
cb(Qcur, "Qcur_temp_scaled", il);
}
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa);

View File

@@ -407,6 +407,6 @@ llm_expert_gating_func_type gating_op,
static ggml_cgraph * llama_build_graph(llama_context & lctx, const llama_batch & batch, bool worst_case);
ggml_tensor * build_std_attention(ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * rope_factors,
ggml_tensor * KQ_mask, ggml_tensor * sinks, float KQ_scale, float f_attn_scale, int n_swa, int il);
ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale, int n_swa, int il);
};

View File

@@ -1728,6 +1728,7 @@ static bool is_model_split_supported(const llama_model & model) {
LLM_ARCH_LLAMA,
LLM_ARCH_QWEN3MOE,
LLM_ARCH_GLM4_MOE,
LLM_ARCH_MISTRAL3,
};
auto it = k_supported.find(model.arch);
return it != k_supported.end();