NEON Flash Attention: add support for Q8_0, Q4_0, Q4_1

This commit is contained in:
Iwan Kawrakow
2024-09-12 06:26:37 +02:00
parent c920195edd
commit 2dee479c44

View File

@@ -6548,17 +6548,29 @@ struct HelperF16 final : public BaseHelper<step> {
}
};
#if defined __AVX2__
template <int D, int step>
struct HelperQ80 final : public BaseHelper<step> {
static_assert(step == QK8_0);
using Base = BaseHelper<step>;
//using F16 = HelperF16<D, step>;
HelperQ80(const char * data, int stride) : Base(data, stride) {}
inline void load(int l1, F16::Data * vk) const {
auto dl = (const block_q8_0_x4 *)Base::lblock(l1);
if constexpr (D >= 128) {
#ifdef __aarch64__
for (int ib = 0; ib < D/128; ++ib) {
const auto& b8 = dl[ib];
auto d = (const float16_t *)b8.d;
for (int i = 0; i < 4; ++i) {
auto di = vdupq_n_f16(d[i]);
auto qs = vld1_s8_x4(b8.qs + 32*i);
vk[16*ib+4*i+0] = vmulq_f16(di, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
vk[16*ib+4*i+1] = vmulq_f16(di, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
vk[16*ib+4*i+2] = vmulq_f16(di, vcvtq_f16_s16(vmovl_s8(qs.val[2])));
vk[16*ib+4*i+3] = vmulq_f16(di, vcvtq_f16_s16(vmovl_s8(qs.val[3])));
}
}
#else
F16::Data vd[4];
for (int ib = 0; ib < D/128; ++ib) {
const auto& b8 = dl[ib];
@@ -6585,12 +6597,21 @@ struct HelperQ80 final : public BaseHelper<step> {
vk[16*ib+4*i+2] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+16)))));
vk[16*ib+4*i+3] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+24)))));
}
#endif
#endif // HAVE_FANCY_SIMD
}
#endif // __aarch64__
} else {
for (int i = 0; i < D/32; ++i) {
const auto& b8 = dl[i/4];
int ii = i%4;
#ifdef __aarch64__
auto vd = F16::set1(b8.d[ii]);
auto qs = vld1_s8_x4(b8.qs + 32*i);
vk[4*i+0] = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
vk[4*i+1] = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
vk[4*i+2] = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[2])));
vk[4*i+3] = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[3])));
#else
auto vd = F16::set1(GGML_FP16_TO_FP32(b8.d[ii]));
#ifdef HAVE_FANCY_SIMD
vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+0))));
@@ -6600,18 +6621,23 @@ struct HelperQ80 final : public BaseHelper<step> {
vk[4*i+1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+ 8)))));
vk[4*i+2] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+16)))));
vk[4*i+3] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+24)))));
#endif
#endif
}
}
}
inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
// Say D = 256 -> i is 0, 2, 4, 6, 8, ..., 28, 30. 128/8 = 16 -> we use 1st block of 128 for i = 0, 2, ..., 14, second for i = 16, 18, ..., 30
// i = 0, 2 -> ii = 0, i = 4, 6 -> ii = 1, i = 8, 10 -> ii = 2, i = 12, 14 -> ii = 3, i = 16, 18 -> ii = 0, etc.
// i*F16::block_size/128
int j = F16::block_size*i;
auto dl = (const block_q8_0_x4 *)Base::lblock(l1) + j/(4*QK8_0);
int ii = (j/QK8_0)%4;
#ifdef __aarch64__
const float16_t * d = (const float16_t *)dl->d;
auto vd = F16::set1(d[ii]);
auto qs = vld1_s8_x2(dl->qs + 32*ii + j%32);
v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
#else
auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d[ii]));
#ifdef HAVE_FANCY_SIMD
v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+0))));
@@ -6619,6 +6645,7 @@ struct HelperQ80 final : public BaseHelper<step> {
#else
v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32)))));
v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32+8)))));
#endif
#endif
}
@@ -6637,6 +6664,19 @@ struct HelperQ40 final : public BaseHelper<step> {
inline void load(int l1, F16::Data * vk) const {
auto dl = (const block_q4_0 *)Base::lblock(l1);
#ifdef __aarch64__
for (int i = 0; i < D/32; ++i) {
auto& b4 = dl[i];
auto vd = vdupq_n_f16(*(const float16_t *)&b4.d);
auto qs = vld1q_u8(b4.qs);
auto ql = vaddq_s8(vandq_u8(qs, mask), m8);
auto qh = vaddq_s8(vshrq_n_u8(qs, 4), m8);
vk[4*i+0] = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(ql))));
vk[4*i+1] = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(ql))));
vk[4*i+2] = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(qh))));
vk[4*i+3] = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(qh))));
}
#else
if constexpr (D >= 128) {
ggml_half aux[4];
F16::Data vd[4];
@@ -6703,11 +6743,20 @@ struct HelperQ40 final : public BaseHelper<step> {
#endif
}
}
#endif
}
inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
int j = F16::block_size*i;
auto dl = (const block_q4_0 *)Base::lblock(l1) + j/QK4_0;
#ifdef __aarch64__
auto vd = F16::set1(*(const float16_t *)&dl->d);
auto q = vld1q_u8(dl->qs);
q = j%QK4_0 ? vshrq_n_u8(q, 4) : vandq_u8(q, mask);
q = vaddq_s8(q, m8);
v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(q))));
v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(q))));
#else
auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
auto q = _mm_loadu_si128((const __m128i *)dl->qs);
#ifdef HAVE_FANCY_SIMD
@@ -6720,6 +6769,7 @@ struct HelperQ40 final : public BaseHelper<step> {
auto q16 = _mm256_cvtepi8_epi16(_mm_add_epi8(_mm_and_si128(q, mask), m8));
v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))));
v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))));
#endif
#endif
}
@@ -6728,8 +6778,13 @@ struct HelperQ40 final : public BaseHelper<step> {
load(l1+1, vk+D/F16::block_size);
}
#ifdef __AVX2__
const __m128i mask = _mm_set1_epi8(0xf);
const __m128i m8 = _mm_set1_epi8(-8);
#else
const uint8x16_t mask = vdupq_n_u8(0xf);
const int8x16_t m8 = vdupq_n_s8(-8);
#endif
};
template <int D, int step>
@@ -6738,10 +6793,20 @@ struct HelperQ41 final : public BaseHelper<step> {
using Base = BaseHelper<step>;
HelperQ41(const char * data, int stride) : Base(data, stride) {}
inline void load(int l1, F16::Data * vk) const {
auto dl = (const block_q4_1 *)Base::lblock(l1);
for (int i = 0; i < D/32; ++i) {
#ifdef __aarch64__
auto vd = F16::set1(*(const float16_t *)&dl[i].d);
auto vm = F16::set1(*(const float16_t *)&dl[i].m);
auto q = vld1q_u8(dl[i].qs);
auto ql = vandq_u8(q, mask);
auto qh = vshrq_n_u8(q, 4);
vk[4*i+0] = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_low_u8(ql))));
vk[4*i+1] = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_high_u8(ql))));
vk[4*i+2] = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_low_u8(qh))));
vk[4*i+3] = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_high_u8(qh))));
#else
auto vd = F16::set1(GGML_FP16_TO_FP32(dl[i].d));
auto vm = F16::set1(GGML_FP16_TO_FP32(dl[i].m));
auto q = _mm_loadu_si128((const __m128i *)dl[i].qs);
@@ -6758,6 +6823,7 @@ struct HelperQ41 final : public BaseHelper<step> {
vk[4*i+2] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(qh16))), vm);
vk[4*i+3] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(qh16, 1))), vm);
vk[4*i+0] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(ql)), vm);
#endif
#endif
}
}
@@ -6765,6 +6831,14 @@ struct HelperQ41 final : public BaseHelper<step> {
inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
int j = F16::block_size*i;
auto dl = (const block_q4_1 *)Base::lblock(l1) + j/QK4_1;
#ifdef __aarch64__
auto vd = F16::set1(*(const float16_t *)&dl->d);
auto vm = F16::set1(*(const float16_t *)&dl->m);
auto q = vld1q_u8(dl->qs);
q = (j%QK4_1) ? vshrq_n_u8(q, 4) : vandq_u8(q, mask);
v1 = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_low_u8(q))));
v2 = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_high_u8(q))));
#else
auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
auto vm = F16::set1(GGML_FP16_TO_FP32(dl->m));
auto q = _mm_loadu_si128((const __m128i *)dl->qs);
@@ -6778,6 +6852,7 @@ struct HelperQ41 final : public BaseHelper<step> {
auto q16 = _mm256_cvtepi8_epi16(_mm_and_si128(q, mask));
v1 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))), vm);
v2 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))), vm);
#endif
#endif
}
@@ -6786,9 +6861,12 @@ struct HelperQ41 final : public BaseHelper<step> {
load(l1+1, vk+D/F16::block_size);
}
#ifdef __aarch64__
const uint8x16_t mask = vdupq_n_u8(0xf);
#else
const __m128i mask = _mm_set1_epi8(0xf);
};
#endif
};
template <int q_step, int k_step>
struct FlashMS {
@@ -7604,7 +7682,6 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperF16<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
#ifdef __AVX2__
case GGML_TYPE_Q8_0: {
HelperQ80<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
@@ -7617,7 +7694,6 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperQ41<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
#endif
default: break;
}
}
@@ -7633,7 +7709,6 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperF16<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
#ifdef __AVX2__
case GGML_TYPE_Q8_0: {
HelperQ80<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
@@ -7646,20 +7721,15 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperQ41<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
#endif
default: break;
}
}
inline bool flash_attn_is_supported(ggml_type type) {
#ifdef __AVX2__
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1) return true;
#ifdef __AVX512BF16__
if (type == GGML_TYPE_BF16) return true;
#endif
#else
if (type == GGML_TYPE_F16) return true;
#endif
return false;
}