diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 07761628..c5168568 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -17237,25 +17237,73 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv) { // v*softmax(scale*(k*q)) + if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32 + auto type_k = ggml_type(int_type_k); auto type_v = ggml_type(int_type_v); if (type_k == GGML_TYPE_Q8_0 && type_v == GGML_TYPE_Q8_0 && Dk == 576 && Dv == 512) { - //printf("Using DeepSeek FA with nq1 = %d, nk1 = %d\n", (int)nq1, (int)nk1); - HelperQ80<576, 32> kh((const char *)k, stride_k); - HelperQ80<512, 32> vh((const char *)v, stride_v); - if (nq1 % 8 == 0) { - FlashAttn<576, 512, 8, 32> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - } else { - FlashAttn<576, 512, 1, 32> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + // This is a DeepSeek model with MLA. In that case we only have one cache (and k and v are different views of the cache), + // so type_k must be the same as type_v + GGML_ASSERT(type_k == type_v); + switch (type_k) { + case GGML_TYPE_Q8_0: { + HelperQ80<576, 32> kh((const char *)k, stride_k); + HelperQ80<512, 32> vh((const char *)v, stride_v); + if (nq1 % 8 == 0) { + FlashAttn<576, 512, 8, 32> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } else { + FlashAttn<576, 512, 1, 32> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } + return true; + } break; + // Something is wrong with Q8_KV in this case. + //case GGML_TYPE_Q8_KV: { + // HelperQ8KV<576, 32> kh((const char *)k, stride_k); + // HelperQ8KV<512, 32> vh((const char *)v, stride_v); + // if (nq1 % 8 == 0) { + // FlashAttn<576, 512, 8, 32> fa(scale, softcap); + // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + // } else { + // FlashAttn<576, 512, 1, 32> fa(scale, softcap); + // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + // } + // return true; + //} break; + case GGML_TYPE_F16: { + HelperF16<576, 32> kh((const char *)k, stride_k); + HelperF16<512, 32> vh((const char *)v, stride_v); + if (nq1 % 8 == 0) { + FlashAttn<576, 512, 8, 32> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } else { + FlashAttn<576, 512, 1, 32> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } + return true; + } break; +#ifdef __AVX512BF16__ + case GGML_TYPE_BF16: { + HelperBF16<576, 32> kh((const char *)k, stride_k); + HelperBF16<512, 32> vh((const char *)v, stride_v); + if (nq1 % 8 == 0) { + FlashAttn<576, 512, 8, 32> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } else { + FlashAttn<576, 512, 1, 32> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } + return true; + } break; +#endif + default: break; } - return true; + return false; } if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false; - if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32 if (Dk != Dv && Dk != 192 && Dv != 128) return false; if (Dv != 64 && Dv != 96 && Dv != 128 && Dv != 256) return false; if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dv != 256) return false;