From 8218e77deccadcac0d1835287237143bf25e8ecd Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 4 Sep 2024 13:09:24 +0300 Subject: [PATCH] Zen4 Flash Attnetion: WIP bf16 --- common/common.cpp | 3 + examples/llama-bench/llama-bench.cpp | 3 + ggml/src/iqk/iqk_mul_mat.cpp | 380 ++++++++++++++++++++++++++- 3 files changed, 385 insertions(+), 1 deletion(-) diff --git a/common/common.cpp b/common/common.cpp index c86d364f..6c298d2d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2221,6 +2221,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) { if (s == "f16") { return GGML_TYPE_F16; } + if (s == "bf16") { + return GGML_TYPE_BF16; + } if (s == "q8_0") { return GGML_TYPE_Q8_0; } diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 813d7bae..fc77be50 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -306,6 +306,9 @@ static ggml_type ggml_type_from_name(const std::string & s) { if (s == "f16") { return GGML_TYPE_F16; } + if (s == "bf16") { + return GGML_TYPE_BF16; + } if (s == "q8_0") { return GGML_TYPE_Q8_0; } diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 511eea01..6490fc13 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6192,7 +6192,6 @@ struct HelperF16 final : public BaseHelper { load(l1+0, vk+0); load(l1+1, vk+D/16); } - }; template @@ -6697,6 +6696,345 @@ struct FlashAttn { } }; +#ifdef __AVX512BF16__ + +template +struct HelperBF16 final : public BaseHelper { + using Base = BaseHelper; + HelperBF16(const char * data, int stride) : Base(data, stride) {} + inline void load(int l1, __m512bh * vk) const { + auto dr = Base::lblock(l1); + for (int i = 0; i < D/32; ++i) vk[i] = __m512bh(_mm512_loadu_si512((const __m512i*)dr + i)); + } + + inline void load(int l1, int i, __m512& v1, __m512& v2) const { + auto dr = Base::lblock(l1); + v1 = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)dr + i + 0)), 16)); + v2 = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)dr + i + 1)), 16)); + } + + inline void load_2(int l1, __m512bh * vk) const { + load(l1+0, vk+0); + load(l1+1, vk+D/32); + } +}; + +template +struct FlashAttnBF16 { + static_assert(D%32 == 0 && D <= 256); + static_assert(k_step%32 == 0); + static_assert(q_step <= 4 || q_step%4 == 0); + + FlashAttnBF16(float scale, float softcap) : vscale(_mm512_set1_ps(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {} + + inline void init_qstep() { + for (int j = 0; j < q_step; ++j) { + S[j] = 0; M[j] = -INFINITY; + } + } + + static inline void mult_mask_kq_one(ggml_half h_inf, int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask, + __m512bh * qv, const __m512bh * vkh, float * cache) { + // q index is q_step*i1 + m1 + // k index is k_step*k1 + l1 + const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); + cache[k_step*m1 + l1 + 0] = cache[k_step*m1 + l1 + 1] = -INFINITY; + if (mp[l1+0] == h_inf && mp[l1+1] == h_inf) { + return; + } + auto qr = q + m1*stride_q; + for (int i = 0; i < D/32; ++i) { + auto val1 = _mm512_loadu_ps(qr + 32*i); + auto val2 = _mm512_loadu_ps(qr + 32*i + 16); + qv[i] = _mm512_cvtne2ps_pbh(val2, val1); + } + if (mp[l1+0] != h_inf) { + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i], qv[i]); + cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum); + } + if (mp[l1+1] != h_inf) { + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+D/32], qv[i]); + cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum); + } + } + + inline void update_M_S(int j, __m512 * vk) { + if (softcap <= 0.0f) { + for (int l = 0; l < k_step/16; ++l) vk[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l)); + } else { + auto v_softcap = _mm512_set1_ps(softcap); + for (int l = 0; l < k_step/16; ++l) { + auto val = _mm512_loadu_ps(cache + k_step*j + 16*l); + vk[l] = _mm512_mul_ps(v_softcap, v_tanh(_mm512_mul_ps(vscale, val))); + } + } + + float smax = reduce_T<_mm512_reduce_max_ps, _mm512_max_ps>(vk); + if (smax == -INFINITY) { + std::memset(cache + k_step*j, 0, k_step*sizeof(float)); + need_scaling[j] = M[j] == -INFINITY ? 2 : 0; + return; + } + need_scaling[j] = 0; + if (smax > M[j]) { + if (M[j] > -INFINITY) { + float m = expf(M[j] - smax); + vms[j] = _mm512_set1_ps(m); + need_scaling[j] = 1; + S[j] *= m; + } else { + need_scaling[j] = 2; + S[j] = 0; + } + M[j] = smax; + } + auto vm = _mm512_set1_ps(M[j]); + for (int l = 0; l < k_step/16; ++l) { + vk[l] = v_expf(_mm512_sub_ps(vk[l], vm)); + _mm512_storeu_ps(cache + k_step*j + 16*l, vk[l]); + } + S[j] += reduce_T<_mm512_reduce_add_ps, _mm512_add_ps>(vk); + } + + inline void normalize_and_store(int j, const float * R, float * qkv) const { + GGML_ASSERT(S[j] > 0); + auto norm = _mm512_set1_ps(1/S[j]); + for (int i = 0; i < D/16; ++i) { + auto r = _mm512_loadu_ps(R + 16*i); + _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, r)); + } + } + + inline void normalize_and_store(int nq1, int stride_qkv, float * qkv) const { + auto R = qkv_cache; + for (int j = 0; j < nq1; ++j) { + normalize_and_store(j, R, qkv); + qkv += stride_qkv; + R += D; + } + } + + inline void normalize_and_store(int stride_qkv, float * qkv) const { + auto R = qkv_cache; + for (int j = 0; j < q_step; ++j) { + normalize_and_store(j, R, qkv); + qkv += stride_qkv; + R += D; + } + } + + template + static inline void mult_mask_kq(ggml_half h_inf, const KHelper& kh, int stride_q, int stride_m, const float * q, + const char * mask, const __m512bh * vkh, float * cache) { + __m512bh qv[D/32]; + for (int l1 = 0; l1 < k_step; l1 += 2) { + kh.load_2(l1, vkh); + for (int m1 = 0; m1 < q_step; ++m1) { + mult_mask_kq_one(h_inf, l1, m1, stride_q, stride_m, q, mask, qv, vkh, cache); + } + } + } + + template + static inline void mult_mask_kq(int nq, ggml_half h_inf, const KHelper& kh, int stride_q, int stride_m, const float * q, + const char * mask, const __m512bh * vkh, float * cache) { + __m512bh qv[D/32]; + for (int l1 = 0; l1 < k_step; l1 += 2) { + kh.load_2(l1, vkh); + for (int m1 = 0; m1 < nq; ++m1) { + mult_mask_kq_one(h_inf, l1, m1, stride_q, stride_m, q, mask, qv, vkh, cache); + } + } + } + + template + static inline void mult_mask_kq(ggml_half h_inf, const KHelper& kh, int stride_q, int stride_m, const float * q, + const char * mask, __m512bh * vkh, float * cache) { + __m512bh qv[D/32]; + for (int l1 = 0; l1 < k_step; l1 += 2) { + kh.load_2(l1, vkh); + for (int m1 = 0; m1 < q_step; ++m1) { + mult_mask_kq_one(h_inf, l1, m1, stride_q, stride_m, q, mask, qv, vkh, cache); + } + } + } + + template + static inline void mult_mask_kq(int nq, ggml_half h_inf, const KHelper& kh, int stride_q, int stride_m, const float * q, + const char * mask, __m512bh * vkh, float * cache) { + __m512bh qv[D/32]; + for (int l1 = 0; l1 < k_step; l1 += 2) { + kh.load_2(l1, vkh); + for (int m1 = 0; m1 < nq; ++m1) { + mult_mask_kq_one(h_inf, l1, m1, stride_q, stride_m, q, mask, qv, vkh, cache); + } + } + } + + template + inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) { + { + __m512bh vkh[D/16]; + mult_mask_kq(h_inf, kh, stride_q, stride_m, q, mask, vkh, cache); + } + __m512 vk[k_step/16]; + for (int j = 0; j < q_step; ++j) { + update_M_S(j, vk); + } + } + + template + inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) { + { + __m512bh vkh[D/16]; + mult_mask_kq(nq, h_inf, kh, stride_q, stride_m, q, mask, vkh, cache); + } + __m512 vk[k_step/16]; + for (int j = 0; j < nq; ++j) { + update_M_S(j, vk); + } + } + + // This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2 + // Hence, for now, we will not handle head sizes of 80 and 112 + inline void accumulate_qkv(const HelperBF16& vh) { + __m512 vk[2*q_step]; + for (int i = 0; i < D/16; i += 2) { + for (int j = 0; j < q_step; ++j) { + if (need_scaling[j] == 2) { + vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps(); + } else { + auto R = qkv_cache + D*j; + vk[2*j+0] = _mm512_loadu_ps(R + 16*i); + vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16); + if (need_scaling[j] == 1) { + vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]); + vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]); + } + } + } + __m512 v1, v2; + for (int l1 = 0; l1 < k_step; ++l1) { + vh.load(l1, i, v1, v2); + for (int j = 0; j < q_step; ++j) { + auto vs = _mm512_set1_ps(cache[k_step*j + l1]); + vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]); + vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]); + } + } + for (int j = 0; j < q_step; ++j) { + auto R = qkv_cache + D*j; + _mm512_storeu_ps(R + 16*i, vk[2*j+0]); + _mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]); + } + } + } + + template = 2>> + inline void accumulate_qkv(int nq1, const VHelper& vh) { + GGML_ASSERT(nq1 < q_step); + __m512 vk[2*q_step]; + for (int i = 0; i < D/16; i += 2) { + for (int j = 0; j < nq1; ++j) { + if (need_scaling[j] == 2) { + vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps(); + } else { + auto R = qkv_cache + D*j; + vk[2*j+0] = _mm512_loadu_ps(R + 16*i); + vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16); + if (need_scaling[j] == 1) { + vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]); + vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]); + } + } + } + __m512 v1, v2; + for (int l1 = 0; l1 < k_step; ++l1) { + vh.load(l1, i, v1, v2); + for (int j = 0; j < nq1; ++j) { + auto vs = _mm512_set1_ps(cache[k_step*j + l1]); + vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]); + vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]); + } + } + for (int j = 0; j < nq1; ++j) { + auto R = qkv_cache + D*j; + _mm512_storeu_ps(R + 16*i, vk[2*j+0]); + _mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]); + } + } + } + + template + void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, + const float * q, const char * mask, float * qkv) { + for (int i1 = 0; i1 < nq1/q_step; ++i1) { + init_qstep(); + kh.reset_block(); + vh.reset_block(); + auto mr = mask; + for (int k1 = 0; k1 < nk1/k_step; ++k1) { + multiply_mask_kq(kh, stride_q, stride_m, q, mr); + accumulate_qkv(vh); + kh.next_block(); + vh.next_block(); + mr += k_step*sizeof(ggml_half); + } + normalize_and_store(stride_qkv, qkv); + + q += q_step*stride_q; + mask += q_step*stride_m; + qkv += q_step*stride_qkv; + } + int n_left = nq1 - q_step*(nq1/q_step); + if (n_left > 0) { + init_qstep(); + kh.reset_block(); + vh.reset_block(); + auto mr = mask; + for (int k1 = 0; k1 < nk1/k_step; ++k1) { + multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr); + accumulate_qkv(n_left, vh); + kh.next_block(); + vh.next_block(); + mr += k_step*sizeof(ggml_half); + } + normalize_and_store(n_left, stride_qkv, qkv); + } + } + + float cache[q_step*k_step]; + float qkv_cache[D*q_step]; + float S[q_step], M[q_step]; + int need_scaling[q_step]; + __m512 vms[q_step]; + const __m512 vscale; + const float softcap; + const ggml_half h_inf; + + typedef __m512 (*combine_t)(__m512, __m512); + typedef float (*reduce_t)(__m512); + template + static inline float reduce_T(const __m512 * vals) { + float result; + if constexpr (k_step/16 == 1) { + result = Op(vals[0]); + } + else if constexpr (k_step/16 == 2) { + result = Op(Op_combine(vals[0], vals[1])); + } + else { + auto vmax = Op_combine(vals[0], vals[1]); + for (int l = 2; l < k_step/16; ++l) vmax = Op_combine(vmax, vals[l]); + result = Op(vmax); + } + return result; + } +}; +#endif + template inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float scale, float softcap, float * qkv) { @@ -6710,6 +7048,23 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str } } +#ifdef __AVX512BF16__ +template +inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, + const float * q, const char * k, const char * v, const char * mask, + float scale, float softcap, float * qkv) { + HelperBF16 kh(k, stride_k); + HelperBF16 vh(v, stride_v); + if (nq1 >= q_step) { + FlashAttnBF16 fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } else { + FlashAttnBF16 fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } +} +#endif + template inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv, @@ -6766,7 +7121,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, } inline bool flash_attn_is_supported(ggml_type type) { +#ifdef __AVX512BF16__ + return type == GGML_TYPE_F16 || type == GGML_TYPE_BF16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1; +#else return type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1; +#endif } } @@ -6799,6 +7158,25 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k stride_q /= sizeof(float); // q stride as float +#ifdef __AVX512BF16__ + if (type_k == GGML_TYPE_BF16 && type_v == GGML_TYPE_BF16) { + switch (D) { + case 64: + iqk_flash_helper_T< 64, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 96: + iqk_flash_helper_T< 96, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 128: + iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 256: + iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + default: + return false; + } + + return true; + } +#endif + switch (D) { case 64: iqk_flash_helper_T< 64, 4, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;