diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 30b545d3..ce868514 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -8139,7 +8139,8 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k stride_q /= sizeof(float); // q stride as float #ifdef __AVX512BF16__ - if (type_k == GGML_TYPE_BF16 && type_v == GGML_TYPE_BF16) { + if (type_k == GGML_TYPE_BF16 || type_v == GGML_TYPE_BF16) { + if (type_k != GGML_TYPE_BF16 || type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 with other types switch (D) { case 64: iqk_flash_helper_T< 64, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;