Zen4 Flash Attnetion: WIP bf16

This commit is contained in:
Iwan Kawrakow
2024-09-04 13:09:24 +03:00
parent f17d0d72f5
commit 8218e77dec
3 changed files with 385 additions and 1 deletions

View File

@@ -2221,6 +2221,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "f16") {
return GGML_TYPE_F16;
}
if (s == "bf16") {
return GGML_TYPE_BF16;
}
if (s == "q8_0") {
return GGML_TYPE_Q8_0;
}

View File

@@ -306,6 +306,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
if (s == "f16") {
return GGML_TYPE_F16;
}
if (s == "bf16") {
return GGML_TYPE_BF16;
}
if (s == "q8_0") {
return GGML_TYPE_Q8_0;
}

View File

@@ -6192,7 +6192,6 @@ struct HelperF16 final : public BaseHelper<step> {
load(l1+0, vk+0);
load(l1+1, vk+D/16);
}
};
template <int D, int step>
@@ -6697,6 +6696,345 @@ struct FlashAttn {
}
};
#ifdef __AVX512BF16__
template <int D, int step>
struct HelperBF16 final : public BaseHelper<step> {
using Base = BaseHelper<step>;
HelperBF16(const char * data, int stride) : Base(data, stride) {}
inline void load(int l1, __m512bh * vk) const {
auto dr = Base::lblock(l1);
for (int i = 0; i < D/32; ++i) vk[i] = __m512bh(_mm512_loadu_si512((const __m512i*)dr + i));
}
inline void load(int l1, int i, __m512& v1, __m512& v2) const {
auto dr = Base::lblock(l1);
v1 = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)dr + i + 0)), 16));
v2 = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)dr + i + 1)), 16));
}
inline void load_2(int l1, __m512bh * vk) const {
load(l1+0, vk+0);
load(l1+1, vk+D/32);
}
};
template <int D, int q_step, int k_step>
struct FlashAttnBF16 {
static_assert(D%32 == 0 && D <= 256);
static_assert(k_step%32 == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
FlashAttnBF16(float scale, float softcap) : vscale(_mm512_set1_ps(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
inline void init_qstep() {
for (int j = 0; j < q_step; ++j) {
S[j] = 0; M[j] = -INFINITY;
}
}
static inline void mult_mask_kq_one(ggml_half h_inf, int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
__m512bh * qv, const __m512bh * vkh, float * cache) {
// q index is q_step*i1 + m1
// k index is k_step*k1 + l1
const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
cache[k_step*m1 + l1 + 0] = cache[k_step*m1 + l1 + 1] = -INFINITY;
if (mp[l1+0] == h_inf && mp[l1+1] == h_inf) {
return;
}
auto qr = q + m1*stride_q;
for (int i = 0; i < D/32; ++i) {
auto val1 = _mm512_loadu_ps(qr + 32*i);
auto val2 = _mm512_loadu_ps(qr + 32*i + 16);
qv[i] = _mm512_cvtne2ps_pbh(val2, val1);
}
if (mp[l1+0] != h_inf) {
auto vsum = _mm512_setzero_ps();
for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i], qv[i]);
cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
}
if (mp[l1+1] != h_inf) {
auto vsum = _mm512_setzero_ps();
for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+D/32], qv[i]);
cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
}
}
inline void update_M_S(int j, __m512 * vk) {
if (softcap <= 0.0f) {
for (int l = 0; l < k_step/16; ++l) vk[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l));
} else {
auto v_softcap = _mm512_set1_ps(softcap);
for (int l = 0; l < k_step/16; ++l) {
auto val = _mm512_loadu_ps(cache + k_step*j + 16*l);
vk[l] = _mm512_mul_ps(v_softcap, v_tanh(_mm512_mul_ps(vscale, val)));
}
}
float smax = reduce_T<_mm512_reduce_max_ps, _mm512_max_ps>(vk);
if (smax == -INFINITY) {
std::memset(cache + k_step*j, 0, k_step*sizeof(float));
need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
return;
}
need_scaling[j] = 0;
if (smax > M[j]) {
if (M[j] > -INFINITY) {
float m = expf(M[j] - smax);
vms[j] = _mm512_set1_ps(m);
need_scaling[j] = 1;
S[j] *= m;
} else {
need_scaling[j] = 2;
S[j] = 0;
}
M[j] = smax;
}
auto vm = _mm512_set1_ps(M[j]);
for (int l = 0; l < k_step/16; ++l) {
vk[l] = v_expf(_mm512_sub_ps(vk[l], vm));
_mm512_storeu_ps(cache + k_step*j + 16*l, vk[l]);
}
S[j] += reduce_T<_mm512_reduce_add_ps, _mm512_add_ps>(vk);
}
inline void normalize_and_store(int j, const float * R, float * qkv) const {
GGML_ASSERT(S[j] > 0);
auto norm = _mm512_set1_ps(1/S[j]);
for (int i = 0; i < D/16; ++i) {
auto r = _mm512_loadu_ps(R + 16*i);
_mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, r));
}
}
inline void normalize_and_store(int nq1, int stride_qkv, float * qkv) const {
auto R = qkv_cache;
for (int j = 0; j < nq1; ++j) {
normalize_and_store(j, R, qkv);
qkv += stride_qkv;
R += D;
}
}
inline void normalize_and_store(int stride_qkv, float * qkv) const {
auto R = qkv_cache;
for (int j = 0; j < q_step; ++j) {
normalize_and_store(j, R, qkv);
qkv += stride_qkv;
R += D;
}
}
template <typename KHelper>
static inline void mult_mask_kq(ggml_half h_inf, const KHelper& kh, int stride_q, int stride_m, const float * q,
const char * mask, const __m512bh * vkh, float * cache) {
__m512bh qv[D/32];
for (int l1 = 0; l1 < k_step; l1 += 2) {
kh.load_2(l1, vkh);
for (int m1 = 0; m1 < q_step; ++m1) {
mult_mask_kq_one(h_inf, l1, m1, stride_q, stride_m, q, mask, qv, vkh, cache);
}
}
}
template <typename KHelper>
static inline void mult_mask_kq(int nq, ggml_half h_inf, const KHelper& kh, int stride_q, int stride_m, const float * q,
const char * mask, const __m512bh * vkh, float * cache) {
__m512bh qv[D/32];
for (int l1 = 0; l1 < k_step; l1 += 2) {
kh.load_2(l1, vkh);
for (int m1 = 0; m1 < nq; ++m1) {
mult_mask_kq_one(h_inf, l1, m1, stride_q, stride_m, q, mask, qv, vkh, cache);
}
}
}
template <typename KHelper>
static inline void mult_mask_kq(ggml_half h_inf, const KHelper& kh, int stride_q, int stride_m, const float * q,
const char * mask, __m512bh * vkh, float * cache) {
__m512bh qv[D/32];
for (int l1 = 0; l1 < k_step; l1 += 2) {
kh.load_2(l1, vkh);
for (int m1 = 0; m1 < q_step; ++m1) {
mult_mask_kq_one(h_inf, l1, m1, stride_q, stride_m, q, mask, qv, vkh, cache);
}
}
}
template <typename KHelper>
static inline void mult_mask_kq(int nq, ggml_half h_inf, const KHelper& kh, int stride_q, int stride_m, const float * q,
const char * mask, __m512bh * vkh, float * cache) {
__m512bh qv[D/32];
for (int l1 = 0; l1 < k_step; l1 += 2) {
kh.load_2(l1, vkh);
for (int m1 = 0; m1 < nq; ++m1) {
mult_mask_kq_one(h_inf, l1, m1, stride_q, stride_m, q, mask, qv, vkh, cache);
}
}
}
template <typename KHelper>
inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) {
{
__m512bh vkh[D/16];
mult_mask_kq(h_inf, kh, stride_q, stride_m, q, mask, vkh, cache);
}
__m512 vk[k_step/16];
for (int j = 0; j < q_step; ++j) {
update_M_S(j, vk);
}
}
template <typename KHelper>
inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) {
{
__m512bh vkh[D/16];
mult_mask_kq(nq, h_inf, kh, stride_q, stride_m, q, mask, vkh, cache);
}
__m512 vk[k_step/16];
for (int j = 0; j < nq; ++j) {
update_M_S(j, vk);
}
}
// This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
// Hence, for now, we will not handle head sizes of 80 and 112
inline void accumulate_qkv(const HelperBF16<D, k_step>& vh) {
__m512 vk[2*q_step];
for (int i = 0; i < D/16; i += 2) {
for (int j = 0; j < q_step; ++j) {
if (need_scaling[j] == 2) {
vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
} else {
auto R = qkv_cache + D*j;
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
if (need_scaling[j] == 1) {
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
}
}
}
__m512 v1, v2;
for (int l1 = 0; l1 < k_step; ++l1) {
vh.load(l1, i, v1, v2);
for (int j = 0; j < q_step; ++j) {
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
}
}
for (int j = 0; j < q_step; ++j) {
auto R = qkv_cache + D*j;
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
}
}
}
template <typename VHelper, int Nq = q_step, class = std::enable_if<Nq >= 2>>
inline void accumulate_qkv(int nq1, const VHelper& vh) {
GGML_ASSERT(nq1 < q_step);
__m512 vk[2*q_step];
for (int i = 0; i < D/16; i += 2) {
for (int j = 0; j < nq1; ++j) {
if (need_scaling[j] == 2) {
vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
} else {
auto R = qkv_cache + D*j;
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
if (need_scaling[j] == 1) {
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
}
}
}
__m512 v1, v2;
for (int l1 = 0; l1 < k_step; ++l1) {
vh.load(l1, i, v1, v2);
for (int j = 0; j < nq1; ++j) {
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
}
}
for (int j = 0; j < nq1; ++j) {
auto R = qkv_cache + D*j;
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
}
}
}
template <typename KHelper, typename VHelper>
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv) {
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
init_qstep();
kh.reset_block();
vh.reset_block();
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
multiply_mask_kq(kh, stride_q, stride_m, q, mr);
accumulate_qkv(vh);
kh.next_block();
vh.next_block();
mr += k_step*sizeof(ggml_half);
}
normalize_and_store(stride_qkv, qkv);
q += q_step*stride_q;
mask += q_step*stride_m;
qkv += q_step*stride_qkv;
}
int n_left = nq1 - q_step*(nq1/q_step);
if (n_left > 0) {
init_qstep();
kh.reset_block();
vh.reset_block();
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr);
accumulate_qkv(n_left, vh);
kh.next_block();
vh.next_block();
mr += k_step*sizeof(ggml_half);
}
normalize_and_store(n_left, stride_qkv, qkv);
}
}
float cache[q_step*k_step];
float qkv_cache[D*q_step];
float S[q_step], M[q_step];
int need_scaling[q_step];
__m512 vms[q_step];
const __m512 vscale;
const float softcap;
const ggml_half h_inf;
typedef __m512 (*combine_t)(__m512, __m512);
typedef float (*reduce_t)(__m512);
template <reduce_t Op, combine_t Op_combine>
static inline float reduce_T(const __m512 * vals) {
float result;
if constexpr (k_step/16 == 1) {
result = Op(vals[0]);
}
else if constexpr (k_step/16 == 2) {
result = Op(Op_combine(vals[0], vals[1]));
}
else {
auto vmax = Op_combine(vals[0], vals[1]);
for (int l = 2; l < k_step/16; ++l) vmax = Op_combine(vmax, vals[l]);
result = Op(vmax);
}
return result;
}
};
#endif
template <int D, int q_step, int k_step, typename KHelper, typename VHelper>
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv) {
@@ -6710,6 +7048,23 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
}
}
#ifdef __AVX512BF16__
template <int D, int q_step, int k_step>
inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask,
float scale, float softcap, float * qkv) {
HelperBF16<D, k_step> kh(k, stride_k);
HelperBF16<D, k_step> vh(v, stride_v);
if (nq1 >= q_step) {
FlashAttnBF16<D, q_step, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
} else {
FlashAttnBF16<D, 1, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
}
}
#endif
template <int D, int q_step, int k_step, typename KHelper>
inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv,
@@ -6766,7 +7121,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
}
inline bool flash_attn_is_supported(ggml_type type) {
#ifdef __AVX512BF16__
return type == GGML_TYPE_F16 || type == GGML_TYPE_BF16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1;
#else
return type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1;
#endif
}
}
@@ -6799,6 +7158,25 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
stride_q /= sizeof(float); // q stride as float
#ifdef __AVX512BF16__
if (type_k == GGML_TYPE_BF16 && type_v == GGML_TYPE_BF16) {
switch (D) {
case 64:
iqk_flash_helper_T< 64, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 96:
iqk_flash_helper_T< 96, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 128:
iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 256:
iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
default:
return false;
}
return true;
}
#endif
switch (D) {
case 64:
iqk_flash_helper_T< 64, 4, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;