mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-01 01:24:08 +00:00
FA: don't use large Q steps on AVX2 for fp16 K-cache
This commit is contained in:
@@ -13793,16 +13793,23 @@ template <int D, int k_step, typename KHelper, typename VHelper>
|
||||
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
const float * q, const char * mask, float scale, float softcap, float * qkv) {
|
||||
|
||||
if (nk1 >= 4096) {
|
||||
if (nq1 >= 32) {
|
||||
FlashAttn<D, 32, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
return;
|
||||
}
|
||||
else if (nq1 >= 8) {
|
||||
FlashAttn<D, 8, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
return;
|
||||
#if defined __AVX2__ && !HAVE_FANCY_SIMD
|
||||
constexpr bool kUseLargeStepsQ = !std::is_same_v<KHelper, HelperF16<D, k_step>>;
|
||||
#else
|
||||
constexpr bool kUseLargeStepsQ = true;
|
||||
#endif
|
||||
if constexpr (kUseLargeStepsQ) {
|
||||
if (nk1 >= 4096) {
|
||||
if (nq1 >= 32) {
|
||||
FlashAttn<D, 32, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
return;
|
||||
}
|
||||
else if (nq1 >= 8) {
|
||||
FlashAttn<D, 8, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (nq1 >= 8) {
|
||||
|
||||
Reference in New Issue
Block a user