Fix LLaMA-4 attention

This commit is contained in:
Iwan Kawrakow
2025-04-24 13:59:19 +03:00
parent 9dac3edf2f
commit 6250937c49

View File

@@ -9974,7 +9974,12 @@ struct llm_build_context {
}
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
//bool is_swa = hparams.n_swa > 0 && h_params.n_swa_pattern > 0 ?
ggml_tensor * KQ_mask = build_inp_KQ_mask();
ggml_tensor * KQ_mask_swa = nullptr;
if (hparams.n_swa > 0 && hparams.n_swa_pattern > 0) {
KQ_mask_swa = build_inp_KQ_mask_swa();
}
//const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : 1.f;
@@ -9982,6 +9987,8 @@ struct llm_build_context {
struct ggml_tensor * inpSA = inpL;
bool use_rope = model.arch == LLM_ARCH_LLAMA4 ? (il + 1) % hparams.n_no_rope_layer_step != 0 : true;
auto this_KQ_mask = hparams.n_swa > 0 && hparams.n_swa_pattern > 0 && il % hparams.n_swa_pattern < (hparams.n_swa_pattern - 1) ?
KQ_mask_swa : KQ_mask;
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
@@ -10046,7 +10053,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
}
if (il == n_layer - 1) {