From 9541631a52f5447f0c10f2bb6d04ac7f5dba5f5a Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 19 May 2025 15:20:43 +0300 Subject: [PATCH] Most helpers don't need to be templates Also hide Q4_0 and Q8_KV behind IQK_FA_ALL_QUANTS. Compilation time drops to 14 second on the Ryzen-5975WX --- ggml/src/iqk/fa/iqk_fa_576_512.cpp | 22 +-- ggml/src/iqk/fa/iqk_fa_templates.h | 235 ++++++++++++++--------------- ggml/src/iqk/iqk_gemm_kquants.cpp | 2 +- 3 files changed, 129 insertions(+), 130 deletions(-) diff --git a/ggml/src/iqk/fa/iqk_fa_576_512.cpp b/ggml/src/iqk/fa/iqk_fa_576_512.cpp index e8acff18..5174be30 100644 --- a/ggml/src/iqk/fa/iqk_fa_576_512.cpp +++ b/ggml/src/iqk/fa/iqk_fa_576_512.cpp @@ -53,32 +53,34 @@ inline bool iqk_deepseek_helper(ggml_type type_k, const float * q, const char * k, const char * v, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { if (type_k == GGML_TYPE_Q8_0) { - HelperQ80<576, step_k> kh((const char *)k, stride_k); - HelperQ80<512, step_k> vh((const char *)v, stride_v); + HelperQ80 kh((const char *)k, stride_k); + HelperQ80 vh((const char *)v, stride_v); iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); return true; } if (type_k == GGML_TYPE_Q8_0_R8) { - HelperQ80R8<576, step_k> kh((const char *)k, stride_k); - HelperQ80<512, step_k> vh((const char *)v, stride_v); + HelperQ80R8<576> kh((const char *)k, stride_k); + HelperQ80 vh((const char *)v, stride_v); iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); return true; } if (type_k == GGML_TYPE_Q6_0) { - HelperQ60<576, step_k> kh((const char *)k, stride_k); - HelperQ60<512, step_k> vh((const char *)v, stride_v); + HelperQ60 kh((const char *)k, stride_k); + HelperQ60 vh((const char *)v, stride_v); iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); return true; } +#if GGML_IQK_FA_ALL_QUANTS if (type_k == GGML_TYPE_Q8_KV) { - HelperQ8KV<576, step_k> kh((const char *)k, stride_k); - HelperQ8KV<512, step_k> vh((const char *)v, stride_v); + HelperQ8KV<576> kh((const char *)k, stride_k); + HelperQ8KV<512> vh((const char *)v, stride_v); iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); return true; } +#endif if (type_k == GGML_TYPE_F16) { - HelperF16<576, step_k> kh((const char *)k, stride_k); - HelperF16<512, step_k> vh((const char *)v, stride_v); + HelperF16 kh((const char *)k, stride_k); + HelperF16 vh((const char *)v, stride_v); iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); return true; } diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index a766d737..a8b19401 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -32,13 +32,12 @@ namespace { -template struct BaseHelper { BaseHelper(const char * data, int stride) : data(data), block(data), stride(stride) {} //inline void set_block(int k1) { block = data + k1*k_step*stride; } inline void reset_block() { block = data; } - inline void next_block() { block += k_step*stride; } + inline void next_block(int step) { block += step*stride; } inline const char * lblock(int l1) const { return block + l1*stride; } const char * data; @@ -177,27 +176,16 @@ struct F16 { } }; -template -struct HelperF16 final : public BaseHelper { - using Base = BaseHelper; +struct HelperF16 final : public BaseHelper { + using Base = BaseHelper; HelperF16(const char * data, int stride) : Base(data, stride) {} - inline void load(int l1, F16::Data * vk) const { - auto dr = Base::lblock(l1); - for (int i = 0; i < D/F16::block_size; ++i) vk[i] = F16::load(dr, i); - } - inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { //auto dr = (const ggml_half *)Base::lblock(l1); auto dr = Base::lblock(l1); v1 = F16::load(dr, i + 0); v2 = F16::load(dr, i + 1); } - - inline void load_2(int l1, F16::Data* vk) const { - load(l1+0, vk+0); - load(l1+1, vk+D/16); - } }; template struct block_q8_KV { @@ -206,9 +194,9 @@ template struct block_q8_KV { int8_t qs[D]; }; -template -struct HelperQ8KV final : public BaseHelper { - using Base = BaseHelper; +template +struct HelperQ8KV final : public BaseHelper { + using Base = BaseHelper; using block_q8 = block_q8_KV; constexpr static ggml_type type = GGML_TYPE_Q8_KV; constexpr static int block_size_q = D; @@ -235,9 +223,8 @@ struct HelperQ8KV final : public BaseHelper { } }; -template -struct HelperQ80 final : public BaseHelper { - using Base = BaseHelper; +struct HelperQ80 final : public BaseHelper { + using Base = BaseHelper; constexpr static ggml_type type = GGML_TYPE_Q8_0; #ifdef HAVE_FANCY_SIMD using block_q8 = block_q8_2; @@ -271,8 +258,8 @@ struct HelperQ80 final : public BaseHelper { #endif } + template static inline void convert(int nq, int stride_q, const float * q, block_q8_0 * y) { - //GGML_ASSERT(nq <= step); Why did I have this assert? for (int i = 0; i < nq; ++i) { quantize_row_q8_0_x4(q, y, D); q += stride_q; @@ -280,8 +267,8 @@ struct HelperQ80 final : public BaseHelper { } } + template static inline void convert(int nq, int stride_q, const float * q, block_q8_1 * y) { - //GGML_ASSERT(nq <= step); Why did I have this assert? for (int i = 0; i < nq; ++i) { quantize_row_q8_1_x4(q, y, D); q += stride_q; @@ -289,8 +276,8 @@ struct HelperQ80 final : public BaseHelper { } } + template static inline void convert(int nq, int stride_q, const float * q, block_q8_2 * y) { - //GGML_ASSERT(nq <= step); Why did I have this assert? for (int i = 0; i < nq; ++i) { quantize_row_q8_2_x4(q, y, D); q += stride_q; @@ -298,6 +285,7 @@ struct HelperQ80 final : public BaseHelper { } } + template static inline void convert(int nq, int stride_q, const float * q, block_q8_KV * y) { for (int i = 0; i < nq; ++i) { quantize_row_q8_KV(q, y, D); @@ -307,9 +295,9 @@ struct HelperQ80 final : public BaseHelper { } }; -template -struct HelperQ80R8 : public BaseHelper { - using Base = BaseHelper; +template +struct HelperQ80R8 : public BaseHelper { + using Base = BaseHelper; constexpr static ggml_type type = GGML_TYPE_Q8_0_R8; #ifdef __AVX2__ constexpr static int block_size_q = QK8_2; @@ -319,7 +307,7 @@ struct HelperQ80R8 : public BaseHelper { using block_q8 = block_q8_0; #endif HelperQ80R8(const char * data, int stride) : Base(data, stride) {} - HelperQ80R8(int nk, const HelperQ80& q8) : Base(q8.data, q8.stride) { + HelperQ80R8(int nk, const HelperQ80& q8) : Base(q8.data, q8.stride) { r4 = repack(nk, q8); Base::data = (const char *)r4.data(); Base::stride = (D/QK8_0)*sizeof(block_q8_0); @@ -416,7 +404,7 @@ struct HelperQ80R8 : public BaseHelper { } } - static std::vector repack(int nk, const HelperQ80& q8) { + static std::vector repack(int nk, const HelperQ80& q8) { static_assert(D%QK8_0 == 0); GGML_ASSERT(nk%8 == 0); constexpr int nblock = D/QK8_0; @@ -430,9 +418,9 @@ struct HelperQ80R8 : public BaseHelper { }; // TODO: unite this with the above -template -struct HelperQ8KVR8 : public BaseHelper { - using Base = BaseHelper; +template +struct HelperQ8KVR8 : public BaseHelper { + using Base = BaseHelper; constexpr static ggml_type type = GGML_TYPE_Q8_KV_R8; constexpr static int block_size_q = D; using block_q8 = block_q8_KV; @@ -442,13 +430,13 @@ struct HelperQ8KVR8 : public BaseHelper { int8_t qs[8*D]; }; - HelperQ8KVR8(int nk, const HelperQ8KV& q8) : Base(q8.data, q8.stride) { + HelperQ8KVR8(int nk, const HelperQ8KV& q8) : Base(q8.data, q8.stride) { r4 = repack(nk, q8); Base::data = (const char *)r4.data(); Base::stride = sizeof(block_q8_KV_r8)/8; } - static std::vector repack(int nk, const HelperQ8KV& q8) { + static std::vector repack(int nk, const HelperQ8KV& q8) { static_assert(D%32 == 0); GGML_ASSERT(nk%8 == 0); std::vector result(nk/8); @@ -526,9 +514,8 @@ struct HelperQ8KVR8 : public BaseHelper { std::vector r4; }; -template -struct HelperQ40 final : public BaseHelper { - using Base = BaseHelper; +struct HelperQ40 final : public BaseHelper { + using Base = BaseHelper; constexpr static ggml_type type = GGML_TYPE_Q4_0; #if defined __AVX2__ using block_q8 = block_q8_2; @@ -576,9 +563,8 @@ struct HelperQ40 final : public BaseHelper { #endif }; -template -struct HelperQ41 final : public BaseHelper { - using Base = BaseHelper; +struct HelperQ41 final : public BaseHelper { + using Base = BaseHelper; using block_q8 = block_q8_2; constexpr static ggml_type type = GGML_TYPE_Q4_1; constexpr static int block_size_q = QK8_2; @@ -620,9 +606,8 @@ struct HelperQ41 final : public BaseHelper { #endif }; -template -struct HelperIQ4nl final : public BaseHelper { - using Base = BaseHelper; +struct HelperIQ4nl final : public BaseHelper { + using Base = BaseHelper; constexpr static ggml_type type = GGML_TYPE_IQ4_NL; #ifdef __aarch64__ using block_q8 = block_q8_0; @@ -676,8 +661,7 @@ struct HelperIQ4nl final : public BaseHelper { #endif }; -template -struct HelperQ60 final : public BaseHelper { +struct HelperQ60 final : public BaseHelper { constexpr static ggml_type type = GGML_TYPE_Q6_0; #ifdef __aarch64__ using block_q8 = block_q8_0; @@ -686,7 +670,7 @@ struct HelperQ60 final : public BaseHelper { using block_q8 = block_q8_2; constexpr static int block_size_q = QK8_2; #endif - using Base = BaseHelper; + using Base = BaseHelper; HelperQ60(const char * data, int stride) : Base(data, stride) {} // Needed for v * softmax(k * q) @@ -739,8 +723,10 @@ struct HelperQ60 final : public BaseHelper { #endif }; -template +template struct FlashMS { + constexpr static int q_step = q_step_in; + constexpr static int k_step = k_step_in; // Something goes wrong when storing and manipulating K*Q as fp16. // It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B). // As I wasn't able to find where we lose precision, let's comment this out @@ -979,8 +965,9 @@ struct FlashQKV { using qkv_cache_t = float; #endif - template - inline void accumulate_qkv_1(const VHelper& vh, const FlashMS& fms) { + template + inline void accumulate_qkv_1(const VHelper& vh, const FMS& fms) { + static_assert(q_step == FMS::q_step); F16::Data vq[D/F16::block_size]; if (fms.need_scaling[0] == 2) { for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::zero(); @@ -1017,8 +1004,9 @@ struct FlashQKV { // This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2 // Hence, for now, we will not handle head sizes of 80 and 112 - template - inline void accumulate_qkv(const VHelper& vh, const FlashMS& fms) { + template + inline void accumulate_qkv(const VHelper& vh, const FMS& fms) { + static_assert(q_step == FMS::q_step); if constexpr (q_step == 1) { accumulate_qkv_1(vh, fms); return; @@ -1106,8 +1094,9 @@ struct FlashQKV { } } - template - inline void accumulate_qkv(int nq1, const VHelper& vh, const FlashMS& fms) { + template + inline void accumulate_qkv(int nq1, const VHelper& vh, const FMS& fms) { + static_assert(q_step == FMS::q_step); if (nq1 == 1) { accumulate_qkv_1(vh, fms); return; @@ -1151,7 +1140,9 @@ struct FlashQKV { } } - inline void normalize_and_store_1row(const FlashMS& fms, int j, const qkv_cache_t * R, float * qkv) const { + template + inline void normalize_and_store_1row(const FMS& fms, int j, const qkv_cache_t * R, float * qkv) const { + static_assert(q_step == FMS::q_step); GGML_ASSERT(fms.S[j] > 0); auto norm = F16::set1(1/fms.S[j]); //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); @@ -1161,7 +1152,9 @@ struct FlashQKV { } } - inline void normalize_and_store(const FlashMS& fms, int nq1, int stride_qkv, float * qkv, float * M, float * S) const { + template + inline void normalize_and_store(const FMS& fms, int nq1, int stride_qkv, float * qkv, float * M, float * S) const { + static_assert(q_step == FMS::q_step); if (M && S) { std::memcpy(M, fms.M, nq1*sizeof(float)); std::memcpy(S, fms.S, nq1*sizeof(float)); @@ -1187,7 +1180,9 @@ struct FlashQKV { } } - inline void normalize_and_store(const FlashMS& fms, int stride_qkv, float * qkv, float * M, float * S) const { + template + inline void normalize_and_store(const FMS& fms, int stride_qkv, float * qkv, float * M, float * S) const { + static_assert(q_step == FMS::q_step); if (M && S) { std::memcpy(M, fms.M, q_step*sizeof(float)); std::memcpy(S, fms.S, q_step*sizeof(float)); @@ -1365,8 +1360,8 @@ struct FlashQKfp32 { constexpr int kMaxQ = 8; static_assert(q_step < kMaxQ || q_step%kMaxQ == 0); DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr}; - if constexpr (std::is_same_v> || - std::is_same_v>) { + if constexpr (std::is_same_v> || + std::is_same_v>) { iqk_gemm_q8kv_fa(D, q_step, kh.type, kh.block, kh.stride, info, k_step); } else { iqk_gemm_legacy_fa(D, q_step, kh.type, kh.block, kh.stride, info, k_step); @@ -1389,8 +1384,8 @@ struct FlashQKfp32 { const block_q8 * q, const char * mask, FlashMS& fms) { GGML_ASSERT(nq < q_step); DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr}; - if constexpr (std::is_same_v> || - std::is_same_v>) { + if constexpr (std::is_same_v> || + std::is_same_v>) { iqk_gemm_q8kv_fa(D, nq, kh.type, kh.block, kh.stride, info, k_step); } else { iqk_gemm_legacy_fa(D, nq, kh.type, kh.block, kh.stride, info, k_step); @@ -1434,8 +1429,8 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms); #endif fqkv.accumulate_qkv(vh, fms); - kh.next_block(); - vh.next_block(); + kh.next_block(k_step); + vh.next_block(k_step); mr += k_step*sizeof(ggml_half); } fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); @@ -1461,8 +1456,8 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms); #endif fqkv.accumulate_qkv(n_left, vh, fms); - kh.next_block(); - vh.next_block(); + kh.next_block(k_step); + vh.next_block(k_step); mr += k_step*sizeof(ggml_half); } fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); @@ -1476,22 +1471,22 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, const float * q, const char * mask, float * qkv, float * M, float * S, char * qptr) { auto q8 = (typename KHelper::block_q8 *)qptr; - if constexpr (q_step > 1 && std::is_same_v>) { + if constexpr (q_step > 1 && std::is_same_v) { if (nq1 == q_step) { fms.init_qstep(); kh.reset_block(); vh.reset_block(); block_q8_0_r8 q8r8[Dk/QK8_0 * k_step/8]; - HelperQ80R8 khr8((const char *)q8r8, Dk/QK8_0*sizeof(block_q8_0)); - auto q8r = (typename HelperQ80R8::block_q8 *)qptr; - HelperQ80::convert(q_step, stride_q, q, q8r); + HelperQ80R8 khr8((const char *)q8r8, Dk/QK8_0*sizeof(block_q8_0)); + auto q8r = (typename HelperQ80R8::block_q8 *)qptr; + HelperQ80::convert(q_step, stride_q, q, q8r); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { - HelperQ80R8::repack(k_step, kh.block, kh.stride, q8r8); + HelperQ80R8::repack(k_step, kh.block, kh.stride, q8r8); KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms); fqkv.accumulate_qkv(vh, fms); - kh.next_block(); - vh.next_block(); + kh.next_block(k_step); + vh.next_block(k_step); mr += k_step*sizeof(ggml_half); } fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); @@ -1508,7 +1503,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, fms.init_qstep(); kh.reset_block(); vh.reset_block(); - HelperQ80::convert(q_step, stride_q, q, q8); + HelperQ80::convert(q_step, stride_q, q, q8); #if FA_TIMING perf.accum_nolock(0, t1); #endif @@ -1525,8 +1520,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms); fqkv.accumulate_qkv(vh, fms); #endif - kh.next_block(); - vh.next_block(); + kh.next_block(k_step); + vh.next_block(k_step); mr += k_step*sizeof(ggml_half); } #if FA_TIMING @@ -1547,13 +1542,13 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, fms.init_qstep(); kh.reset_block(); vh.reset_block(); - HelperQ80::convert(n_left, stride_q, q, q8); + HelperQ80::convert(n_left, stride_q, q, q8); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { KQHelper::mul_mask_kq(n_left, kh, stride_m, q8, mr, fms); fqkv.accumulate_qkv(n_left, vh, fms); - kh.next_block(); - vh.next_block(); + kh.next_block(k_step); + vh.next_block(k_step); mr += k_step*sizeof(ggml_half); } fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); @@ -1591,14 +1586,14 @@ struct FlashAttn { template void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) { - if constexpr (std::is_same_v> || - std::is_same_v> || - std::is_same_v> || - std::is_same_v> || - std::is_same_v> || - std::is_same_v> || - std::is_same_v> || - std::is_same_v>) { + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v> || + std::is_same_v || + std::is_same_v> || + std::is_same_v>) { constexpr size_t kMaxOnStackSize = 576; //auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8); auto q_size = q_step*(Dk/QK8_2*sizeof(block_q8_2)); @@ -1606,31 +1601,33 @@ struct FlashAttn { if (q_size > kMaxOnStackSize) { auto qptr = get_q_storage(q_size); if (false && nq1 >= 8) { - if constexpr (std::is_same_v>) { + if constexpr (std::is_same_v) { #if FA_TIMING auto t1 = Perf::cur_time(); HelperQ80R8 khr4(nk1, kh); Perf::instance().accum(4, t1); #else - HelperQ80R8 khr4(nk1, kh); + HelperQ80R8 khr4(nk1, kh); #endif - compute_helper_q, VHelper, FlashQKfp32>( + compute_helper_q, VHelper, FlashQKfp32>( khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); return; } - if constexpr (std::is_same_v>) { +#if GGML_IQK_FA_ALL_QUANTS + if constexpr (std::is_same_v>) { #if FA_TIMING auto t1 = Perf::cur_time(); HelperQ8KVR8 khr4(nk1, kh); Perf::instance().accum(4, t1); #else - HelperQ8KVR8 khr4(nk1, kh); + HelperQ8KVR8 khr4(nk1, kh); #endif - compute_helper_q, VHelper, FlashQKfp32>( + compute_helper_q, VHelper, FlashQKfp32>( khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); return; } +#endif } compute_helper_q>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); @@ -2040,8 +2037,8 @@ struct FlashAttnBF16 { FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms); fqkv.accumulate_qkv(vh, fms); #endif - kh.next_block(); - vh.next_block(); + kh.next_block(k_step); + vh.next_block(k_step); mr += k_step*sizeof(ggml_half); } #if FA_TIMING @@ -2066,8 +2063,8 @@ struct FlashAttnBF16 { for (int k1 = 0; k1 < nk1/k_step; ++k1) { FlashQKbf16::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms); fqkv.accumulate_qkv(n_left, vh, fms); - kh.next_block(); - vh.next_block(); + kh.next_block(k_step); + vh.next_block(k_step); mr += k_step*sizeof(ggml_half); } fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); @@ -2180,7 +2177,7 @@ inline bool iqk_flash_helper_T(KHelper& kh, ggml_type type_v, switch (type_v) { case GGML_TYPE_F16: { - HelperF16 vh(v, stride_v); + HelperF16 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; #ifdef __AVX512BF16__ @@ -2190,28 +2187,28 @@ inline bool iqk_flash_helper_T(KHelper& kh, ggml_type type_v, } break; #endif case GGML_TYPE_Q8_0: { - HelperQ80 vh(v, stride_v); + HelperQ80 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q8_KV: { - HelperQ8KV vh(v, stride_v); + HelperQ8KV vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q6_0: { - HelperQ60 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); - } break; - case GGML_TYPE_Q4_0: { - HelperQ40 vh(v, stride_v); + HelperQ60 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; #if GGML_IQK_FA_ALL_QUANTS + case GGML_TYPE_Q4_0: { + HelperQ40 vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + } break; case GGML_TYPE_Q4_1: { - HelperQ41 vh(v, stride_v); + HelperQ41 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_IQ4_NL: { - HelperIQ4nl vh(v, stride_v); + HelperIQ4nl vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; #endif @@ -2229,36 +2226,36 @@ inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, bool result = false; switch (type_k) { case GGML_TYPE_F16: { - HelperF16 kh(k, stride_k); + HelperF16 kh(k, stride_k); result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q8_0: { - HelperQ80 kh(k, stride_k); + HelperQ80 kh(k, stride_k); result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q8_0_R8: { - HelperQ80R8 kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); - } break; - case GGML_TYPE_Q8_KV: { - HelperQ8KV kh(k, stride_k); + HelperQ80R8 kh(k, stride_k); result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q6_0: { - HelperQ60 kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); - } break; - case GGML_TYPE_Q4_0: { - HelperQ40 kh(k, stride_k); + HelperQ60 kh(k, stride_k); result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; #if GGML_IQK_FA_ALL_QUANTS + case GGML_TYPE_Q8_KV: { + HelperQ8KV kh(k, stride_k); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + } break; + case GGML_TYPE_Q4_0: { + HelperQ40 kh(k, stride_k); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + } break; case GGML_TYPE_Q4_1: { - HelperQ41 kh(k, stride_k); + HelperQ41 kh(k, stride_k); result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_IQ4_NL: { - HelperIQ4nl kh(k, stride_k); + HelperIQ4nl kh(k, stride_k); result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; #endif diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp index 75901008..b175987c 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.cpp +++ b/ggml/src/iqk/iqk_gemm_kquants.cpp @@ -3039,7 +3039,7 @@ static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataI #endif template -inline std::pair mul_mat_kernel(int D, int int_typeA, int nq) { +inline std::pair mul_mat_kernel([[maybe_unused]] int D, int int_typeA, int nq) { auto typeA = ggml_type(int_typeA); constexpr int kMaxQ = 8; #define MAKE_FUNCS(mul_mat, n) \