diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 5a8cbce2..2e7a723a 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6548,17 +6548,29 @@ struct HelperF16 final : public BaseHelper { } }; -#if defined __AVX2__ template struct HelperQ80 final : public BaseHelper { static_assert(step == QK8_0); using Base = BaseHelper; - //using F16 = HelperF16; HelperQ80(const char * data, int stride) : Base(data, stride) {} inline void load(int l1, F16::Data * vk) const { auto dl = (const block_q8_0_x4 *)Base::lblock(l1); if constexpr (D >= 128) { +#ifdef __aarch64__ + for (int ib = 0; ib < D/128; ++ib) { + const auto& b8 = dl[ib]; + auto d = (const float16_t *)b8.d; + for (int i = 0; i < 4; ++i) { + auto di = vdupq_n_f16(d[i]); + auto qs = vld1_s8_x4(b8.qs + 32*i); + vk[16*ib+4*i+0] = vmulq_f16(di, vcvtq_f16_s16(vmovl_s8(qs.val[0]))); + vk[16*ib+4*i+1] = vmulq_f16(di, vcvtq_f16_s16(vmovl_s8(qs.val[1]))); + vk[16*ib+4*i+2] = vmulq_f16(di, vcvtq_f16_s16(vmovl_s8(qs.val[2]))); + vk[16*ib+4*i+3] = vmulq_f16(di, vcvtq_f16_s16(vmovl_s8(qs.val[3]))); + } + } +#else F16::Data vd[4]; for (int ib = 0; ib < D/128; ++ib) { const auto& b8 = dl[ib]; @@ -6585,12 +6597,21 @@ struct HelperQ80 final : public BaseHelper { vk[16*ib+4*i+2] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+16))))); vk[16*ib+4*i+3] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+24))))); } -#endif +#endif // HAVE_FANCY_SIMD } +#endif // __aarch64__ } else { for (int i = 0; i < D/32; ++i) { const auto& b8 = dl[i/4]; int ii = i%4; +#ifdef __aarch64__ + auto vd = F16::set1(b8.d[ii]); + auto qs = vld1_s8_x4(b8.qs + 32*i); + vk[4*i+0] = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0]))); + vk[4*i+1] = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1]))); + vk[4*i+2] = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[2]))); + vk[4*i+3] = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[3]))); +#else auto vd = F16::set1(GGML_FP16_TO_FP32(b8.d[ii])); #ifdef HAVE_FANCY_SIMD vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+0)))); @@ -6600,18 +6621,23 @@ struct HelperQ80 final : public BaseHelper { vk[4*i+1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+ 8))))); vk[4*i+2] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+16))))); vk[4*i+3] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+24))))); +#endif #endif } } } inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { - // Say D = 256 -> i is 0, 2, 4, 6, 8, ..., 28, 30. 128/8 = 16 -> we use 1st block of 128 for i = 0, 2, ..., 14, second for i = 16, 18, ..., 30 - // i = 0, 2 -> ii = 0, i = 4, 6 -> ii = 1, i = 8, 10 -> ii = 2, i = 12, 14 -> ii = 3, i = 16, 18 -> ii = 0, etc. - // i*F16::block_size/128 int j = F16::block_size*i; auto dl = (const block_q8_0_x4 *)Base::lblock(l1) + j/(4*QK8_0); int ii = (j/QK8_0)%4; +#ifdef __aarch64__ + const float16_t * d = (const float16_t *)dl->d; + auto vd = F16::set1(d[ii]); + auto qs = vld1_s8_x2(dl->qs + 32*ii + j%32); + v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0]))); + v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1]))); +#else auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d[ii])); #ifdef HAVE_FANCY_SIMD v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+0)))); @@ -6619,6 +6645,7 @@ struct HelperQ80 final : public BaseHelper { #else v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32))))); v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32+8))))); +#endif #endif } @@ -6637,6 +6664,19 @@ struct HelperQ40 final : public BaseHelper { inline void load(int l1, F16::Data * vk) const { auto dl = (const block_q4_0 *)Base::lblock(l1); +#ifdef __aarch64__ + for (int i = 0; i < D/32; ++i) { + auto& b4 = dl[i]; + auto vd = vdupq_n_f16(*(const float16_t *)&b4.d); + auto qs = vld1q_u8(b4.qs); + auto ql = vaddq_s8(vandq_u8(qs, mask), m8); + auto qh = vaddq_s8(vshrq_n_u8(qs, 4), m8); + vk[4*i+0] = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(ql)))); + vk[4*i+1] = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(ql)))); + vk[4*i+2] = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(qh)))); + vk[4*i+3] = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(qh)))); + } +#else if constexpr (D >= 128) { ggml_half aux[4]; F16::Data vd[4]; @@ -6703,11 +6743,20 @@ struct HelperQ40 final : public BaseHelper { #endif } } +#endif } inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { int j = F16::block_size*i; auto dl = (const block_q4_0 *)Base::lblock(l1) + j/QK4_0; +#ifdef __aarch64__ + auto vd = F16::set1(*(const float16_t *)&dl->d); + auto q = vld1q_u8(dl->qs); + q = j%QK4_0 ? vshrq_n_u8(q, 4) : vandq_u8(q, mask); + q = vaddq_s8(q, m8); + v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(q)))); + v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(q)))); +#else auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); auto q = _mm_loadu_si128((const __m128i *)dl->qs); #ifdef HAVE_FANCY_SIMD @@ -6720,6 +6769,7 @@ struct HelperQ40 final : public BaseHelper { auto q16 = _mm256_cvtepi8_epi16(_mm_add_epi8(_mm_and_si128(q, mask), m8)); v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16)))); v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1)))); +#endif #endif } @@ -6728,8 +6778,13 @@ struct HelperQ40 final : public BaseHelper { load(l1+1, vk+D/F16::block_size); } +#ifdef __AVX2__ const __m128i mask = _mm_set1_epi8(0xf); const __m128i m8 = _mm_set1_epi8(-8); +#else + const uint8x16_t mask = vdupq_n_u8(0xf); + const int8x16_t m8 = vdupq_n_s8(-8); +#endif }; template @@ -6738,10 +6793,20 @@ struct HelperQ41 final : public BaseHelper { using Base = BaseHelper; HelperQ41(const char * data, int stride) : Base(data, stride) {} - inline void load(int l1, F16::Data * vk) const { auto dl = (const block_q4_1 *)Base::lblock(l1); for (int i = 0; i < D/32; ++i) { +#ifdef __aarch64__ + auto vd = F16::set1(*(const float16_t *)&dl[i].d); + auto vm = F16::set1(*(const float16_t *)&dl[i].m); + auto q = vld1q_u8(dl[i].qs); + auto ql = vandq_u8(q, mask); + auto qh = vshrq_n_u8(q, 4); + vk[4*i+0] = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_low_u8(ql)))); + vk[4*i+1] = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_high_u8(ql)))); + vk[4*i+2] = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_low_u8(qh)))); + vk[4*i+3] = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_high_u8(qh)))); +#else auto vd = F16::set1(GGML_FP16_TO_FP32(dl[i].d)); auto vm = F16::set1(GGML_FP16_TO_FP32(dl[i].m)); auto q = _mm_loadu_si128((const __m128i *)dl[i].qs); @@ -6758,6 +6823,7 @@ struct HelperQ41 final : public BaseHelper { vk[4*i+2] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(qh16))), vm); vk[4*i+3] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(qh16, 1))), vm); vk[4*i+0] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(ql)), vm); +#endif #endif } } @@ -6765,6 +6831,14 @@ struct HelperQ41 final : public BaseHelper { inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { int j = F16::block_size*i; auto dl = (const block_q4_1 *)Base::lblock(l1) + j/QK4_1; +#ifdef __aarch64__ + auto vd = F16::set1(*(const float16_t *)&dl->d); + auto vm = F16::set1(*(const float16_t *)&dl->m); + auto q = vld1q_u8(dl->qs); + q = (j%QK4_1) ? vshrq_n_u8(q, 4) : vandq_u8(q, mask); + v1 = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_low_u8(q)))); + v2 = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_high_u8(q)))); +#else auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); auto vm = F16::set1(GGML_FP16_TO_FP32(dl->m)); auto q = _mm_loadu_si128((const __m128i *)dl->qs); @@ -6778,6 +6852,7 @@ struct HelperQ41 final : public BaseHelper { auto q16 = _mm256_cvtepi8_epi16(_mm_and_si128(q, mask)); v1 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))), vm); v2 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))), vm); +#endif #endif } @@ -6786,9 +6861,12 @@ struct HelperQ41 final : public BaseHelper { load(l1+1, vk+D/F16::block_size); } +#ifdef __aarch64__ + const uint8x16_t mask = vdupq_n_u8(0xf); +#else const __m128i mask = _mm_set1_epi8(0xf); -}; #endif +}; template struct FlashMS { @@ -7604,7 +7682,6 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperF16 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; -#ifdef __AVX2__ case GGML_TYPE_Q8_0: { HelperQ80 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); @@ -7617,7 +7694,6 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperQ41 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; -#endif default: break; } } @@ -7633,7 +7709,6 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperF16 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; -#ifdef __AVX2__ case GGML_TYPE_Q8_0: { HelperQ80 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); @@ -7646,20 +7721,15 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperQ41 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; -#endif default: break; } } inline bool flash_attn_is_supported(ggml_type type) { -#ifdef __AVX2__ if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1) return true; #ifdef __AVX512BF16__ if (type == GGML_TYPE_BF16) return true; -#endif -#else - if (type == GGML_TYPE_F16) return true; #endif return false; }