mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 15:44:10 +00:00
Zen4 Flash Attnetion: improving bf16
This commit is contained in:
@@ -6756,6 +6756,13 @@ struct HelperBF16 final : public BaseHelper<step> {
|
||||
load(l1+0, vk+0);
|
||||
load(l1+1, vk+D/32);
|
||||
}
|
||||
|
||||
inline void load_4(int l1, __m512bh * vk) const {
|
||||
load(l1+0, vk+0);
|
||||
load(l1+1, vk+1*D/32);
|
||||
load(l1+2, vk+2*D/32);
|
||||
load(l1+3, vk+3*D/32);
|
||||
}
|
||||
};
|
||||
|
||||
template <int D, int q_step, int k_step>
|
||||
@@ -6791,16 +6798,50 @@ struct FlashQKbf16 {
|
||||
}
|
||||
}
|
||||
|
||||
static inline void mult_mask_kq_4(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
|
||||
__m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
|
||||
// q index is q_step*i1 + m1
|
||||
// k index is k_step*k1 + l1
|
||||
const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
|
||||
fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] =
|
||||
fms.cache[k_step*m1 + l1 + 2] = fms.cache[k_step*m1 + l1 + 3] = -INFINITY;
|
||||
if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf && mp[l1+2] == fms.h_inf && mp[l1+3] == fms.h_inf) {
|
||||
return;
|
||||
}
|
||||
auto qr = q + m1*stride_q;
|
||||
for (int i = 0; i < D/32; ++i) {
|
||||
auto val1 = _mm512_loadu_ps(qr + 32*i);
|
||||
auto val2 = _mm512_loadu_ps(qr + 32*i + 16);
|
||||
qv[i] = _mm512_cvtne2ps_pbh(val2, val1);
|
||||
}
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
if (mp[l1+k] == fms.h_inf) continue;
|
||||
auto vsum = _mm512_setzero_ps();
|
||||
for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]);
|
||||
fms.cache[k_step*m1 + l1 + k] = _mm512_reduce_add_ps(vsum);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename KHelper>
|
||||
static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q,
|
||||
const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
{
|
||||
__m512bh qv[D/32];
|
||||
__m512bh vkh[D/16];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||
kh.load_2(l1, vkh);
|
||||
for (int m1 = 0; m1 < q_step; ++m1) {
|
||||
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vkh, fms);
|
||||
if constexpr (D <= 128) {
|
||||
__m512bh vkh[D/8];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 4) {
|
||||
kh.load_4(l1, vkh);
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
mult_mask_kq_4(l1, j, stride_q, stride_m, q, mask, qv, vkh, fms);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
__m512bh vkh[D/16];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||
kh.load_2(l1, vkh);
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
mult_mask_kq_one(l1, j, stride_q, stride_m, q, mask, qv, vkh, fms);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user