diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index f12f87fc..eb39d04c 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -270,6 +270,7 @@ if (GGML_IQK_MUL_MAT) iqk/fa/iqk_fa_128_128.cpp iqk/fa/iqk_fa_96_96.cpp iqk/fa/iqk_fa_64_64.cpp + iqk/fa/iqk_fa_1088_1024.cpp iqk/iqk_gemm_floats.cpp iqk/iqk_gemm_kquants.cpp iqk/iqk_gemm_ktquants.cpp diff --git a/ggml/src/iqk/fa/iqk_fa_1088_1024.cpp b/ggml/src/iqk/fa/iqk_fa_1088_1024.cpp new file mode 100644 index 00000000..df3a3be6 --- /dev/null +++ b/ggml/src/iqk/fa/iqk_fa_1088_1024.cpp @@ -0,0 +1,121 @@ +#include "iqk/iqk_config.h" + +#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION + +#include "iqk/fa/iqk_fa_templates.h" + +namespace { + +template +inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, + int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, + const float * q, const char * mask, float scale, float softcap, float * qkv, + const float * sinkf, float * M, float * S) { + auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) { + nq1 -= n; + if (nq1 == 0) return true; + q += n*stride_q; + mask += n*stride_m; + qkv += n*stride_qkv; + if (M && S) { M += n; S += n; } + return false; + }; + if (nq1 >= 16) { + int n_step = nq1/16; + FlashAttn<1088, 1024, 16, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(16*n_step)) return; + } + if (nq1 >= 8) { + int n_step = nq1/8; + FlashAttn<1088, 1024, 8, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(8*n_step)) return; + } + if (nq1 >= 4) { + int n_step = nq1/4; + FlashAttn<1088, 1024, 4, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(4*n_step)) return; + } + if (nq1 >= 2) { + int n_step = nq1/2; + FlashAttn<1088, 1024, 2, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(2*n_step)) return; + } + FlashAttn<1088, 1024, 1, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); +} + +template +inline bool iqk_deepseek_helper(ggml_type type_k, + int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, + const float * q, const char * k, const char * v, const char * mask, + float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) { + if (type_k == GGML_TYPE_Q8_0) { + HelperQ80 kh((const char *)k, stride_k); + HelperQ80 vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); + return true; + } + if (type_k == GGML_TYPE_Q8_0_R8) { + HelperQ80R8<1088> kh((const char *)k, stride_k); + HelperQ80 vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); + return true; + } + if (type_k == GGML_TYPE_Q6_0) { + HelperQ60 kh((const char *)k, stride_k); + HelperQ60 vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); + return true; + } +#if GGML_IQK_FA_ALL_QUANTS + if (type_k == GGML_TYPE_Q8_KV) { + HelperQ8KV<1088> kh((const char *)k, stride_k); + HelperQ8KV<512> vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); + return true; + } +#endif + if (type_k == GGML_TYPE_F16) { + HelperF16 kh((const char *)k, stride_k); + HelperF16 vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); + return true; + } +#ifdef __AVX512BF16__ + if (type_k == GGML_TYPE_BF16) { + HelperBF16<1088, step_k> kh((const char *)k, stride_k); + HelperBF16<1024, step_k> vh((const char *)v, stride_v); + if (nq1 % 8 == 0) { + FlashAttnBF16<1088, 1024, 8, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + } else { + FlashAttnBF16<1088, 1024, 1, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + } + return true; + } +#endif + return false; +} + +} + +IQK_FA_CASE(iqk_fa_1088_1024) { + + auto type_k = ggml_type(int_type_k); + auto type_v = ggml_type(int_type_v); + + if (!(type_k == type_v || (type_k == GGML_TYPE_Q8_0_R8 && type_v == GGML_TYPE_Q8_0))) { + return false; + } + stride_q /= sizeof(float); // q stride as float + return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, sinkf, M, S); + +} + +#endif diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index 8e96844e..bca95c77 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -1231,7 +1231,7 @@ struct FlashQKV { template struct FlashQKfp32 { - static_assert(D%F16::block_size == 0 && D <= 576); + static_assert(D%F16::block_size == 0 && D <= 1088); static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -1523,8 +1523,8 @@ char * get_q_storage(size_t size) { // q_step-1 versions of these functions for us, which I though was too much with q_step = 8. template struct FlashAttn { - static_assert(Dk%F16::block_size == 0 && Dk <= 576); - static_assert(Dv%F16::block_size == 0 && Dv <= 512); + static_assert(Dk%F16::block_size == 0 && Dk <= 1088); + static_assert(Dv%F16::block_size == 0 && Dv <= 1024); static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -1635,7 +1635,7 @@ struct HelperBF16 final : public BaseHelper { template struct FlashQKbf16 { //static_assert(D%32 == 0 && D <= 256); - static_assert(D%32 == 0 && D <= 576); + static_assert(D%32 == 0 && D <= 1088); static_assert(k_step%32 == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -1947,8 +1947,8 @@ template struct FlashAttnBF16 { //static_assert(Dk%32 == 0 && Dk <= 256); //static_assert(Dv%32 == 0 && Dv <= 256); - static_assert(Dk%32 == 0 && Dk <= 576); - static_assert(Dv%32 == 0 && Dv <= 512); + static_assert(Dk%32 == 0 && Dk <= 1088); + static_assert(Dv%32 == 0 && Dv <= 1024); static_assert(k_step%32 == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -2240,6 +2240,7 @@ IQK_FA_CASE(iqk_fa_256_256); IQK_FA_CASE(iqk_fa_128_128); IQK_FA_CASE(iqk_fa_96_96); IQK_FA_CASE(iqk_fa_64_64); +IQK_FA_CASE(iqk_fa_1088_1024); #endif diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 35573e49..8c75112d 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1343,6 +1343,10 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k return iqk_fa_576_512(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, k, v, mask, scale, softcap, qkv, sinksf, M, S); } + if (Dk == 1088 && Dv == 1024) { + return iqk_fa_1088_1024(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, k, v, mask, scale, softcap, qkv, sinksf, M, S); + } if (Dk == 192 && Dv == 128) { return iqk_fa_192_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,