mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-26 08:04:09 +00:00
Refactor iqk_mul_mat.cpp (#435)
* Refactor iqk: WIP * Refactor iqk: Factor out float GEMM (AVX2/AVX512) * Refactor iqk: Factor out GEMM for legacy quants (AVX2/AVX512) * Refactor iqk: Factor out GEMM for k-quants (AVX2/AVX512) * Refactor iqk: fix AVX2 * Refactor iqk: Factor out GEMM for i-quants (AVX2/AVX512) * Refactor iqk: fix AVX2 * Refactor iqk: Factor out GEMM for iqk-quants (AVX2/AVX512) * Refactor iqk: fix AVX2 * Refactor iqk: Factor out GEMM for 1-bit quants (ABX2/AVX512) * Refactor iqk: fix AVX2 * Refactor iqk: Factor out GEMM for iq1_bn, iq2_bn, iq2_bn_r4 * Refactor iqk: Factor out GEMM for repacked legacy quants * Refactor iqk: Factor out GEMM for q8_K_R8, q8_KV * Refactor iqk: Factor out GEMM for repacked i-quants * Refactor iqk: GEMM kernels are refactored on AVX2/AVX512 * Refactor iqk: factor out 1-bit quants (NEON) * Refactor iqk: factor out k-quants (NEON) * Refactor iqk: factor out floats (NEON) * Also iq4_xs belongs to k-quants * Refactor iqk: factor out iqk quants (NEON) * Refactor iqk: factor out legacy quants (NEON) * Refactor iqk: factor out repacked legacy quants (NEON) * Refactor iqk: factor out repacked k-quants (NEON) * Refactor iqk: factor out repacked iqk quants (NEON) * Refactor iqk: GEMM kernels are refactored on NEON * Refactor iqk: FA compiles If it works is a different story. Current compile time: 107.3 sesonds on the Ryzen-7950X * Refactor iqk: FA refactored (Zen4) Compile time for the FA files is now ~21 seconds on my Ryzen-7950X, so still slightly too long for my taste but much better than the 142 seconds we had before. * Adding forgotten file * Most helpers don't need to be templates Also hide Q4_0 and Q8_KV behind IQK_FA_ALL_QUANTS. Compilation time drops to 14 second on the Ryzen-5975WX * Fix bf16 * Refactor iqk: FA refactored (NEON) * Forgotten MMQ ref and typo (#431) * Adding forgotten iq5_k_r4 * Fix iq4_k_r4 on NEON * Fix iq4_ks on NEON It was broken before the refactoring (the shifts were not correctly applied). * Fix q8_0 on NEON * Fix q6_0 K cache --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> Co-authored-by: Nexes the Elder <124105151+Nexesenex@users.noreply.github.com>
This commit is contained in:
@@ -258,8 +258,29 @@ set (GGML_HEADERS_IQK iqk/iqk_config.h)
|
||||
if (GGML_IQK_MUL_MAT)
|
||||
message(STATUS "Using optimized iqk matrix multiplications")
|
||||
add_compile_definitions(GGML_USE_IQK_MULMAT)
|
||||
set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp iqk/iqk_flash_attn.cpp)
|
||||
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h iqk/iqk_flash_impl.h)
|
||||
set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp
|
||||
iqk/iqk_flash_attn.cpp
|
||||
iqk/fa/iqk_fa_576_512.cpp
|
||||
iqk/fa/iqk_fa_192_128.cpp
|
||||
iqk/fa/iqk_fa_256_256.cpp
|
||||
iqk/fa/iqk_fa_128_128.cpp
|
||||
iqk/fa/iqk_fa_96_96.cpp
|
||||
iqk/fa/iqk_fa_64_64.cpp
|
||||
iqk/iqk_gemm_floats.cpp
|
||||
iqk/iqk_gemm_kquants.cpp
|
||||
iqk/iqk_gemm_iquants.cpp
|
||||
iqk/iqk_gemm_iqk_quants.cpp
|
||||
iqk/iqk_gemm_1bit.cpp
|
||||
iqk/iqk_gemm_legacy_quants.cpp)
|
||||
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h
|
||||
iqk/iqk_flash_impl.h
|
||||
iqk/fa/iqk_fa_templates.h
|
||||
iqk/iqk_gemm_floats.h
|
||||
iqk/iqk_gemm_kquants.h
|
||||
iqk/iqk_gemm_iquants.h
|
||||
iqk/iqk_gemm_iqk_quants.h
|
||||
iqk/iqk_gemm_1bit.h
|
||||
iqk/iqk_gemm_legacy_quants.h)
|
||||
if (GGML_IQK_FLASH_ATTENTION)
|
||||
message(STATUS "Enabling IQK Flash Attention kernels")
|
||||
add_compile_definitions(GGML_IQK_FLASH_ATTENTION)
|
||||
|
||||
45
ggml/src/iqk/fa/iqk_fa_128_128.cpp
Normal file
45
ggml/src/iqk/fa/iqk_fa_128_128.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
#include "iqk/iqk_config.h"
|
||||
|
||||
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
|
||||
|
||||
#include "iqk/fa/iqk_fa_templates.h"
|
||||
|
||||
IQK_FA_CASE(iqk_fa_128_128) {
|
||||
|
||||
auto type_k = ggml_type(int_type_k);
|
||||
auto type_v = ggml_type(int_type_v);
|
||||
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
auto ck = (const char *)k;
|
||||
auto cv = (const char *)v;
|
||||
auto cm = (const char *)mask;
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
if (type_k == GGML_TYPE_BF16) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
if (nk%64 == 0) {
|
||||
iqk_flash_helper_T<128, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
iqk_flash_helper_T<128, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk%128 == 0) {
|
||||
return iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
if (nk%64 == 0) {
|
||||
return iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
|
||||
return iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
45
ggml/src/iqk/fa/iqk_fa_192_128.cpp
Normal file
45
ggml/src/iqk/fa/iqk_fa_192_128.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
#include "iqk/iqk_config.h"
|
||||
|
||||
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
|
||||
|
||||
#include "iqk/fa/iqk_fa_templates.h"
|
||||
|
||||
IQK_FA_CASE(iqk_fa_192_128) {
|
||||
|
||||
auto type_k = ggml_type(int_type_k);
|
||||
auto type_v = ggml_type(int_type_v);
|
||||
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
auto ck = (const char *)k;
|
||||
auto cv = (const char *)v;
|
||||
auto cm = (const char *)mask;
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
if (type_k == GGML_TYPE_BF16) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
if (nk%64 == 0) {
|
||||
iqk_flash_helper_T<192, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
iqk_flash_helper_T<192, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk%128 == 0) {
|
||||
return iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
if (nk%64 == 0) {
|
||||
return iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
|
||||
return iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
45
ggml/src/iqk/fa/iqk_fa_256_256.cpp
Normal file
45
ggml/src/iqk/fa/iqk_fa_256_256.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
#include "iqk/iqk_config.h"
|
||||
|
||||
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
|
||||
|
||||
#include "iqk/fa/iqk_fa_templates.h"
|
||||
|
||||
IQK_FA_CASE(iqk_fa_256_256) {
|
||||
|
||||
auto type_k = ggml_type(int_type_k);
|
||||
auto type_v = ggml_type(int_type_v);
|
||||
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
auto ck = (const char *)k;
|
||||
auto cv = (const char *)v;
|
||||
auto cm = (const char *)mask;
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
if (type_k == GGML_TYPE_BF16) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
if (nk%64 == 0) {
|
||||
iqk_flash_helper_T<256, 256, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
iqk_flash_helper_T<256, 256, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk%128 == 0) {
|
||||
return iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
if (nk%64 == 0) {
|
||||
return iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
|
||||
return iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
120
ggml/src/iqk/fa/iqk_fa_576_512.cpp
Normal file
120
ggml/src/iqk/fa/iqk_fa_576_512.cpp
Normal file
@@ -0,0 +1,120 @@
|
||||
#include "iqk/iqk_config.h"
|
||||
|
||||
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
|
||||
|
||||
#include "iqk/fa/iqk_fa_templates.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <int step_k, typename KHelper, typename VHelper>
|
||||
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
|
||||
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
|
||||
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
|
||||
nq1 -= n;
|
||||
if (nq1 == 0) return true;
|
||||
q += n*stride_q;
|
||||
mask += n*stride_m;
|
||||
qkv += n*stride_qkv;
|
||||
if (M && S) { M += n; S += n; }
|
||||
return false;
|
||||
};
|
||||
if (nq1 >= 16) {
|
||||
int n_step = nq1/16;
|
||||
FlashAttn<576, 512, 16, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (update(16*n_step)) return;
|
||||
}
|
||||
if (nq1 >= 8) {
|
||||
int n_step = nq1/8;
|
||||
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (update(8*n_step)) return;
|
||||
}
|
||||
if (nq1 >= 4) {
|
||||
int n_step = nq1/4;
|
||||
FlashAttn<576, 512, 4, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (update(4*n_step)) return;
|
||||
}
|
||||
if (nq1 >= 2) {
|
||||
int n_step = nq1/2;
|
||||
FlashAttn<576, 512, 2, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (update(2*n_step)) return;
|
||||
}
|
||||
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
}
|
||||
|
||||
template <int step_k>
|
||||
inline bool iqk_deepseek_helper(ggml_type type_k,
|
||||
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||
const float * q, const char * k, const char * v, const char * mask,
|
||||
float scale, float softcap, float * qkv, float * M, float * S) {
|
||||
if (type_k == GGML_TYPE_Q8_0) {
|
||||
HelperQ80 kh((const char *)k, stride_k);
|
||||
HelperQ80 vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
if (type_k == GGML_TYPE_Q8_0_R8) {
|
||||
HelperQ80R8<576> kh((const char *)k, stride_k);
|
||||
HelperQ80 vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
if (type_k == GGML_TYPE_Q6_0) {
|
||||
HelperQ60 kh((const char *)k, stride_k);
|
||||
HelperQ60 vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
#if GGML_IQK_FA_ALL_QUANTS
|
||||
if (type_k == GGML_TYPE_Q8_KV) {
|
||||
HelperQ8KV<576> kh((const char *)k, stride_k);
|
||||
HelperQ8KV<512> vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
if (type_k == GGML_TYPE_F16) {
|
||||
HelperF16 kh((const char *)k, stride_k);
|
||||
HelperF16 vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
#ifdef __AVX512BF16__
|
||||
if (type_k == GGML_TYPE_BF16) {
|
||||
HelperBF16<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperBF16<512, step_k> vh((const char *)v, stride_v);
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
} else {
|
||||
FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
IQK_FA_CASE(iqk_fa_576_512) {
|
||||
|
||||
auto type_k = ggml_type(int_type_k);
|
||||
auto type_v = ggml_type(int_type_v);
|
||||
|
||||
if (!(type_k == type_v || (type_k == GGML_TYPE_Q8_0_R8 && type_v == GGML_TYPE_Q8_0))) {
|
||||
return false;
|
||||
}
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
45
ggml/src/iqk/fa/iqk_fa_64_64.cpp
Normal file
45
ggml/src/iqk/fa/iqk_fa_64_64.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
#include "iqk/iqk_config.h"
|
||||
|
||||
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
|
||||
|
||||
#include "iqk/fa/iqk_fa_templates.h"
|
||||
|
||||
IQK_FA_CASE(iqk_fa_64_64) {
|
||||
|
||||
auto type_k = ggml_type(int_type_k);
|
||||
auto type_v = ggml_type(int_type_v);
|
||||
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
auto ck = (const char *)k;
|
||||
auto cv = (const char *)v;
|
||||
auto cm = (const char *)mask;
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
if (type_k == GGML_TYPE_BF16) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
if (nk%64 == 0) {
|
||||
iqk_flash_helper_T<64, 64, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
iqk_flash_helper_T<64, 64, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk%128 == 0) {
|
||||
return iqk_flash_helper_T<64, 64, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
if (nk%64 == 0) {
|
||||
return iqk_flash_helper_T<64, 64, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
|
||||
return iqk_flash_helper_T<64, 64, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
45
ggml/src/iqk/fa/iqk_fa_96_96.cpp
Normal file
45
ggml/src/iqk/fa/iqk_fa_96_96.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
#include "iqk/iqk_config.h"
|
||||
|
||||
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
|
||||
|
||||
#include "iqk/fa/iqk_fa_templates.h"
|
||||
|
||||
IQK_FA_CASE(iqk_fa_96_96) {
|
||||
|
||||
auto type_k = ggml_type(int_type_k);
|
||||
auto type_v = ggml_type(int_type_v);
|
||||
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
auto ck = (const char *)k;
|
||||
auto cv = (const char *)v;
|
||||
auto cm = (const char *)mask;
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
if (type_k == GGML_TYPE_BF16) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
if (nk%64 == 0) {
|
||||
iqk_flash_helper_T<96, 96, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
iqk_flash_helper_T<96, 96, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk%128 == 0) {
|
||||
return iqk_flash_helper_T<96, 96, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
if (nk%64 == 0) {
|
||||
return iqk_flash_helper_T<96, 96, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
|
||||
return iqk_flash_helper_T<96, 96, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
2207
ggml/src/iqk/fa/iqk_fa_templates.h
Normal file
2207
ggml/src/iqk/fa/iqk_fa_templates.h
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,6 +7,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "iqk_config.h"
|
||||
|
||||
#if defined IQK_IMPLEMENT
|
||||
@@ -14,6 +16,7 @@
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-quants.h"
|
||||
@@ -79,8 +82,6 @@ struct Perf {
|
||||
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
typedef struct {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
@@ -135,4 +136,694 @@ struct DataInfo {
|
||||
|
||||
typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x);
|
||||
|
||||
#define IQK_MAX_NY 8
|
||||
|
||||
#define IQK_SET_MUL_MAT_FUNCTIONS_T(kernel, Dequantizer, funcs) \
|
||||
funcs[0] = kernel<Dequantizer, 1>;\
|
||||
funcs[1] = kernel<Dequantizer, 2>;\
|
||||
funcs[2] = kernel<Dequantizer, 3>;\
|
||||
funcs[3] = kernel<Dequantizer, 4>;\
|
||||
funcs[4] = kernel<Dequantizer, 5>;\
|
||||
funcs[5] = kernel<Dequantizer, 6>;\
|
||||
funcs[6] = kernel<Dequantizer, 7>;\
|
||||
funcs[7] = kernel<Dequantizer, 8>;\
|
||||
|
||||
#define IQK_SET_MUL_MAT_FUNCTIONS(kernel, funcs) \
|
||||
funcs[0] = kernel<1>;\
|
||||
funcs[1] = kernel<2>;\
|
||||
funcs[2] = kernel<3>;\
|
||||
funcs[3] = kernel<4>;\
|
||||
funcs[4] = kernel<5>;\
|
||||
funcs[5] = kernel<6>;\
|
||||
funcs[6] = kernel<7>;\
|
||||
funcs[7] = kernel<8>;\
|
||||
|
||||
|
||||
// ==================================================================================================
|
||||
|
||||
static inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) {
|
||||
const uint16_t * scales = (const uint16_t *)scales8;
|
||||
const uint32_t a0 = scales[0] | (scales[1] << 16);
|
||||
const uint32_t a1 = scales[2] | (scales[3] << 16);
|
||||
const uint32_t a2 = scales[4] | (scales[5] << 16);
|
||||
aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030);
|
||||
aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030);
|
||||
aux32[2] = a1 & 0x3f3f3f3f;
|
||||
aux32[0] = a0 & 0x3f3f3f3f;
|
||||
}
|
||||
|
||||
#if !(defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__)
|
||||
const uint64_t keven_signs[128] = {
|
||||
0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff,
|
||||
0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff,
|
||||
0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff,
|
||||
0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff,
|
||||
0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff,
|
||||
0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff,
|
||||
0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff,
|
||||
0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff,
|
||||
0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff,
|
||||
0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff,
|
||||
0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff,
|
||||
0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff,
|
||||
0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff,
|
||||
0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff,
|
||||
0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff,
|
||||
0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff,
|
||||
0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff,
|
||||
0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff,
|
||||
0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff,
|
||||
0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff,
|
||||
0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff,
|
||||
0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff,
|
||||
0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff,
|
||||
0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff,
|
||||
0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff,
|
||||
0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff,
|
||||
0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff,
|
||||
0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff,
|
||||
0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff,
|
||||
0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff,
|
||||
0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff,
|
||||
0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff,
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef __AVX2__
|
||||
|
||||
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
||||
|
||||
static inline float hsum_float_4(__m128 x) {
|
||||
x = _mm_add_ps(x, _mm_movehl_ps(x, x));
|
||||
x = _mm_add_ss(x, _mm_movehdup_ps(x));
|
||||
return _mm_cvtss_f32(x);
|
||||
}
|
||||
static inline float hsum_float_8(__m256 x) {
|
||||
return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));
|
||||
}
|
||||
static inline int hsum_i32_8(const __m256i a) {
|
||||
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
|
||||
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
|
||||
const __m128i sum64 = _mm_add_epi32(hi64, sum128);
|
||||
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
|
||||
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
||||
}
|
||||
static inline float hmax_float_8(__m256 x) {
|
||||
__m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
|
||||
max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4));
|
||||
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4));
|
||||
return _mm_cvtss_f32(max4);
|
||||
}
|
||||
|
||||
static inline __m256 hsum_float_8x8(__m256 * accm) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31));
|
||||
//accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
|
||||
// _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
|
||||
}
|
||||
for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
|
||||
return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
|
||||
}
|
||||
|
||||
static inline __m128i load_iq4nl_values_128() {
|
||||
static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};
|
||||
return _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
|
||||
}
|
||||
|
||||
static inline __m256i load_iq4nl_values_256() {
|
||||
auto val128 = load_iq4nl_values_128();
|
||||
return MM256_SET_M128I(val128, val128);
|
||||
}
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
static inline __m512i load_iq4nl_values_512() {
|
||||
auto val256 = load_iq4nl_values_256();
|
||||
return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline __m128i load_iq4k_values_128() {
|
||||
return _mm_loadu_si128((const __m128i *)iq4k_values);
|
||||
}
|
||||
|
||||
static inline __m256i load_iq4k_values_256() {
|
||||
auto val128 = load_iq4k_values_128();
|
||||
return MM256_SET_M128I(val128, val128);
|
||||
}
|
||||
|
||||
template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
|
||||
|
||||
constexpr static int nrc_y = nrc;
|
||||
|
||||
Q8(const DataInfo& info) {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);
|
||||
}
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
inline __m512i load_quants64(int iy, int i, int j) const { return _mm512_loadu_si512((const __m512i*)y[iy][i].qs + j); }
|
||||
#endif
|
||||
inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); }
|
||||
inline __m256i load_bsums(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].bsums); }
|
||||
inline float scale(int iy, int i) const { return y[iy][i].d; }
|
||||
|
||||
const block_q8 * y[nrc_y];
|
||||
};
|
||||
|
||||
template <int nrc> struct Q8_16 {
|
||||
|
||||
constexpr static int nrc_y = nrc;
|
||||
|
||||
Q8_16(const DataInfo& info) {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto ptr = (const float *)info.src1_row(iy);
|
||||
std::memcpy(d + 5*iy, ptr, 5*sizeof(float));
|
||||
y[iy] = (const int8_t *)(ptr + 5);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
inline __m512i load_quants64(int iy, int i) const { return _mm512_loadu_si512((const __m512i*)y[iy] + i); }
|
||||
#endif
|
||||
inline __m256i load_quants(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy] + i); }
|
||||
inline float scale(int iy, int k) const { return d[5*iy+k]; }
|
||||
inline float sum_row(int iy) const { return d[5*iy + 4]; }
|
||||
inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 5*iy); }
|
||||
|
||||
float d[5*nrc_y];
|
||||
const int8_t * y[nrc_y];
|
||||
};
|
||||
|
||||
struct Scales8KBase {
|
||||
template <typename Q8>
|
||||
inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {
|
||||
const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0]));
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
const __m256i q8s = q8.load_bsums(iy, i);
|
||||
const __m256i prod = _mm256_madd_epi16(mins, q8s);
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
|
||||
}
|
||||
}
|
||||
inline __m256i shuffle(__m128i mins) const {
|
||||
return MM256_SET_M128I(_mm_shuffle_epi8(mins, shuffles[1]), _mm_shuffle_epi8(mins, shuffles[0]));
|
||||
}
|
||||
const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100),
|
||||
_mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)};
|
||||
};
|
||||
|
||||
template <typename Block, bool per_row_scale = false, bool is_f16 = false>
|
||||
struct BaseDequantizer {
|
||||
BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {}
|
||||
inline void new_row(int ix) {
|
||||
if constexpr (per_row_scale) {
|
||||
if constexpr (is_f16) {
|
||||
const ggml_half * dptr = (const ggml_half *)((const char *)vx + bx*ix);
|
||||
d = GGML_FP16_TO_FP32(*dptr);
|
||||
x = (const Block *)(dptr + 1);
|
||||
} else {
|
||||
const float * dptr = (const float *)((const char *)vx + bx*ix);
|
||||
d = *dptr;
|
||||
x = (const Block *)(dptr + 1);
|
||||
}
|
||||
} else {
|
||||
x = (const Block *)((const char *)vx + bx*ix);
|
||||
}
|
||||
}
|
||||
|
||||
const void * vx;
|
||||
const size_t bx;
|
||||
const Block * x;
|
||||
|
||||
float d;
|
||||
};
|
||||
|
||||
template <typename Q8, typename Bits>
|
||||
static inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
|
||||
if (j == 0) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));
|
||||
}
|
||||
#else
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));
|
||||
const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));
|
||||
const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));
|
||||
const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));
|
||||
sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4));
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));
|
||||
}
|
||||
#else
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));
|
||||
const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));
|
||||
const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));
|
||||
const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));
|
||||
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));
|
||||
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Q8, typename Bits>
|
||||
static inline void multiply_add_avx2(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
|
||||
__m256i p[4];
|
||||
if (j == 0) {
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]);
|
||||
p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, k), bits.values[k])));
|
||||
}
|
||||
sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p[0], p[1]), _mm256_add_epi32(p[2], p[3]));
|
||||
}
|
||||
} else {
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]);
|
||||
p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, 4+k), bits.values[k])));
|
||||
}
|
||||
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[0], p[2]));
|
||||
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[1], p[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
|
||||
struct BlockPermuter {
|
||||
const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
|
||||
const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
|
||||
};
|
||||
|
||||
struct Q4Bits {
|
||||
inline void prepare(const uint8_t * q4) {
|
||||
auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
|
||||
auto tmp1 = _mm512_and_si512(q4bits, ml);
|
||||
auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||
values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
|
||||
values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
|
||||
q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
|
||||
tmp1 = _mm512_and_si512(q4bits, ml);
|
||||
tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||
values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
|
||||
values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
|
||||
}
|
||||
inline void prepare64(const uint8_t * q4) {
|
||||
auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
|
||||
values[0] = _mm512_and_si512(q4bits, ml);
|
||||
values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||
q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
|
||||
values[2] = _mm512_and_si512(q4bits, ml);
|
||||
values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||
}
|
||||
inline void prepare64a(const uint8_t * q4) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + k);
|
||||
values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(q4bits), _mm256_srli_epi16(q4bits, 4), 1);
|
||||
values[k] = _mm512_and_si512(values[k], ml);
|
||||
}
|
||||
}
|
||||
__m512i values[4];
|
||||
const __m512i ml = _mm512_set1_epi8(0xf);
|
||||
const BlockPermuter perm;
|
||||
};
|
||||
|
||||
struct Q2Bits {
|
||||
inline void prepare(const uint8_t * q2) {
|
||||
|
||||
auto q2bits = _mm512_loadu_si512((const __m512i*)q2);
|
||||
auto tmp = _mm512_srli_epi16(q2bits, 2);
|
||||
|
||||
values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp);
|
||||
values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp);
|
||||
values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml);
|
||||
values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml);
|
||||
values[0] = _mm512_and_si512(values[0], ml);
|
||||
values[2] = _mm512_and_si512(values[2], ml);
|
||||
}
|
||||
__m512i values[4];
|
||||
const __m512i ml = _mm512_set1_epi8(0x03);
|
||||
BlockPermuter perm;
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
struct Q2Bits {
|
||||
inline void prepare(const uint8_t * q2, int j) {
|
||||
auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j);
|
||||
values[0] = _mm256_and_si256(q2bits, ml);
|
||||
values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);
|
||||
values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);
|
||||
values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);
|
||||
}
|
||||
__m256i values[4];
|
||||
const __m256i ml = _mm256_set1_epi8(0x03);
|
||||
};
|
||||
|
||||
struct Q4Bits {
|
||||
inline void prepare(const uint8_t * q4, int j) {
|
||||
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
|
||||
values[0] = _mm256_and_si256(q4bits, ml);
|
||||
values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
|
||||
values[2] = _mm256_and_si256(q4bits, ml);
|
||||
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||
}
|
||||
inline void prepare64(const uint8_t * q4, int j) {
|
||||
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
|
||||
values[0] = _mm256_and_si256(q4bits, ml);
|
||||
values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
|
||||
values[1] = _mm256_and_si256(q4bits, ml);
|
||||
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||
}
|
||||
inline void prepare16(const uint8_t * q4, int j) {
|
||||
values[0] = dequant16(q4 + 64*j + 0);
|
||||
values[1] = dequant16(q4 + 64*j + 16);
|
||||
values[2] = dequant16(q4 + 64*j + 32);
|
||||
values[3] = dequant16(q4 + 64*j + 48);
|
||||
}
|
||||
inline __m256i dequant16(const uint8_t * qs) const {
|
||||
const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);
|
||||
const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128);
|
||||
return _mm256_and_si256(ml, aux256);
|
||||
}
|
||||
__m256i values[4];
|
||||
const __m256i ml = _mm256_set1_epi8(0xf);
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
#else
|
||||
// ------------------------------------ __aarch64__ --------------------------------------------------
|
||||
|
||||
template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
|
||||
|
||||
constexpr static int nrc_y = nrc;
|
||||
|
||||
Q8(const DataInfo& info) {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);
|
||||
}
|
||||
|
||||
inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); }
|
||||
inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); }
|
||||
inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); }
|
||||
inline int16x8_t load_bsums8(int iy, int i) const {
|
||||
auto q8s = vld1q_s16_x2(y[iy][i].bsums);
|
||||
return vpaddq_s16(q8s.val[0], q8s.val[1]);
|
||||
}
|
||||
inline float scale(int iy, int i) const { return y[iy][i].d; }
|
||||
|
||||
const block_q8 * y[nrc_y];
|
||||
};
|
||||
|
||||
template <typename block_q, bool has_row_scale = false, bool scale_is_f16 = false>
|
||||
struct BaseDequantizer {
|
||||
BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {}
|
||||
inline void new_row(int ix) {
|
||||
if constexpr (has_row_scale) {
|
||||
if constexpr (scale_is_f16) {
|
||||
const ggml_half * dptr = (const ggml_half *)((const char *)vx + ix*bx);
|
||||
d = GGML_FP16_TO_FP32(*dptr);
|
||||
x = (const block_q *)(dptr + 1);
|
||||
} else {
|
||||
const float * dptr = (const float *)((const char *)vx + ix*bx);
|
||||
d = *dptr;
|
||||
x = (const block_q *)(dptr + 1);
|
||||
}
|
||||
} else {
|
||||
x = (const block_q *)((const char *)vx + ix*bx);
|
||||
}
|
||||
}
|
||||
const void * vx;
|
||||
const block_q * x;
|
||||
const size_t bx;
|
||||
const int nrc;
|
||||
float d;
|
||||
};
|
||||
|
||||
struct Q4bits {
|
||||
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||
uint8x16x4_t b1, b2;
|
||||
inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const {
|
||||
b.val[0] = vandq_u8(val[0], m4b);
|
||||
b.val[2] = vshrq_n_u8(val[0], 4);
|
||||
b.val[1] = vandq_u8(val[1], m4b);
|
||||
b.val[3] = vshrq_n_u8(val[1], 4);
|
||||
}
|
||||
inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const {
|
||||
b.val[0] = vandq_u8(val[0], m4b);
|
||||
b.val[1] = vshrq_n_u8(val[0], 4);
|
||||
b.val[2] = vandq_u8(val[1], m4b);
|
||||
b.val[3] = vshrq_n_u8(val[1], 4);
|
||||
}
|
||||
inline void prepare(const uint8_t * qs) {
|
||||
auto q4bits = vld1q_u8_x2(qs);
|
||||
prepare4(b1, q4bits.val);
|
||||
q4bits = vld1q_u8_x2(qs+32);
|
||||
prepare4(b2, q4bits.val);
|
||||
}
|
||||
inline void prepare_v2(const uint8_t * qs) {
|
||||
auto q4bits = vld1q_u8_x4(qs);
|
||||
prepare4(b1, q4bits.val+0);
|
||||
prepare4(b2, q4bits.val+2);
|
||||
}
|
||||
inline void prepare64(const uint8_t * qs) {
|
||||
auto q4bits = vld1q_u8_x4(qs);
|
||||
b1.val[0] = vandq_u8(q4bits.val[0], m4b);
|
||||
b1.val[1] = vandq_u8(q4bits.val[1], m4b);
|
||||
b1.val[2] = vandq_u8(q4bits.val[2], m4b);
|
||||
b1.val[3] = vandq_u8(q4bits.val[3], m4b);
|
||||
b2.val[0] = vshrq_n_u8(q4bits.val[0], 4);
|
||||
b2.val[1] = vshrq_n_u8(q4bits.val[1], 4);
|
||||
b2.val[2] = vshrq_n_u8(q4bits.val[2], 4);
|
||||
b2.val[3] = vshrq_n_u8(q4bits.val[3], 4);
|
||||
}
|
||||
inline void prepare16(const uint8_t * qs) {
|
||||
auto q4bits = vld1q_u8_x2(qs);
|
||||
prepare4_16(b1, q4bits.val);
|
||||
q4bits = vld1q_u8_x2(qs+32);
|
||||
prepare4_16(b2, q4bits.val);
|
||||
}
|
||||
inline void prepare16_v2(const uint8_t * qs) {
|
||||
auto q4bits = vld1q_u8_x4(qs);
|
||||
prepare4_16(b1, q4bits.val+0);
|
||||
prepare4_16(b2, q4bits.val+2);
|
||||
}
|
||||
};
|
||||
|
||||
struct Q2bits {
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x03);
|
||||
uint8x16x4_t b1, b2;
|
||||
inline void prepare(const uint8_t * qs) {
|
||||
auto q2bits = vld1q_u8_x2(qs);
|
||||
b1.val[0] = vandq_u8(q2bits.val[0], m4b);
|
||||
b1.val[1] = vandq_u8(q2bits.val[1], m4b);
|
||||
|
||||
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
|
||||
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
|
||||
b1.val[2] = vandq_u8(q2bits.val[0], m4b);
|
||||
b1.val[3] = vandq_u8(q2bits.val[1], m4b);
|
||||
|
||||
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
|
||||
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
|
||||
b2.val[0] = vandq_u8(q2bits.val[0], m4b);
|
||||
b2.val[1] = vandq_u8(q2bits.val[1], m4b);
|
||||
|
||||
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
|
||||
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
|
||||
b2.val[2] = vandq_u8(q2bits.val[0], m4b);
|
||||
b2.val[3] = vandq_u8(q2bits.val[1], m4b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Q8>
|
||||
static inline void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
|
||||
const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) {
|
||||
auto mzero = vdupq_n_s32(0);
|
||||
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
|
||||
auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
|
||||
vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1]); // block 1
|
||||
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
|
||||
auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
|
||||
vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1]); // block 2
|
||||
auto p12 = vpaddq_s32(p1, p2);
|
||||
|
||||
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
|
||||
auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
|
||||
vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1]); // block 1
|
||||
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
|
||||
auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
|
||||
vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1]); // block 2
|
||||
auto p34 = vpaddq_s32(p3, p4);
|
||||
|
||||
auto pall = vpaddq_s32(p12, p34);
|
||||
sumi = vmlaq_s32(sumi, scales.val[j], pall);
|
||||
}
|
||||
|
||||
template <typename Q8>
|
||||
static inline void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
|
||||
const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) {
|
||||
|
||||
auto mzero = vdupq_n_s32(0);
|
||||
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
|
||||
auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
|
||||
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1,
|
||||
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
|
||||
auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
|
||||
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4,
|
||||
auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3
|
||||
sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12);
|
||||
|
||||
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
|
||||
auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
|
||||
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5,
|
||||
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
|
||||
auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
|
||||
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7,
|
||||
auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7
|
||||
sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34);
|
||||
}
|
||||
|
||||
struct SignHelper {
|
||||
|
||||
inline void init() { shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); }
|
||||
|
||||
inline void apply_signs_1(uint8x16_t * b, const uint8x16_t& signs16) {
|
||||
auto aux = vqtbl1q_u8(signs16, shuffle);
|
||||
auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1));
|
||||
b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s));
|
||||
shuffle = vaddq_u8(shuffle, step);
|
||||
}
|
||||
|
||||
const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
|
||||
const uint8x16_t m1 = vdupq_n_u8(1);
|
||||
const uint8x16_t step = vdupq_n_u8(2);
|
||||
uint8x16_t shuffle;
|
||||
};
|
||||
|
||||
template <typename Dequantizer, int nrc_y>
|
||||
static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
const int nb = n / QK_K;
|
||||
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
|
||||
Dequantizer deq(vx, bx, nrc_y);
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
deq.new_row(ix);
|
||||
|
||||
float32x4_t acc[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
int32x4_t sumi[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);
|
||||
|
||||
if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) {
|
||||
deq.process_scales(i, q8, acc);
|
||||
deq.prepare(i, 0);
|
||||
deq.compute(q8, i, 0, sumi);
|
||||
deq.prepare(i, 1);
|
||||
deq.compute(q8, i, 1, sumi);
|
||||
} else {
|
||||
if constexpr (Dequantizer::num_blocks() == 8) {
|
||||
auto scales = deq.new_block(i, q8, acc);
|
||||
deq.prepare(i, 0);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
|
||||
deq.prepare(i, 1);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
|
||||
}
|
||||
else if constexpr (Dequantizer::num_blocks() == 16) {
|
||||
auto scales = deq.new_block(i, q8, acc);
|
||||
deq.prepare(i, 0);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
|
||||
deq.prepare(i, 1);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
|
||||
}
|
||||
else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, vaddvq_f32(acc[iy]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16x2_t& y) {
|
||||
auto sumi = vdupq_n_s32(0);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
|
||||
return sumi;
|
||||
}
|
||||
|
||||
static IQK_ALWAYS_INLINE int32x4x2_t interleaved_dotq_b16(const int8x16_t * qx, const int8x16x2_t& y) {
|
||||
int32x4x2_t sumi = { vdupq_n_s32(0), vdupq_n_s32(0) };
|
||||
sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[0], y.val[0], 0);
|
||||
sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[1], y.val[1], 0);
|
||||
sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[2], y.val[0], 1);
|
||||
sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[3], y.val[1], 1);
|
||||
sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[4], y.val[0], 2);
|
||||
sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[5], y.val[1], 2);
|
||||
sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[6], y.val[0], 3);
|
||||
sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[7], y.val[1], 3);
|
||||
return sumi;
|
||||
}
|
||||
|
||||
static IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16_t& y) {
|
||||
auto sumi = vdupq_n_s32(0);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[0], y, 0);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[1], y, 1);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2], y, 2);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[3], y, 3);
|
||||
return sumi;
|
||||
}
|
||||
|
||||
static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) {
|
||||
qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
|
||||
qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
|
||||
qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
|
||||
qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
|
||||
qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
|
||||
qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
|
||||
qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
|
||||
}
|
||||
|
||||
static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x2_t& bits, int8x16_t * qx) {
|
||||
qx[0] = vqtbl1q_s8(values, vandq_u8( bits.val[0], m4));
|
||||
qx[1] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4));
|
||||
qx[2] = vqtbl1q_s8(values, vandq_u8( bits.val[1], m4));
|
||||
qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
2282
ggml/src/iqk/iqk_gemm_1bit.cpp
Normal file
2282
ggml/src/iqk/iqk_gemm_1bit.cpp
Normal file
File diff suppressed because it is too large
Load Diff
11
ggml/src/iqk/iqk_gemm_1bit.h
Normal file
11
ggml/src/iqk/iqk_gemm_1bit.h
Normal file
@@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "iqk_common.h"
|
||||
|
||||
#ifdef IQK_IMPLEMENT
|
||||
|
||||
#include <array>
|
||||
|
||||
bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
|
||||
|
||||
#endif
|
||||
1048
ggml/src/iqk/iqk_gemm_floats.cpp
Normal file
1048
ggml/src/iqk/iqk_gemm_floats.cpp
Normal file
File diff suppressed because it is too large
Load Diff
13
ggml/src/iqk/iqk_gemm_floats.h
Normal file
13
ggml/src/iqk/iqk_gemm_floats.h
Normal file
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "iqk_common.h"
|
||||
|
||||
#ifdef IQK_IMPLEMENT
|
||||
|
||||
#include <array>
|
||||
|
||||
bool iqk_set_kernels_float(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels);
|
||||
|
||||
void iqk_gemm_default_floats(int D, int nq, const char * vx, size_t bx, DataInfo& info, int k_step);
|
||||
|
||||
#endif
|
||||
3289
ggml/src/iqk/iqk_gemm_iqk_quants.cpp
Normal file
3289
ggml/src/iqk/iqk_gemm_iqk_quants.cpp
Normal file
File diff suppressed because it is too large
Load Diff
11
ggml/src/iqk/iqk_gemm_iqk_quants.h
Normal file
11
ggml/src/iqk/iqk_gemm_iqk_quants.h
Normal file
@@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "iqk_common.h"
|
||||
|
||||
#ifdef IQK_IMPLEMENT
|
||||
|
||||
#include <array>
|
||||
|
||||
bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
|
||||
|
||||
#endif
|
||||
2252
ggml/src/iqk/iqk_gemm_iquants.cpp
Normal file
2252
ggml/src/iqk/iqk_gemm_iquants.cpp
Normal file
File diff suppressed because it is too large
Load Diff
11
ggml/src/iqk/iqk_gemm_iquants.h
Normal file
11
ggml/src/iqk/iqk_gemm_iquants.h
Normal file
@@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "iqk_common.h"
|
||||
|
||||
#ifdef IQK_IMPLEMENT
|
||||
|
||||
#include <array>
|
||||
|
||||
bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
|
||||
|
||||
#endif
|
||||
3121
ggml/src/iqk/iqk_gemm_kquants.cpp
Normal file
3121
ggml/src/iqk/iqk_gemm_kquants.cpp
Normal file
File diff suppressed because it is too large
Load Diff
13
ggml/src/iqk/iqk_gemm_kquants.h
Normal file
13
ggml/src/iqk/iqk_gemm_kquants.h
Normal file
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "iqk_common.h"
|
||||
|
||||
#ifdef IQK_IMPLEMENT
|
||||
|
||||
#include <array>
|
||||
|
||||
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
|
||||
|
||||
void iqk_gemm_q8kv_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step);
|
||||
|
||||
#endif
|
||||
2763
ggml/src/iqk/iqk_gemm_legacy_quants.cpp
Normal file
2763
ggml/src/iqk/iqk_gemm_legacy_quants.cpp
Normal file
File diff suppressed because it is too large
Load Diff
14
ggml/src/iqk/iqk_gemm_legacy_quants.h
Normal file
14
ggml/src/iqk/iqk_gemm_legacy_quants.h
Normal file
@@ -0,0 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include "iqk_common.h"
|
||||
|
||||
#ifdef IQK_IMPLEMENT
|
||||
|
||||
#include <array>
|
||||
#include <utility>
|
||||
|
||||
bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
|
||||
|
||||
void iqk_gemm_legacy_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step);
|
||||
|
||||
#endif
|
||||
File diff suppressed because it is too large
Load Diff
207
ggml/src/iqk/iqk_utils.h
Normal file
207
ggml/src/iqk/iqk_utils.h
Normal file
@@ -0,0 +1,207 @@
|
||||
#pragma once
|
||||
|
||||
#include "iqk_config.h"
|
||||
|
||||
#if defined IQK_IMPLEMENT
|
||||
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#if defined(__ARM_NEON) && defined(__aarch64__)
|
||||
// copy-pasted from Justine Tunney's contribution to llama.cpp
|
||||
// adapted from arm limited optimized routine
|
||||
// the maximum error is 1.45358 plus 0.5 ulps
|
||||
// numbers above 88.38 will flush to infinity
|
||||
// numbers beneath -103.97 will flush to zero
|
||||
static inline float32x4_t v_expf(float32x4_t x) {
|
||||
const float32x4_t r = vdupq_n_f32(0x1.8p23f);
|
||||
const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
|
||||
const float32x4_t n = vsubq_f32(z, r);
|
||||
const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
|
||||
vdupq_n_f32(0x1.7f7d1cp-20f));
|
||||
const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
|
||||
const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
|
||||
const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
|
||||
const float32x4_t u = vmulq_f32(b, b);
|
||||
const float32x4_t j = vfmaq_f32(
|
||||
vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
|
||||
vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
|
||||
vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
|
||||
if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
|
||||
return vfmaq_f32(k, j, k);
|
||||
const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
|
||||
const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
|
||||
const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
|
||||
return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
|
||||
vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
|
||||
}
|
||||
static inline float16x8_t v_expf(float16x8_t x) {
|
||||
auto val1 = v_expf(vcvt_f32_f16(vget_low_f16(x)));
|
||||
auto val2 = v_expf(vcvt_f32_f16(vget_high_f16(x)));
|
||||
return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
|
||||
}
|
||||
static inline float32x4_t v_tanh(float32x4_t x) {
|
||||
const float32x4_t one = vdupq_n_f32(1.0f);
|
||||
const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f));
|
||||
const float32x4_t exp_two_x = v_expf(two_x);
|
||||
const uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f));
|
||||
const float32x4_t res = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
|
||||
return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask)));
|
||||
//return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
|
||||
}
|
||||
//inline float32x4_t v_tanh(float16x8_t x) {
|
||||
// auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x)));
|
||||
// auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x)));
|
||||
// return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
|
||||
//}
|
||||
static inline float32x4_t v_silu(float32x4_t x) {
|
||||
const float32x4_t one = vdupq_n_f32(1.0f);
|
||||
const float32x4_t zero = vdupq_n_f32(0.0f);
|
||||
const float32x4_t neg_x = vsubq_f32(zero, x);
|
||||
const float32x4_t exp_neg_x = v_expf(neg_x);
|
||||
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
|
||||
return vdivq_f32(x, one_plus_exp_neg_x);
|
||||
}
|
||||
static inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) {
|
||||
const float32x4_t one = vdupq_n_f32(1.0f);
|
||||
float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x));
|
||||
arg = vmulq_f32(arg, vmulq_f32(x, c2));
|
||||
float32x4_t exp_arg = v_expf(arg);
|
||||
float32x4_t gelu = vmulq_f32(x, vdivq_f32(exp_arg, vaddq_f32(exp_arg, one)));
|
||||
uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f));
|
||||
return vbslq_f32(mask, x, gelu);
|
||||
}
|
||||
|
||||
#endif // __ARN_NEON
|
||||
|
||||
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||
|
||||
// copy-pasted from Justine Tunney's contribution to llama.cpp
|
||||
// adapted from arm limited optimized routine
|
||||
// the maximum error is 1.45358 plus 0.5 ulps
|
||||
// numbers above 88.38 will flush to infinity
|
||||
// numbers beneath -103.97 will flush to zero
|
||||
static inline __m512 v_expf(__m512 x) {
|
||||
const __m512 r = _mm512_set1_ps(0x1.8p23f);
|
||||
const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
|
||||
const __m512 n = _mm512_sub_ps(z, r);
|
||||
const __m512 b =
|
||||
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
|
||||
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
|
||||
const __mmask16 d =
|
||||
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
|
||||
const __m512 u = _mm512_mul_ps(b, b);
|
||||
const __m512 j = _mm512_fmadd_ps(
|
||||
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
|
||||
_mm512_set1_ps(0x1.573e2ep-5f)),
|
||||
u,
|
||||
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
|
||||
_mm512_set1_ps(0x1.fffdb6p-2f))),
|
||||
u,
|
||||
_mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
|
||||
const __m512 res = _mm512_scalef_ps(j, n);
|
||||
if (_mm512_kortestz(d, d))
|
||||
return res;
|
||||
const __m512 zero = _mm512_setzero_ps();
|
||||
const __m512 alt = _mm512_mask_blend_ps(
|
||||
_mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
|
||||
return _mm512_mask_blend_ps(d, res, alt);
|
||||
}
|
||||
static inline __m512 v_tanh(__m512 x) {
|
||||
const __m512 one = _mm512_set1_ps(1.0f);
|
||||
const __m512 exp_two_x = v_expf(_mm512_mul_ps(x, _mm512_set1_ps(2.f)));
|
||||
const __mmask16 mask = _mm512_cmp_ps_mask(x, _mm512_set1_ps(10.f), _CMP_GT_OQ);
|
||||
const __m512 res = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one));
|
||||
return _mm512_mask_blend_ps(mask, res, one);
|
||||
}
|
||||
static inline __m512 v_gelu(__m512 x, __m512 c1, __m512 c2) {
|
||||
const __m512 one = _mm512_set1_ps(1.0f);
|
||||
__m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one);
|
||||
//__m512 arg = _mm512_add_ps(one, _mm512_mul_ps(_mm512_mul_ps(x, x), c1));
|
||||
arg = _mm512_mul_ps(arg, _mm512_mul_ps(c2, x));
|
||||
const __mmask16 mask = _mm512_cmp_ps_mask(arg, _mm512_set1_ps(30.f), _CMP_GT_OQ);
|
||||
const __m512 exp_arg = v_expf(arg);
|
||||
const __m512 ratio = _mm512_div_ps(exp_arg, _mm512_add_ps(exp_arg, one));
|
||||
return _mm512_mul_ps(x, _mm512_mask_blend_ps(mask, ratio, one));
|
||||
}
|
||||
static inline __m512 v_silu(__m512 x) {
|
||||
const __m512 one = _mm512_set1_ps(1);
|
||||
const __m512 zero = _mm512_setzero_ps();
|
||||
const __m512 neg_x = _mm512_sub_ps(zero, x);
|
||||
const __m512 exp_neg_x = v_expf(neg_x);
|
||||
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
|
||||
return _mm512_div_ps(x, one_plus_exp_neg_x);
|
||||
}
|
||||
#endif // __AVX512__
|
||||
|
||||
#if defined(__AVX2__) && defined(__FMA__)
|
||||
|
||||
// adapted from arm limited optimized routine
|
||||
// the maximum error is 1.45358 plus 0.5 ulps
|
||||
// numbers above 88.38 will flush to infinity
|
||||
// numbers beneath -103.97 will flush to zero
|
||||
static inline __m256 v_expf(__m256 x) {
|
||||
const __m256 r = _mm256_set1_ps(0x1.8p23f);
|
||||
const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
|
||||
const __m256 n = _mm256_sub_ps(z, r);
|
||||
const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
|
||||
_mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
|
||||
const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
|
||||
const __m256 k = _mm256_castsi256_ps(
|
||||
_mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
|
||||
const __m256i c = _mm256_castps_si256(
|
||||
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
|
||||
_mm256_set1_ps(126), _CMP_GT_OQ));
|
||||
const __m256 u = _mm256_mul_ps(b, b);
|
||||
const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
|
||||
_mm256_set1_ps(0x1.573e2ep-5f)), u,
|
||||
_mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
|
||||
_mm256_set1_ps(0x1.fffdb6p-2f))),
|
||||
u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
|
||||
if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
|
||||
return _mm256_fmadd_ps(j, k, k);
|
||||
const __m256i g = _mm256_and_si256(
|
||||
_mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
|
||||
_mm256_set1_epi32(0x82000000u));
|
||||
const __m256 s1 =
|
||||
_mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
|
||||
const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
|
||||
const __m256i d = _mm256_castps_si256(
|
||||
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
|
||||
_mm256_set1_ps(192), _CMP_GT_OQ));
|
||||
return _mm256_or_ps(
|
||||
_mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
|
||||
_mm256_andnot_ps(
|
||||
_mm256_castsi256_ps(d),
|
||||
_mm256_or_ps(
|
||||
_mm256_and_ps(_mm256_castsi256_ps(c),
|
||||
_mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
|
||||
_mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
|
||||
}
|
||||
static inline __m256 v_tanh(__m256 x) {
|
||||
const __m256 one = _mm256_set1_ps(1.0f);
|
||||
const __m256 exp_two_x = v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f)));
|
||||
const __m256 res = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
|
||||
const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ);
|
||||
return _mm256_or_ps(_mm256_and_ps(mask, one), _mm256_andnot_ps(mask, res));
|
||||
}
|
||||
static inline __m256 v_gelu(__m256 x, __m256 c1, __m256 c2) {
|
||||
const __m256 one = _mm256_set1_ps(1.0f);
|
||||
const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ);
|
||||
__m256 arg = _mm256_add_ps(one, _mm256_mul_ps(_mm256_mul_ps(x, x), c1));
|
||||
arg = _mm256_mul_ps(arg, _mm256_mul_ps(x, c2));
|
||||
__m256 exp_arg = v_expf(arg);
|
||||
__m256 gelu = _mm256_mul_ps(x, _mm256_div_ps(exp_arg, _mm256_add_ps(exp_arg, one)));
|
||||
return _mm256_or_ps(_mm256_and_ps(mask, x), _mm256_andnot_ps(mask, gelu));
|
||||
}
|
||||
static inline __m256 v_silu(__m256 x) {
|
||||
const __m256 one = _mm256_set1_ps(1);
|
||||
const __m256 zero = _mm256_setzero_ps();
|
||||
const __m256 neg_x = _mm256_sub_ps(zero, x);
|
||||
const __m256 exp_neg_x = v_expf(neg_x);
|
||||
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
|
||||
return _mm256_div_ps(x, one_plus_exp_neg_x);
|
||||
}
|
||||
|
||||
#endif // __AVX2__
|
||||
|
||||
#endif // IQK_IMPLEMENT
|
||||
Reference in New Issue
Block a user