mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-26 01:19:20 +00:00
AVX2 Flash Attention: add ability to use Q4_1 for kv-cache
This commit is contained in:
@@ -6731,9 +6731,7 @@ struct HelperQ40 final : public BaseHelper<step> {
|
||||
const __m128i mask = _mm_set1_epi8(0xf);
|
||||
const __m128i m8 = _mm_set1_epi8(-8);
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
template <int D, int step>
|
||||
struct HelperQ41 final : public BaseHelper<step> {
|
||||
static_assert(step == QK4_1);
|
||||
@@ -6741,33 +6739,51 @@ struct HelperQ41 final : public BaseHelper<step> {
|
||||
HelperQ41(const char * data, int stride) : Base(data, stride) {}
|
||||
|
||||
|
||||
inline void load(int l1, __m512 * vk) const {
|
||||
inline void load(int l1, F16::Data * vk) const {
|
||||
auto dl = (const block_q4_1 *)Base::lblock(l1);
|
||||
for (int i = 0; i < D/32; ++i) {
|
||||
auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
|
||||
auto vm = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].m));
|
||||
auto vd = F16::set1(GGML_FP16_TO_FP32(dl[i].d));
|
||||
auto vm = F16::set1(GGML_FP16_TO_FP32(dl[i].m));
|
||||
auto q = _mm_loadu_si128((const __m128i *)dl[i].qs);
|
||||
auto ql = _mm_and_si128(q, mask);
|
||||
auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
vk[2*i+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm);
|
||||
vk[2*i+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)), vm);
|
||||
#else
|
||||
auto ql16 = _mm256_cvtepi8_epi16(ql);
|
||||
auto qh16 = _mm256_cvtepi8_epi16(qh);
|
||||
vk[4*i+0] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(ql16))), vm);
|
||||
vk[4*i+1] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(ql16, 1))), vm);
|
||||
vk[4*i+2] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(qh16))), vm);
|
||||
vk[4*i+3] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(qh16, 1))), vm);
|
||||
vk[4*i+0] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(ql)), vm);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
inline void load(int l1, int i, __m512& v1, __m512& v2) const {
|
||||
auto dl = (const block_q4_1 *)Base::lblock(l1) + i/2;
|
||||
auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d));
|
||||
auto vm = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->m));
|
||||
inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
|
||||
int j = F16::block_size*i;
|
||||
auto dl = (const block_q4_1 *)Base::lblock(l1) + j/QK4_1;
|
||||
auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
|
||||
auto vm = F16::set1(GGML_FP16_TO_FP32(dl->m));
|
||||
auto q = _mm_loadu_si128((const __m128i *)dl->qs);
|
||||
auto ql = _mm_and_si128(q, mask);
|
||||
auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
|
||||
auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
|
||||
v1 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm);
|
||||
v2 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)), vm);
|
||||
#else
|
||||
if (j%QK4_1) q = _mm_srli_epi16(q, 4);
|
||||
auto q16 = _mm256_cvtepi8_epi16(_mm_and_si128(q, mask));
|
||||
v1 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))), vm);
|
||||
v2 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))), vm);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void load_2(int l1, __m512 * vk) const {
|
||||
inline void load_2(int l1, F16::Data * vk) const {
|
||||
load(l1+0, vk+0);
|
||||
load(l1+1, vk+D/16);
|
||||
load(l1+1, vk+D/F16::block_size);
|
||||
}
|
||||
|
||||
const __m128i mask = _mm_set1_epi8(0xf);
|
||||
@@ -7597,12 +7613,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
|
||||
HelperQ40<D, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
case GGML_TYPE_Q4_1: {
|
||||
HelperQ41<D, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
#endif
|
||||
#endif
|
||||
default: break;
|
||||
}
|
||||
@@ -7628,12 +7642,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
|
||||
HelperQ40<D, k_step> kh(k, stride_k);
|
||||
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
case GGML_TYPE_Q4_1: {
|
||||
HelperQ41<D, k_step> kh(k, stride_k);
|
||||
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
#endif
|
||||
#endif
|
||||
default: break;
|
||||
}
|
||||
@@ -7642,13 +7654,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
|
||||
|
||||
inline bool flash_attn_is_supported(ggml_type type) {
|
||||
#ifdef __AVX2__
|
||||
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0) return true;
|
||||
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1) return true;
|
||||
#ifdef __AVX512BF16__
|
||||
if (type == GGML_TYPE_BF16) return true;
|
||||
#endif
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
if (type == GGML_TYPE_Q4_1) return true;
|
||||
#endif
|
||||
#else
|
||||
if (type == GGML_TYPE_F16) return true;
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user