diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 9b83b3f2..33b2a790 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6251,7 +6251,6 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn int32x4_t accd[nrc_y]; - const auto m1 = vdupq_n_u8(1); const auto mask2 = vdupq_n_s8(3); for (int ix = 0; ix < nrc_x; ++ix) { @@ -7181,10 +7180,14 @@ struct HelperQ41 final : public BaseHelper { }; template -struct HelperIQ4NL final : public BaseHelper { +struct HelperIQ4nl final : public BaseHelper { using Base = BaseHelper; +#ifdef __aarch64__ + using block_q8 = block_q8_0; +#else using block_q8 = block_q8_1; - HelperIQ4NL(const char * data, int stride) : Base(data, stride) {} +#endif + HelperIQ4nl(const char * data, int stride) : Base(data, stride), values(vld1q_s8(iq4k_values)) {} // Needed for v * softmax(k * q) inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { @@ -7194,7 +7197,7 @@ struct HelperIQ4NL final : public BaseHelper { auto vd = F16::set1(*(const float16_t *)&dl->d); auto q = vld1q_u8(dl->qs); q = j%QK4_0 ? vshrq_n_u8(q, 4) : vandq_u8(q, mask); - q = vaddq_s8(q, m8); + q = vqtbl1q_s8(values, q); v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(q)))); v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(q)))); #else @@ -7214,11 +7217,12 @@ struct HelperIQ4NL final : public BaseHelper { #endif } -#ifdef __AVX2__ +#ifdef __aarch64__ + const uint8x16_t mask = vdupq_n_u8(0xf); + const int8x16_t values; +#else const __m128i mask = _mm_set1_epi8(0xf); const __m128i values = _mm_loadu_si128((const __m128i *)iq4k_values); -#else - const uint8x16_t mask = vdupq_n_u8(0xf); #endif }; @@ -7746,10 +7750,10 @@ struct FlashQKfp32 { mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); #endif } - else if constexpr (std::is_same_v>) { + else if constexpr (std::is_same_v>) { DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; #ifdef __aarch64__ - mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); + mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); #else mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); #endif @@ -7853,17 +7857,17 @@ struct FlashQKfp32 { #endif } } - else if constexpr (std::is_same_v>) { + else if constexpr (std::is_same_v>) { DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; switch (nq) { #ifdef __aarch64__ - case 1: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; + case 1: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; #else case 1: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; case 2: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; @@ -8014,7 +8018,7 @@ struct FlashAttn { void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { if constexpr (std::is_same_v> || std::is_same_v> || - std::is_same_v> || std::is_same_v>) { + std::is_same_v> || std::is_same_v>) { compute_helper_q>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } else { @@ -8356,7 +8360,7 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_IQ4_NL: { - HelperIQ4NL vh(v, stride_v); + HelperIQ4nl vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; default: break; @@ -8387,7 +8391,7 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_IQ4_NL: { - HelperIQ4NL kh(k, stride_k); + HelperIQ4nl kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; default: break;