mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 23:24:13 +00:00
Slightly faster FA for bf16 KV cache
~2-3% sort of thing. Sadly, when we go beyond 8k tokens, the advantage kind of goes away.
This commit is contained in:
@@ -13864,6 +13864,11 @@ struct FlashQKbf16 {
|
||||
}
|
||||
}
|
||||
|
||||
static inline __m128 hsum_float_4x4(__m128 * a) {
|
||||
for (int i = 0; i < 2; ++i) a[i] = _mm_add_ps(_mm_unpacklo_ps(a[i], a[i+2]), _mm_unpackhi_ps(a[i], a[i+2]));
|
||||
return _mm_add_ps(_mm_unpacklo_ps(a[0], a[1]), _mm_unpackhi_ps(a[0], a[1]));
|
||||
}
|
||||
|
||||
template <typename KHelper>
|
||||
static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q,
|
||||
const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
@@ -13893,6 +13898,34 @@ struct FlashQKbf16 {
|
||||
}
|
||||
}
|
||||
|
||||
static inline void mult_mask_kq_4(int l1, int m1, const ggml_bf16_t * q,
|
||||
__m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
|
||||
auto qr = q + m1*D;
|
||||
for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i *)qr + i));
|
||||
__m128 sum[4];
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto vsum = _mm512_setzero_ps();
|
||||
for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]);
|
||||
auto aux = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1));
|
||||
sum[k] = _mm_add_ps(_mm256_castps256_ps128(aux), _mm256_extractf128_ps(aux, 1));
|
||||
}
|
||||
//auto sum4 = _mm_mask_blend_ps(m8, hsum_float_4x4(sum), _mm_set1_ps(-INFINITY));
|
||||
//_mm_storeu_ps(fms.cache + k_step*m1 + l1, sum4);
|
||||
_mm_storeu_ps(fms.cache + k_step*m1 + l1, hsum_float_4x4(sum));
|
||||
}
|
||||
|
||||
static inline void mult_mask_kq_one(int l1, int m1, const ggml_bf16_t * q,
|
||||
__m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
|
||||
auto qr = q + m1*D;
|
||||
for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i*)qr + i));
|
||||
auto vsum = _mm512_setzero_ps();
|
||||
for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i], qv[i]);
|
||||
fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
|
||||
vsum = _mm512_setzero_ps();
|
||||
for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+D/32], qv[i]);
|
||||
fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
|
||||
}
|
||||
|
||||
template <typename KHelper>
|
||||
static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q,
|
||||
const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
@@ -13902,23 +13935,44 @@ struct FlashQKbf16 {
|
||||
__m512bh vkh[D/8];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 4) {
|
||||
kh.load_4(l1, vkh);
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
mult_mask_kq_4(l1, j, stride_m, q, mask, qv, vkh, fms);
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) mult_mask_kq_4(l1, j, q, qv, vkh, fms);
|
||||
}
|
||||
} else {
|
||||
__m512bh vkh[D/16];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||
kh.load_2(l1, vkh);
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
mult_mask_kq_one(l1, j, stride_m, q, mask, qv, vkh, fms);
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) mult_mask_kq_one(l1, j, q, qv, vkh, fms);
|
||||
}
|
||||
}
|
||||
}
|
||||
__m512 vk[k_step/16];
|
||||
F16::Data vk[k_step/16];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
fms.update_M_S(j, vk);
|
||||
fms.update_M_S(j, vk, mask + stride_m*j);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename KHelper>
|
||||
static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_m, const ggml_bf16_t * q,
|
||||
const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
{
|
||||
__m512bh qv[D/32];
|
||||
if constexpr (D <= 128) {
|
||||
__m512bh vkh[D/8];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 4) {
|
||||
kh.load_4(l1, vkh);
|
||||
for (int j = 0; j < nq; ++j) mult_mask_kq_4(l1, j, q, qv, vkh, fms);
|
||||
}
|
||||
} else {
|
||||
__m512bh vkh[D/16];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||
kh.load_2(l1, vkh);
|
||||
for (int j = 0; j < nq; ++j) mult_mask_kq_one(l1, j, q, qv, vkh, fms);
|
||||
}
|
||||
}
|
||||
}
|
||||
F16::Data vk[k_step/16];
|
||||
for (int j = 0; j < nq; ++j) {
|
||||
fms.update_M_S(j, vk, mask + stride_m*j);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13953,6 +14007,19 @@ struct FlashQKbf16 {
|
||||
bf16 += D;
|
||||
}
|
||||
}
|
||||
|
||||
static inline void convert(int nq, int stride_q, const float * q, ggml_bf16_t * bf16) {
|
||||
auto qr = q;
|
||||
for (int j = 0; j < nq; ++j) {
|
||||
for (int i = 0; i < D/32; ++i) {
|
||||
auto val1 = _mm512_loadu_ps(qr + 32*i);
|
||||
auto val2 = _mm512_loadu_ps(qr + 32*i + 16);
|
||||
_mm512_storeu_si512((__m512i *)bf16 + i, (__m512i)_mm512_cvtne2ps_pbh(val2, val1));
|
||||
}
|
||||
qr += stride_q;
|
||||
bf16 += D;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int D, int q_step, int k_step>
|
||||
@@ -13991,9 +14058,10 @@ struct FlashAttnBF16 {
|
||||
fms.init_qstep();
|
||||
kh.reset_block();
|
||||
vh.reset_block();
|
||||
FlashQKbf16<D, q_step, k_step>::convert(n_left, stride_q, q, q_bf16);
|
||||
auto mr = mask;
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms);
|
||||
FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms);
|
||||
fqkv.accumulate_qkv(n_left, vh, fms);
|
||||
kh.next_block();
|
||||
vh.next_block();
|
||||
@@ -14158,7 +14226,7 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
|
||||
iqk_flash_helper_T< 96, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 128:
|
||||
iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 256:
|
||||
case 256
|
||||
iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
default:
|
||||
return false;
|
||||
|
||||
Reference in New Issue
Block a user