From 62809e70975e932dd2b038da20a18cc5a14ecffd Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 11 Sep 2024 18:39:05 +0300 Subject: [PATCH] AVX2 Flash Attention: add ability to use Q4_0 for kv-cache --- ggml/src/iqk/iqk_mul_mat.cpp | 84 ++++++++++++++++++++++++++---------- 1 file changed, 62 insertions(+), 22 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 142e673d..b5135cfb 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6580,10 +6580,10 @@ struct HelperQ80 final : public BaseHelper { vd[2] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(2, 2, 2, 2)); vd[3] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(3, 3, 3, 3)); for (int i = 0; i < 4; ++i) { - vk[8*ib+4*i+0] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+ 0))))); - vk[8*ib+4*i+1] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+ 8))))); - vk[8*ib+4*i+2] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+16))))); - vk[8*ib+4*i+3] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+24))))); + vk[16*ib+4*i+0] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+ 0))))); + vk[16*ib+4*i+1] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+ 8))))); + vk[16*ib+4*i+2] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+16))))); + vk[16*ib+4*i+3] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+24))))); } #endif } @@ -6609,15 +6609,16 @@ struct HelperQ80 final : public BaseHelper { // Say D = 256 -> i is 0, 2, 4, 6, 8, ..., 28, 30. 128/8 = 16 -> we use 1st block of 128 for i = 0, 2, ..., 14, second for i = 16, 18, ..., 30 // i = 0, 2 -> ii = 0, i = 4, 6 -> ii = 1, i = 8, 10 -> ii = 2, i = 12, 14 -> ii = 3, i = 16, 18 -> ii = 0, etc. // i*F16::block_size/128 - auto dl = (const block_q8_0_x4 *)Base::lblock(l1) + i/(128/F16::block_size); - int ii = (i*F16::block_size/32)%4; + int j = F16::block_size*i; + auto dl = (const block_q8_0_x4 *)Base::lblock(l1) + j/(4*QK8_0); + int ii = (j/QK8_0)%4; auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d[ii])); #ifdef HAVE_FANCY_SIMD v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+0)))); v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+1)))); #else - v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+((i+0)*F16::block_size)%32))))); - v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+((i+1)*F16::block_size)%32))))); + v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32))))); + v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32+8))))); #endif } @@ -6626,9 +6627,7 @@ struct HelperQ80 final : public BaseHelper { load(l1+1, vk+D/F16::block_size); } }; -#endif -#ifdef HAVE_FANCY_SIMD template struct HelperQ40 final : public BaseHelper { static_assert(step == QK4_0); @@ -6636,11 +6635,11 @@ struct HelperQ40 final : public BaseHelper { HelperQ40(const char * data, int stride) : Base(data, stride) {} - inline void load(int l1, __m512 * vk) const { + inline void load(int l1, F16::Data * vk) const { auto dl = (const block_q4_0 *)Base::lblock(l1); if constexpr (D >= 128) { ggml_half aux[4]; - __m512 vd[4]; + F16::Data vd[4]; for (int ib = 0; ib < D/128; ++ib) { for (int i = 0; i < 4; ++i) { auto& b4 = dl[4*ib+i]; @@ -6648,11 +6647,21 @@ struct HelperQ40 final : public BaseHelper { auto q = _mm_loadu_si128((const __m128i *)b4.qs); auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8); auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8); +#ifdef HAVE_FANCY_SIMD vk[8*ib+2*i+0] = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)); vk[8*ib+2*i+1] = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)); +#else + auto ql16 = _mm256_cvtepi8_epi16(ql); + auto qh16 = _mm256_cvtepi8_epi16(qh); + vk[16*ib+4*i+0] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(ql16))); + vk[16*ib+4*i+1] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(ql16, 1))); + vk[16*ib+4*i+2] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(qh16))); + vk[16*ib+4*i+3] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(qh16, 1))); +#endif } auto scales4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)aux)); auto scales8 = _mm256_insertf128_ps(_mm256_castps128_ps256(scales4), scales4, 1); +#ifdef HAVE_FANCY_SIMD auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales8), scales8, 1); vd[0] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(0, 0, 0, 0)); vd[1] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(1, 1, 1, 1)); @@ -6662,38 +6671,69 @@ struct HelperQ40 final : public BaseHelper { vk[8*ib+2*i+0] = _mm512_mul_ps(vd[i], vk[8*ib+2*i+0]); vk[8*ib+2*i+1] = _mm512_mul_ps(vd[i], vk[8*ib+2*i+1]); } +#else + vd[0] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(0, 0, 0, 0)); + vd[1] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(1, 1, 1, 1)); + vd[2] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(2, 2, 2, 2)); + vd[3] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(3, 3, 3, 3)); + for (int i = 0; i < 4; ++i) { + vk[16*ib+4*i+0] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+0]); + vk[16*ib+4*i+1] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+1]); + vk[16*ib+4*i+2] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+2]); + vk[16*ib+4*i+3] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+3]); + } +#endif } } else { for (int i = 0; i < D/32; ++i) { - auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d)); + auto vd = F16::set1(GGML_FP16_TO_FP32(dl[i].d)); auto q = _mm_loadu_si128((const __m128i *)dl[i].qs); auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8); auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8); +#ifdef HAVE_FANCY_SIMD vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh))); +#else + auto ql16 = _mm256_cvtepi8_epi16(ql); + auto qh16 = _mm256_cvtepi8_epi16(qh); + vk[4*i+0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(ql16)))); + vk[4*i+1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(ql16, 1)))); + vk[4*i+2] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(qh16)))); + vk[4*i+3] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(qh16, 1)))); +#endif } } } - inline void load(int l1, int i, __m512& v1, __m512& v2) const { - auto dl = (const block_q4_0 *)Base::lblock(l1) + i/2; - auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d)); + inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { + int j = F16::block_size*i; + auto dl = (const block_q4_0 *)Base::lblock(l1) + j/QK4_0; + auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); auto q = _mm_loadu_si128((const __m128i *)dl->qs); +#ifdef HAVE_FANCY_SIMD auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8); auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8); v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh))); +#else + if (j%QK4_0) q = _mm_srli_epi16(q, 4); + auto q16 = _mm256_cvtepi8_epi16(_mm_add_epi8(_mm_and_si128(q, mask), m8)); + v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16)))); + v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1)))); +#endif } - inline void load_2(int l1, __m512 * vk) const { + inline void load_2(int l1, F16::Data * vk) const { load(l1+0, vk+0); - load(l1+1, vk+D/16); + load(l1+1, vk+D/F16::block_size); } const __m128i mask = _mm_set1_epi8(0xf); const __m128i m8 = _mm_set1_epi8(-8); }; +#endif +#ifdef HAVE_FANCY_SIMD template struct HelperQ41 final : public BaseHelper { static_assert(step == QK4_1); @@ -7553,11 +7593,11 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperQ80 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; -#ifdef HAVE_FANCY_SIMD case GGML_TYPE_Q4_0: { HelperQ40 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; +#ifdef HAVE_FANCY_SIMD case GGML_TYPE_Q4_1: { HelperQ41 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); @@ -7584,11 +7624,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperQ80 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; -#ifdef HAVE_FANCY_SIMD case GGML_TYPE_Q4_0: { HelperQ40 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; +#ifdef HAVE_FANCY_SIMD case GGML_TYPE_Q4_1: { HelperQ41 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); @@ -7602,12 +7642,12 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, inline bool flash_attn_is_supported(ggml_type type) { #ifdef __AVX2__ - if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0) return true; + if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0) return true; #ifdef __AVX512BF16__ if (type == GGML_TYPE_BF16) return true; #endif #ifdef HAVE_FANCY_SIMD - if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1) return true; + if (type == GGML_TYPE_Q4_1) return true; #endif #else if (type == GGML_TYPE_F16) return true;