diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 5ca916b7..6317db6c 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -17219,6 +17219,77 @@ inline bool flash_attn_is_supported(ggml_type type) { #endif return false; } + +template +inline bool iqk_deepseek_helper(ggml_type type_k, + int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, + const float * q, const char * k, const char * v, const char * mask, + float scale, float softcap, float * qkv) { + if (type_k == GGML_TYPE_Q8_0) { + HelperQ80<576, step_k> kh((const char *)k, stride_k); + HelperQ80<512, step_k> vh((const char *)v, stride_v); + if (nq1 % 8 == 0) { + FlashAttn<576, 512, 8, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } else { + FlashAttn<576, 512, 1, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } + return true; + } + if (type_k == GGML_TYPE_Q6_0) { + HelperQ60<576, step_k> kh((const char *)k, stride_k); + HelperQ60<512, step_k> vh((const char *)v, stride_v); + if (nq1 % 8 == 0) { + FlashAttn<576, 512, 8, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } else { + FlashAttn<576, 512, 1, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } + return true; + } + if (type_k == GGML_TYPE_Q8_KV) { + HelperQ8KV<576, step_k> kh((const char *)k, stride_k); + HelperQ8KV<512, step_k> vh((const char *)v, stride_v); + if (nq1 % 8 == 0) { + FlashAttn<576, 512, 8, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } else { + FlashAttn<576, 512, 1, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } + return true; + } + if (type_k == GGML_TYPE_F16) { + HelperF16<576, step_k> kh((const char *)k, stride_k); + HelperF16<512, step_k> vh((const char *)v, stride_v); + if (nq1 % 8 == 0) { + FlashAttn<576, 512, 8, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } else { + FlashAttn<576, 512, 1, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } + return true; + } +#ifdef __AVX512BF16__ + if (type_k == GGML_TYPE_BF16) { + HelperBF16<576, step_k> kh((const char *)k, stride_k); + HelperBF16<512, step_k> vh((const char *)v, stride_v); + if (nq1 % 8 == 0) { + FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } else { + FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } + return true; + } +#endif + return false; +} + } bool iqk_flash_attn_noalibi(int int_type_k, // type of k @@ -17246,80 +17317,10 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k auto type_v = ggml_type(int_type_v); if (Dk == 576 && Dv == 512) { - //if (type_k == GGML_TYPE_Q8_0 && type_v == GGML_TYPE_Q8_0 && Dk == 576 && Dv == 512) { - // This is a DeepSeek model with MLA. In that case we only have one cache (and k and v are different views of the cache), - // so type_k must be the same as type_v GGML_ASSERT(type_k == type_v); stride_q /= sizeof(float); // q stride as float - switch (type_k) { - case GGML_TYPE_Q8_0: { - //printf("%s: nk1 = %d, nq1 = %d, k = %p, v = %p, stride_k = %d stride_v = %d, stride_m = %d\n", __func__, nk1, nq1, k, v, stride_k, stride_v, stride_m); - HelperQ80<576, 32> kh((const char *)k, stride_k); - HelperQ80<512, 32> vh((const char *)v, stride_v); - if (nq1 % 8 == 0) { - FlashAttn<576, 512, 8, 32> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - } else { - FlashAttn<576, 512, 1, 32> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - } - return true; - } break; - case GGML_TYPE_Q6_0: { - //printf("%s: nk1 = %d, nq1 = %d, k = %p, v = %p, stride_k = %d stride_v = %d, stride_m = %d\n", __func__, nk1, nq1, k, v, stride_k, stride_v, stride_m); - HelperQ60<576, 32> kh((const char *)k, stride_k); - HelperQ60<512, 32> vh((const char *)v, stride_v); - if (nq1 % 8 == 0) { - FlashAttn<576, 512, 8, 32> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - } else { - FlashAttn<576, 512, 1, 32> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - } - return true; - } break; - // Something is wrong with Q8_KV in this case. - case GGML_TYPE_Q8_KV: { - HelperQ8KV<576, 32> kh((const char *)k, stride_k); - HelperQ8KV<512, 32> vh((const char *)v, stride_v); - if (nq1 % 8 == 0) { - FlashAttn<576, 512, 8, 32> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - } else { - FlashAttn<576, 512, 1, 32> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - } - return true; - } break; - case GGML_TYPE_F16: { - HelperF16<576, 32> kh((const char *)k, stride_k); - HelperF16<512, 32> vh((const char *)v, stride_v); - if (nq1 % 8 == 0) { - FlashAttn<576, 512, 8, 32> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - } else { - FlashAttn<576, 512, 1, 32> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - } - return true; - } break; -#ifdef __AVX512BF16__ - case GGML_TYPE_BF16: { - HelperBF16<576, 32> kh((const char *)k, stride_k); - HelperBF16<512, 32> vh((const char *)v, stride_v); - if (nq1 % 8 == 0) { - FlashAttnBF16<576, 512, 8, 32> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - } else { - FlashAttnBF16<576, 512, 1, 32> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - } - return true; - } break; -#endif - default: break; - } - return false; + return iqk_deepseek_helper<32>(type_k, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv); } if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false;