mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-25 17:09:22 +00:00
* Refactor iqk: WIP * Refactor iqk: Factor out float GEMM (AVX2/AVX512) * Refactor iqk: Factor out GEMM for legacy quants (AVX2/AVX512) * Refactor iqk: Factor out GEMM for k-quants (AVX2/AVX512) * Refactor iqk: fix AVX2 * Refactor iqk: Factor out GEMM for i-quants (AVX2/AVX512) * Refactor iqk: fix AVX2 * Refactor iqk: Factor out GEMM for iqk-quants (AVX2/AVX512) * Refactor iqk: fix AVX2 * Refactor iqk: Factor out GEMM for 1-bit quants (ABX2/AVX512) * Refactor iqk: fix AVX2 * Refactor iqk: Factor out GEMM for iq1_bn, iq2_bn, iq2_bn_r4 * Refactor iqk: Factor out GEMM for repacked legacy quants * Refactor iqk: Factor out GEMM for q8_K_R8, q8_KV * Refactor iqk: Factor out GEMM for repacked i-quants * Refactor iqk: GEMM kernels are refactored on AVX2/AVX512 * Refactor iqk: factor out 1-bit quants (NEON) * Refactor iqk: factor out k-quants (NEON) * Refactor iqk: factor out floats (NEON) * Also iq4_xs belongs to k-quants * Refactor iqk: factor out iqk quants (NEON) * Refactor iqk: factor out legacy quants (NEON) * Refactor iqk: factor out repacked legacy quants (NEON) * Refactor iqk: factor out repacked k-quants (NEON) * Refactor iqk: factor out repacked iqk quants (NEON) * Refactor iqk: GEMM kernels are refactored on NEON * Refactor iqk: FA compiles If it works is a different story. Current compile time: 107.3 sesonds on the Ryzen-7950X * Refactor iqk: FA refactored (Zen4) Compile time for the FA files is now ~21 seconds on my Ryzen-7950X, so still slightly too long for my taste but much better than the 142 seconds we had before. * Adding forgotten file * Most helpers don't need to be templates Also hide Q4_0 and Q8_KV behind IQK_FA_ALL_QUANTS. Compilation time drops to 14 second on the Ryzen-5975WX * Fix bf16 * Refactor iqk: FA refactored (NEON) * Forgotten MMQ ref and typo (#431) * Adding forgotten iq5_k_r4 * Fix iq4_k_r4 on NEON * Fix iq4_ks on NEON It was broken before the refactoring (the shifts were not correctly applied). * Fix q8_0 on NEON * Fix q6_0 K cache --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> Co-authored-by: Nexes the Elder <124105151+Nexesenex@users.noreply.github.com>
46 lines
1.6 KiB
C++
46 lines
1.6 KiB
C++
#include "iqk/iqk_config.h"
|
|
|
|
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
|
|
|
|
#include "iqk/fa/iqk_fa_templates.h"
|
|
|
|
IQK_FA_CASE(iqk_fa_192_128) {
|
|
|
|
auto type_k = ggml_type(int_type_k);
|
|
auto type_v = ggml_type(int_type_v);
|
|
|
|
stride_q /= sizeof(float); // q stride as float
|
|
auto ck = (const char *)k;
|
|
auto cv = (const char *)v;
|
|
auto cm = (const char *)mask;
|
|
|
|
#ifdef __AVX512BF16__
|
|
if (type_k == GGML_TYPE_BF16) {
|
|
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
|
if (nk%64 == 0) {
|
|
iqk_flash_helper_T<192, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
|
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
|
return true;
|
|
}
|
|
iqk_flash_helper_T<192, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
|
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
|
return true;
|
|
}
|
|
#endif
|
|
|
|
if (nk%128 == 0) {
|
|
return iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
|
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
|
}
|
|
if (nk%64 == 0) {
|
|
return iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
|
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
|
}
|
|
|
|
return iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
|
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
|
|
|
}
|
|
|
|
#endif
|