mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
FA: much better bf16 kv-cache speed for large contexts
We now hit 122 t/s for LLaMA-3.1-8B (quantized as iq4_xs and run-time-repacked) with a context of 32768. IIRC, the previous best for such large context was ~90 t/s. Non-negligible improvement at 16384 and 8192 as well: 173.4 and 214 t/s.
This commit is contained in:
@@ -13187,13 +13187,15 @@ struct FlashQKV {
|
||||
}
|
||||
}
|
||||
}
|
||||
F16::Data v1, v2;
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
vh.load(l1, i, v1, v2);
|
||||
F16::Data v1, v2, v3, v4;
|
||||
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||
vh.load(l1+0, i, v1, v2);
|
||||
vh.load(l1+1, i, v3, v4);
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto vs = F16::set1(fms.cache[k_step*j + l1]);
|
||||
vk[2*j+0] = F16::fmadd(vk[2*j+0], v1, vs);
|
||||
vk[2*j+1] = F16::fmadd(vk[2*j+1], v2, vs);
|
||||
auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]);
|
||||
auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]);
|
||||
vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2);
|
||||
vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
@@ -13221,13 +13223,15 @@ struct FlashQKV {
|
||||
}
|
||||
}
|
||||
}
|
||||
F16::Data v1, v2;
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
vh.load(l1, i, v1, v2);
|
||||
F16::Data v1, v2, v3, v4;
|
||||
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||
vh.load(l1+0, i, v1, v2);
|
||||
vh.load(l1+1, i, v3, v4);
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
auto vs = F16::set1(fms.cache[k_step*j + l1]);
|
||||
vk[2*j+0] = F16::fmadd(vk[2*j+0], v1, vs);
|
||||
vk[2*j+1] = F16::fmadd(vk[2*j+1], v2, vs);
|
||||
auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]);
|
||||
auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]);
|
||||
vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2);
|
||||
vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
@@ -14149,14 +14153,25 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
|
||||
}
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
template <int D, int q_step, int k_step>
|
||||
template <int D, int k_step>
|
||||
inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||
const float * q, const char * k, const char * v, const char * mask,
|
||||
float scale, float softcap, float * qkv) {
|
||||
HelperBF16<D, k_step> kh(k, stride_k);
|
||||
HelperBF16<D, k_step> vh(v, stride_v);
|
||||
if (nq1 >= q_step) {
|
||||
FlashAttnBF16<D, q_step, k_step> fa(scale, softcap);
|
||||
if (nk1 >= 4096) {
|
||||
if (nq1 >= 64) {
|
||||
FlashAttnBF16<D, 64, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
}
|
||||
else if (nq1 >= 16) {
|
||||
FlashAttnBF16<D, 16, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (nq1 >= 8) {
|
||||
FlashAttnBF16<D, 8, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
} else {
|
||||
FlashAttnBF16<D, 1, k_step> fa(scale, softcap);
|
||||
@@ -14176,10 +14191,12 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
|
||||
HelperF16<D, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
case GGML_TYPE_BF16: {
|
||||
HelperBF16<D, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
#endif
|
||||
case GGML_TYPE_Q8_0: {
|
||||
HelperQ80<D, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
|
||||
@@ -14215,12 +14232,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
|
||||
HelperF16<D, k_step> kh(k, stride_k);
|
||||
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
case GGML_TYPE_Q8_0: {
|
||||
HelperQ80<D, k_step> kh(k, stride_k);
|
||||
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
#endif
|
||||
case GGML_TYPE_Q4_0: {
|
||||
HelperQ40<D, k_step> kh(k, stride_k);
|
||||
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
|
||||
@@ -14286,13 +14301,13 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
switch (D) {
|
||||
case 64:
|
||||
iqk_flash_helper_T< 64, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
iqk_flash_helper_T< 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 96:
|
||||
iqk_flash_helper_T< 96, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
iqk_flash_helper_T< 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 128:
|
||||
iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
iqk_flash_helper_T<128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 256:
|
||||
iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
iqk_flash_helper_T<256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user