Try using fp32 for FlashMLA

This commit is contained in:
Iwan Kawrakow
2025-03-10 19:07:53 +02:00
parent a48e163247
commit e0eebfd8ad

View File

@@ -13677,9 +13677,9 @@ struct llm_build_context {
ggml_build_forward_expand(gf, q);
kqv = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
if (q->ne[1] <= 8) {
//if (q->ne[1] <= 8) {
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
}
//}
cb(kqv, "kqv", il);
cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens);