mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
FA: slightly faster V*softmax(K*Q)) on Zen4
We now get 130.9 t/s for a context of 32k tokens.
This commit is contained in:
@@ -13666,6 +13666,10 @@ struct HelperBF16 final : public BaseHelper<step> {
|
||||
load(l1+2, vk+2*D/32);
|
||||
load(l1+3, vk+3*D/32);
|
||||
}
|
||||
|
||||
inline void load_8(int l1, __m512bh * vk) const {
|
||||
for (int k = 0; k < 8; ++k) load(l1 + k, vk + k*D/32);
|
||||
}
|
||||
};
|
||||
|
||||
template <int D, int q_step, int k_step>
|
||||
@@ -13818,6 +13822,29 @@ struct FlashQKbf16 {
|
||||
_mm_storeu_ps(fms.cache + k_step*m1 + l1, hsum_float_4x4(sum));
|
||||
}
|
||||
|
||||
static IQK_ALWAYS_INLINE __m256 hsum_float_8x8(__m256 * accm) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31));
|
||||
//accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
|
||||
// _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
|
||||
}
|
||||
for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
|
||||
return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
|
||||
}
|
||||
|
||||
static inline void mult_mask_kq_8(int l1, int m1, const ggml_bf16_t * q,
|
||||
__m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
|
||||
auto qr = q + m1*D;
|
||||
for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i *)qr + i));
|
||||
__m256 sum[8];
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
auto vsum = _mm512_setzero_ps();
|
||||
for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]);
|
||||
sum[k] = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1));
|
||||
}
|
||||
_mm256_storeu_ps(fms.cache + k_step*m1 + l1, hsum_float_8x8(sum));
|
||||
}
|
||||
|
||||
static inline void mult_mask_kq_one(int l1, int m1, const ggml_bf16_t * q,
|
||||
__m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
|
||||
auto qr = q + m1*D;
|
||||
@@ -13836,10 +13863,15 @@ struct FlashQKbf16 {
|
||||
{
|
||||
__m512bh qv[D/32];
|
||||
if constexpr (D <= 128) {
|
||||
__m512bh vkh[D/8];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 4) {
|
||||
kh.load_4(l1, vkh);
|
||||
for (int j = 0; j < q_step; ++j) mult_mask_kq_4(l1, j, q, qv, vkh, fms);
|
||||
//__m512bh vkh[D/8];
|
||||
//for (int l1 = 0; l1 < k_step; l1 += 4) {
|
||||
// kh.load_4(l1, vkh);
|
||||
// for (int j = 0; j < q_step; ++j) mult_mask_kq_4(l1, j, q, qv, vkh, fms);
|
||||
//}
|
||||
__m512bh vkh[D/4];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 8) {
|
||||
kh.load_8(l1, vkh);
|
||||
for (int j = 0; j < q_step; ++j) mult_mask_kq_8(l1, j, q, qv, vkh, fms);
|
||||
}
|
||||
} else {
|
||||
__m512bh vkh[D/16];
|
||||
@@ -14159,6 +14191,22 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
if (type_k == GGML_TYPE_BF16) {
|
||||
if (nk1%64 == 0) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
switch (D) {
|
||||
case 64:
|
||||
iqk_flash_helper_T< 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 96:
|
||||
iqk_flash_helper_T< 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 128:
|
||||
iqk_flash_helper_T<128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 256:
|
||||
iqk_flash_helper_T<256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
switch (D) {
|
||||
case 64:
|
||||
|
||||
Reference in New Issue
Block a user