mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-03 18:40:14 +00:00
Refactor iqk: FA refactored (Zen4)
Compile time for the FA files is now ~21 seconds on my Ryzen-7950X, so still slightly too long for my taste but much better than the 142 seconds we had before.
This commit is contained in:
@@ -260,6 +260,12 @@ if (GGML_IQK_MUL_MAT)
|
||||
add_compile_definitions(GGML_USE_IQK_MULMAT)
|
||||
set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp
|
||||
iqk/iqk_flash_attn.cpp
|
||||
iqk/fa/iqk_fa_576_512.cpp
|
||||
iqk/fa/iqk_fa_192_128.cpp
|
||||
iqk/fa/iqk_fa_256_256.cpp
|
||||
iqk/fa/iqk_fa_128_128.cpp
|
||||
iqk/fa/iqk_fa_96_96.cpp
|
||||
iqk/fa/iqk_fa_64_64.cpp
|
||||
iqk/iqk_gemm_floats.cpp
|
||||
iqk/iqk_gemm_kquants.cpp
|
||||
iqk/iqk_gemm_iquants.cpp
|
||||
@@ -268,6 +274,7 @@ if (GGML_IQK_MUL_MAT)
|
||||
iqk/iqk_gemm_legacy_quants.cpp)
|
||||
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h
|
||||
iqk/iqk_flash_impl.h
|
||||
iqk/fa/iqk_fa_templates.h
|
||||
iqk/iqk_gemm_floats.h
|
||||
iqk/iqk_gemm_kquants.h
|
||||
iqk/iqk_gemm_iquants.h
|
||||
|
||||
45
ggml/src/iqk/fa/iqk_fa_128_128.cpp
Normal file
45
ggml/src/iqk/fa/iqk_fa_128_128.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
#include "iqk/iqk_config.h"
|
||||
|
||||
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
|
||||
|
||||
#include "iqk/fa/iqk_fa_templates.h"
|
||||
|
||||
IQK_FA_CASE(iqk_fa_128_128) {
|
||||
|
||||
auto type_k = ggml_type(int_type_k);
|
||||
auto type_v = ggml_type(int_type_v);
|
||||
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
auto ck = (const char *)k;
|
||||
auto cv = (const char *)v;
|
||||
auto cm = (const char *)mask;
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
if (type_k == GGML_TYPE_BF16) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
if (nk%64 == 0) {
|
||||
iqk_flash_helper_T<128, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
iqk_flash_helper_T<128, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk%128 == 0) {
|
||||
return iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
if (nk%64 == 0) {
|
||||
return iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
|
||||
return iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
45
ggml/src/iqk/fa/iqk_fa_192_128.cpp
Normal file
45
ggml/src/iqk/fa/iqk_fa_192_128.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
#include "iqk/iqk_config.h"
|
||||
|
||||
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
|
||||
|
||||
#include "iqk/fa/iqk_fa_templates.h"
|
||||
|
||||
IQK_FA_CASE(iqk_fa_192_128) {
|
||||
|
||||
auto type_k = ggml_type(int_type_k);
|
||||
auto type_v = ggml_type(int_type_v);
|
||||
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
auto ck = (const char *)k;
|
||||
auto cv = (const char *)v;
|
||||
auto cm = (const char *)mask;
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
if (type_k == GGML_TYPE_BF16) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
if (nk%64 == 0) {
|
||||
iqk_flash_helper_T<192, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
iqk_flash_helper_T<192, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk%128 == 0) {
|
||||
return iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
if (nk%64 == 0) {
|
||||
return iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
|
||||
return iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
45
ggml/src/iqk/fa/iqk_fa_256_256.cpp
Normal file
45
ggml/src/iqk/fa/iqk_fa_256_256.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
#include "iqk/iqk_config.h"
|
||||
|
||||
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
|
||||
|
||||
#include "iqk/fa/iqk_fa_templates.h"
|
||||
|
||||
IQK_FA_CASE(iqk_fa_256_256) {
|
||||
|
||||
auto type_k = ggml_type(int_type_k);
|
||||
auto type_v = ggml_type(int_type_v);
|
||||
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
auto ck = (const char *)k;
|
||||
auto cv = (const char *)v;
|
||||
auto cm = (const char *)mask;
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
if (type_k == GGML_TYPE_BF16) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
if (nk%64 == 0) {
|
||||
iqk_flash_helper_T<256, 256, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
iqk_flash_helper_T<256, 256, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk%128 == 0) {
|
||||
return iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
if (nk%64 == 0) {
|
||||
return iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
|
||||
return iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
118
ggml/src/iqk/fa/iqk_fa_576_512.cpp
Normal file
118
ggml/src/iqk/fa/iqk_fa_576_512.cpp
Normal file
@@ -0,0 +1,118 @@
|
||||
#include "iqk/iqk_config.h"
|
||||
|
||||
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
|
||||
|
||||
#include "iqk/fa/iqk_fa_templates.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <int step_k, typename KHelper, typename VHelper>
|
||||
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
|
||||
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
|
||||
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
|
||||
nq1 -= n;
|
||||
if (nq1 == 0) return true;
|
||||
q += n*stride_q;
|
||||
mask += n*stride_m;
|
||||
qkv += n*stride_qkv;
|
||||
if (M && S) { M += n; S += n; }
|
||||
return false;
|
||||
};
|
||||
if (nq1 >= 16) {
|
||||
int n_step = nq1/16;
|
||||
FlashAttn<576, 512, 16, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (update(16*n_step)) return;
|
||||
}
|
||||
if (nq1 >= 8) {
|
||||
int n_step = nq1/8;
|
||||
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (update(8*n_step)) return;
|
||||
}
|
||||
if (nq1 >= 4) {
|
||||
int n_step = nq1/4;
|
||||
FlashAttn<576, 512, 4, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (update(4*n_step)) return;
|
||||
}
|
||||
if (nq1 >= 2) {
|
||||
int n_step = nq1/2;
|
||||
FlashAttn<576, 512, 2, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (update(2*n_step)) return;
|
||||
}
|
||||
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
}
|
||||
|
||||
template <int step_k>
|
||||
inline bool iqk_deepseek_helper(ggml_type type_k,
|
||||
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||
const float * q, const char * k, const char * v, const char * mask,
|
||||
float scale, float softcap, float * qkv, float * M, float * S) {
|
||||
if (type_k == GGML_TYPE_Q8_0) {
|
||||
HelperQ80<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperQ80<512, step_k> vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
if (type_k == GGML_TYPE_Q8_0_R8) {
|
||||
HelperQ80R8<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperQ80<512, step_k> vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
if (type_k == GGML_TYPE_Q6_0) {
|
||||
HelperQ60<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperQ60<512, step_k> vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
if (type_k == GGML_TYPE_Q8_KV) {
|
||||
HelperQ8KV<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperQ8KV<512, step_k> vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
if (type_k == GGML_TYPE_F16) {
|
||||
HelperF16<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperF16<512, step_k> vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
#ifdef __AVX512BF16__
|
||||
if (type_k == GGML_TYPE_BF16) {
|
||||
HelperBF16<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperBF16<512, step_k> vh((const char *)v, stride_v);
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
} else {
|
||||
FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
IQK_FA_CASE(iqk_fa_576_512) {
|
||||
|
||||
auto type_k = ggml_type(int_type_k);
|
||||
auto type_v = ggml_type(int_type_v);
|
||||
|
||||
if (!(type_k == type_v || (type_k == GGML_TYPE_Q8_0_R8 && type_v == GGML_TYPE_Q8_0))) {
|
||||
return false;
|
||||
}
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
45
ggml/src/iqk/fa/iqk_fa_64_64.cpp
Normal file
45
ggml/src/iqk/fa/iqk_fa_64_64.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
#include "iqk/iqk_config.h"
|
||||
|
||||
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
|
||||
|
||||
#include "iqk/fa/iqk_fa_templates.h"
|
||||
|
||||
IQK_FA_CASE(iqk_fa_64_64) {
|
||||
|
||||
auto type_k = ggml_type(int_type_k);
|
||||
auto type_v = ggml_type(int_type_v);
|
||||
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
auto ck = (const char *)k;
|
||||
auto cv = (const char *)v;
|
||||
auto cm = (const char *)mask;
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
if (type_k == GGML_TYPE_BF16) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
if (nk%64 == 0) {
|
||||
iqk_flash_helper_T<64, 64, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
iqk_flash_helper_T<64, 64, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk%128 == 0) {
|
||||
return iqk_flash_helper_T<64, 64, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
if (nk%64 == 0) {
|
||||
return iqk_flash_helper_T<64, 64, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
|
||||
return iqk_flash_helper_T<64, 64, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
45
ggml/src/iqk/fa/iqk_fa_96_96.cpp
Normal file
45
ggml/src/iqk/fa/iqk_fa_96_96.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
#include "iqk/iqk_config.h"
|
||||
|
||||
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
|
||||
|
||||
#include "iqk/fa/iqk_fa_templates.h"
|
||||
|
||||
IQK_FA_CASE(iqk_fa_96_96) {
|
||||
|
||||
auto type_k = ggml_type(int_type_k);
|
||||
auto type_v = ggml_type(int_type_v);
|
||||
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
auto ck = (const char *)k;
|
||||
auto cv = (const char *)v;
|
||||
auto cm = (const char *)mask;
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
if (type_k == GGML_TYPE_BF16) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
if (nk%64 == 0) {
|
||||
iqk_flash_helper_T<96, 96, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
iqk_flash_helper_T<96, 96, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk%128 == 0) {
|
||||
return iqk_flash_helper_T<96, 96, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
if (nk%64 == 0) {
|
||||
return iqk_flash_helper_T<96, 96, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
}
|
||||
|
||||
return iqk_flash_helper_T<96, 96, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
2287
ggml/src/iqk/fa/iqk_fa_templates.h
Normal file
2287
ggml/src/iqk/fa/iqk_fa_templates.h
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user