diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 93d8b222..84514ddc 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6798,6 +6798,29 @@ struct FlashQKbf16 { } } + static inline void mult_mask_kq_one(int l1, int m1, int stride_m, const ggml_bf16_t * q, const char * mask, + __m512bh * qv, const __m512bh * vkh, FlashMS& fms) { + // q index is q_step*i1 + m1 + // k index is k_step*k1 + l1 + const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); + fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = -INFINITY; + if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf) { + return; + } + auto qr = q + m1*D; + for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i*)qr + i)); + if (mp[l1+0] != fms.h_inf) { + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i], qv[i]); + fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum); + } + if (mp[l1+1] != fms.h_inf) { + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+D/32], qv[i]); + fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum); + } + } + static inline void mult_mask_kq_4(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask, __m512bh * qv, const __m512bh * vkh, FlashMS& fms) { // q index is q_step*i1 + m1 @@ -6822,6 +6845,26 @@ struct FlashQKbf16 { } } + static inline void mult_mask_kq_4(int l1, int m1, int stride_m, const ggml_bf16_t * q, const char * mask, + __m512bh * qv, const __m512bh * vkh, FlashMS& fms) { + // q index is q_step*i1 + m1 + // k index is k_step*k1 + l1 + const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); + fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = + fms.cache[k_step*m1 + l1 + 2] = fms.cache[k_step*m1 + l1 + 3] = -INFINITY; + if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf && mp[l1+2] == fms.h_inf && mp[l1+3] == fms.h_inf) { + return; + } + auto qr = q + m1*D; + for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i *)qr + i)); + for (int k = 0; k < 4; ++k) { + if (mp[l1+k] == fms.h_inf) continue; + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]); + fms.cache[k_step*m1 + l1 + k] = _mm512_reduce_add_ps(vsum); + } + } + template static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask, FlashMS& fms) { @@ -6851,6 +6894,35 @@ struct FlashQKbf16 { } } + template + static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q, + const char * mask, FlashMS& fms) { + { + __m512bh qv[D/32]; + if constexpr (D <= 128) { + __m512bh vkh[D/8]; + for (int l1 = 0; l1 < k_step; l1 += 4) { + kh.load_4(l1, vkh); + for (int j = 0; j < q_step; ++j) { + mult_mask_kq_4(l1, j, stride_m, q, mask, qv, vkh, fms); + } + } + } else { + __m512bh vkh[D/16]; + for (int l1 = 0; l1 < k_step; l1 += 2) { + kh.load_2(l1, vkh); + for (int j = 0; j < q_step; ++j) { + mult_mask_kq_one(l1, j, stride_m, q, mask, qv, vkh, fms); + } + } + } + } + __m512 vk[k_step/16]; + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk); + } + } + template static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask, FlashMS& fms) { @@ -6869,6 +6941,19 @@ struct FlashQKbf16 { fms.update_M_S(j, vk); } } + + static inline void convert(int stride_q, const float * q, ggml_bf16_t * bf16) { + auto qr = q; + for (int j = 0; j < q_step; ++j) { + for (int i = 0; i < D/32; ++i) { + auto val1 = _mm512_loadu_ps(qr + 32*i); + auto val2 = _mm512_loadu_ps(qr + 32*i + 16); + _mm512_storeu_si512((__m512i *)bf16 + i, (__m512i)_mm512_cvtne2ps_pbh(val2, val1)); + } + qr += stride_q; + bf16 += D; + } + } }; template @@ -6882,8 +6967,41 @@ struct FlashAttnBF16 { template void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { - compute_helper>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + ggml_bf16_t q_bf16[q_step*D]; + for (int i1 = 0; i1 < nq1/q_step; ++i1) { + fms.init_qstep(); + kh.reset_block(); + vh.reset_block(); + FlashQKbf16::convert(stride_q, q, q_bf16); + auto mr = mask; + for (int k1 = 0; k1 < nk1/k_step; ++k1) { + FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms); + fqkv.accumulate_qkv(vh, fms); + kh.next_block(); + vh.next_block(); + mr += k_step*sizeof(ggml_half); + } + fqkv.normalize_and_store(fms, stride_qkv, qkv); + + q += q_step*stride_q; + mask += q_step*stride_m; + qkv += q_step*stride_qkv; + } + int n_left = nq1 - q_step*(nq1/q_step); + if (n_left > 0) { + fms.init_qstep(); + kh.reset_block(); + vh.reset_block(); + auto mr = mask; + for (int k1 = 0; k1 < nk1/k_step; ++k1) { + FlashQKbf16::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms); + fqkv.accumulate_qkv(n_left, vh, fms); + kh.next_block(); + vh.next_block(); + mr += k_step*sizeof(ggml_half); + } + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv); + } } FlashMS fms; @@ -7018,9 +7136,9 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k if (type_k == GGML_TYPE_BF16 && type_v == GGML_TYPE_BF16) { switch (D) { case 64: - iqk_flash_helper_T< 64, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: @@ -7035,12 +7153,12 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k switch (D) { case 64: - iqk_flash_helper_T< 64, 4, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 8, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 80: // iqk_flash_helper_T< 80, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 4, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 8, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 112: // iqk_flash_helper_T<112, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;