diff --git a/ggml/src/iqk/iqk_gemm_floats.cpp b/ggml/src/iqk/iqk_gemm_floats.cpp index bcab5fbf..664c734b 100644 --- a/ggml/src/iqk/iqk_gemm_floats.cpp +++ b/ggml/src/iqk/iqk_gemm_floats.cpp @@ -566,6 +566,67 @@ bool iqk_set_kernels_float(int ne00, int typeA, int typeB, std::array, QFT>(D, cx, bx, ik*nrc_k, info); + } + info.cur_y += nrc_q; + } + if (qrem > 0) { + switch (qrem) { + case 1: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, cx, bx, ik*nrc_k, info); + } + } break; + case 2: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, cx, bx, ik*nrc_k, info); + } + } break; + case 3: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, cx, bx, ik*nrc_k, info); + } + } break; +#ifdef HAVE_FANCY_SIMD + case 4: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, cx, bx, ik*nrc_k, info); + } + } break; + case 5: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, cx, bx, ik*nrc_k, info); + } + } break; + case 6: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, cx, bx, ik*nrc_k, info); + } + } break; + case 7: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, cx, bx, ik*nrc_k, info); + } + } break; +#endif + } + } +} + #else // ----------------------------------- __aarch64__ ----------------------------------------------- diff --git a/ggml/src/iqk/iqk_gemm_floats.h b/ggml/src/iqk/iqk_gemm_floats.h index 4c414c44..aba514f6 100644 --- a/ggml/src/iqk/iqk_gemm_floats.h +++ b/ggml/src/iqk/iqk_gemm_floats.h @@ -8,4 +8,6 @@ bool iqk_set_kernels_float(int ne00, int typeA, int typeB, std::array& kernels); +void iqk_gemm_default_floats(int D, int nq, const char * vx, size_t bx, DataInfo& info, int k_step); + #endif diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp index 44540c75..75901008 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.cpp +++ b/ggml/src/iqk/iqk_gemm_kquants.cpp @@ -2849,4 +2849,272 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array +static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%32 == 0); + if (nrc_y == 1 && nrc_x == 1) { + auto dx = (const float *)vx; + auto dy = (const float *)info.src1_row(0); +#ifdef HAVE_FANCY_SIMD + auto sy = (const int32_t *)(dy + 1); + auto x = (const int8_t *)(dx + 2); + auto y = (const int8_t *)(dy + 2); + auto isum = _mm512_setzero_si512(); + for (int i = 0; i < n/64; ++i) { + auto qx = _mm512_loadu_si512((const __m512i *)x + i); + auto qy = _mm512_loadu_si512((const __m512i *)y + i); + isum = _mm512_dpbusd_epi32(isum, _mm512_add_epi8(qx, _mm512_set1_epi8(127)), qy); + } + auto isum256 = _mm256_add_epi32(_mm512_castsi512_si256(isum), _mm512_extracti32x8_epi32(isum, 1)); + for (int i = 2*(n/64); i < n/32; ++i) { + auto qx = _mm256_loadu_si256((const __m256i *)x + i); + auto qy = _mm256_loadu_si256((const __m256i *)y + i); + isum256 = _mm256_dpbusd_epi32(isum256, _mm256_add_epi8(qx, _mm256_set1_epi8(127)), qy); + } + info.store(0, 0, dx[0]*dy[0]*(hsum_i32_8(isum256) - 127*sy[0])); +#else + auto x = (const int8_t *)(dx + 2); + auto y = (const int8_t *)(dy + 2); + auto isum = _mm256_setzero_si256(); + for (int i = 0; i < n/32; ++i) { + auto qx = _mm256_loadu_si256((const __m256i *)x + i); + auto qy = _mm256_loadu_si256((const __m256i *)y + i); + auto dot = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(qy, qx)); + isum = _mm256_add_epi32(isum, _mm256_madd_epi16(_mm256_set1_epi16(1), dot)); + } + info.store(0, 0, dx[0]*dy[0]*hsum_i32_8(isum)); +#endif + return; + } + __m256i qx[2]; + __m256i acc[2*nrc_y] = {}; + float dy[nrc_y]; +#ifdef HAVE_FANCY_SIMD + int32_t sy[nrc_y]; +#else + __m256i sx[2]; + auto m1 = _mm256_set1_epi16(1); +#endif + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; +#ifdef HAVE_FANCY_SIMD + auto iptr = (const int32_t *)(dptr+1); + sy[iy] = -127*iptr[0]; +#endif + q8y[iy] = (const int8_t *)(dptr + 2); + } + for (int ix = 0; ix < nrc_x; ++ix) { + auto dx = (const float *)((const char *)vx + ix*bx); + auto q8x = (const int8_t *)(dx + 2); + for (int i = 0; i < n/64; ++i) { + for (int j = 0; j < 2; ++j) { +#ifdef HAVE_FANCY_SIMD + qx[j] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + 2*i + j), _mm256_set1_epi8(127)); +#else + qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j); + sx[j] = _mm256_sign_epi8(qx[j], qx[j]); +#endif + } + for (int iy = 0; iy < nrc_y; ++iy) { + for (int j = 0; j < 2; ++j) { +#ifdef HAVE_FANCY_SIMD + acc[2*iy+j] = _mm256_dpbusd_epi32(acc[2*iy+j], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j)); +#else + auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j])); + acc[2*iy+j] = _mm256_add_epi32(acc[2*iy+j], _mm256_madd_epi16(m1, dot)); +#endif + } + } + } + if (int i = 2*(n/64); i < n/32) { +#ifdef HAVE_FANCY_SIMD + qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127)); +#else + qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i); + sx[0] = _mm256_sign_epi8(qx[0], qx[0]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { +#ifdef HAVE_FANCY_SIMD + acc[2*iy] = _mm256_dpbusd_epi32(acc[2*iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i)); +#else + auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0])); + acc[2*iy] = _mm256_add_epi32(acc[2*iy], _mm256_madd_epi16(m1, dot)); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = hsum_i32_8(_mm256_add_epi32(acc[2*iy], acc[2*iy+1])); +#ifdef HAVE_FANCY_SIMD + info.store(ix, iy, dx[0]*dy[iy]*(sumi+sy[iy])); +#else + info.store(ix, iy, dx[0]*dy[iy]*sumi); +#endif + acc[2*iy] = acc[2*iy+1] = _mm256_setzero_si256(); + } + } +} + +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(n%32 == 0); + __m512i qx[4]; + __m512i acc[nrc_y <= 4 ? 2*nrc_y : nrc_y] = {}; + float dy[nrc_y]; + int32_t sy[nrc_y]; + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; + auto iptr = (const int32_t *)(dptr + 1); + sy[iy] = -64*iptr[0]; + q8y[iy] = (const int8_t *)(dptr + 2); + } + const int8_t * q8x[8]; + float dx[8]; + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int kx = 0; kx < 8; ++kx) { + auto dptr = (const float *)((const char *)vx + (ix+kx)*bx); + dx[kx] = dptr[0]; + q8x[kx] = (const int8_t *)(dptr + 2); + } + for (int i = 0; i < n/32; ++i) { + for (int kx = 0; kx < 4; ++kx) { + qx[kx] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8x[kx+0] + i)), + _mm256_loadu_si256((const __m256i *)q8x[kx+4] + i), 1); + } + auto t0 = _mm512_unpacklo_epi32(qx[0], qx[1]); + auto t1 = _mm512_unpacklo_epi32(qx[2], qx[3]); + auto t2 = _mm512_unpackhi_epi32(qx[0], qx[1]); + auto t3 = _mm512_unpackhi_epi32(qx[2], qx[3]); + qx[0] = _mm512_xor_si512(_mm512_unpacklo_epi64(t0, t1), _mm512_set1_epi8(-128)); + qx[1] = _mm512_xor_si512(_mm512_unpackhi_epi64(t0, t1), _mm512_set1_epi8(-128)); + qx[2] = _mm512_xor_si512(_mm512_unpacklo_epi64(t2, t3), _mm512_set1_epi8(-128)); + qx[3] = _mm512_xor_si512(_mm512_unpackhi_epi64(t2, t3), _mm512_set1_epi8(-128)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y256 = _mm256_loadu_si256((const __m256i *)q8y[iy] + i); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1); + if constexpr (nrc_y <= 4) { + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } else { + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } + } + } + auto scales_x = _mm256_loadu_ps(dx); + for (int iy = 0; iy < nrc_y; ++iy) { + if constexpr (nrc_y <= 4) { + auto ss = _mm512_add_epi32(_mm512_add_epi32(acc[2*iy+0], acc[2*iy+1]), _mm512_set1_epi32(sy[iy])); + auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 0), _mm512_extracti32x4_epi32(ss, 1)); + auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 2), _mm512_extracti32x4_epi32(ss, 3)); + auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy])); + info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1))); + info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2))); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512(); + } else { + acc[iy] = _mm512_add_epi32(acc[iy], _mm512_set1_epi32(sy[iy])); + auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 0), _mm512_extracti32x4_epi32(acc[iy], 1)); + auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 2), _mm512_extracti32x4_epi32(acc[iy], 3)); + auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy])); + info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1))); + info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2))); + acc[iy] = _mm512_setzero_si512(); + } + } + } +} +#endif +#endif + +template +inline std::pair mul_mat_kernel(int D, int int_typeA, int nq) { + auto typeA = ggml_type(int_typeA); + constexpr int kMaxQ = 8; +#define MAKE_FUNCS(mul_mat, n) \ + if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ>, kMaxQ);\ + else {\ + switch (n) {\ + case 1: return std::make_pair(mul_mat, 1>, 1);\ + case 2: return std::make_pair(mul_mat, 2>, 2);\ + case 3: return std::make_pair(mul_mat, 3>, 3);\ + case 4: return std::make_pair(mul_mat, 4>, 4);\ + case 5: return std::make_pair(mul_mat, 5>, 5);\ + case 6: return std::make_pair(mul_mat, 6>, 6);\ + case 7: return std::make_pair(mul_mat, 7>, 7);\ + }\ + } +#define MAKE_FUNCS_ONLY_NRC(mul_mat, n) \ + if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ);\ + else {\ + switch (n) {\ + case 1: return std::make_pair(mul_mat<1>, 1);\ + case 2: return std::make_pair(mul_mat<2>, 2);\ + case 3: return std::make_pair(mul_mat<3>, 3);\ + case 4: return std::make_pair(mul_mat<4>, 4);\ + case 5: return std::make_pair(mul_mat<5>, 5);\ + case 6: return std::make_pair(mul_mat<6>, 6);\ + case 7: return std::make_pair(mul_mat<7>, 7);\ + }\ + } + if (typeA == GGML_TYPE_Q8_KV) { +#ifdef __aarch64__ + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16); + if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1); + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq); +#else + if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1); +#ifdef HAVE_FANCY_SIMD + if (D%32 == 0 && k_step%8 == 0) { + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV_8<16>, 16); + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV_8, nq); + } else { + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16); + } +#endif + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq); +#endif + } + else if (typeA == GGML_TYPE_Q8_KV_R8) { + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq); + } + GGML_ABORT("Fatal error"); +} + +inline std::pair mul_mat_kernel(int D, int int_typeA, int nq, int k_step) { + switch (k_step) { + case 32: return mul_mat_kernel< 32>(D, int_typeA, nq); + case 64: return mul_mat_kernel< 64>(D, int_typeA, nq); + case 128: return mul_mat_kernel<128>(D, int_typeA, nq); + default: GGML_ABORT("Fatal error"); + } +} + +} + +void iqk_gemm_q8kv_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step) { + auto [mul_mat, nrc_q] = mul_mat_kernel(D, type_k, nq, k_step); + for (int iq = 0; iq < nq/nrc_q; ++iq) { + mul_mat(D, k, stride_k, info, k_step); + info.cur_y += nrc_q; + } + int iq = nrc_q*(nq/nrc_q); + if (iq < nq) { + auto [mul_mat1, nrc_q1] = mul_mat_kernel(D, type_k, nq - iq, k_step); + GGML_ASSERT(nrc_q1 == nq - iq); + mul_mat1(D, k, stride_k, info, k_step); + } +} + #endif diff --git a/ggml/src/iqk/iqk_gemm_kquants.h b/ggml/src/iqk/iqk_gemm_kquants.h index 96c5f2ca..071d2e50 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.h +++ b/ggml/src/iqk/iqk_gemm_kquants.h @@ -8,4 +8,6 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16); +void iqk_gemm_q8kv_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step); + #endif diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp index 64ae0c2f..4c64ac59 100644 --- a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp @@ -2635,4 +2635,129 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array +inline std::pair mul_mat_kernel(int int_typeA, int nq) { + auto typeA = ggml_type(int_typeA); + constexpr int kMaxQ = 8; +#define MAKE_FUNCS(mul_mat, n) \ + if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ>, kMaxQ);\ + else {\ + switch (n) {\ + case 1: return std::make_pair(mul_mat, 1>, 1);\ + case 2: return std::make_pair(mul_mat, 2>, 2);\ + case 3: return std::make_pair(mul_mat, 3>, 3);\ + case 4: return std::make_pair(mul_mat, 4>, 4);\ + case 5: return std::make_pair(mul_mat, 5>, 5);\ + case 6: return std::make_pair(mul_mat, 6>, 6);\ + case 7: return std::make_pair(mul_mat, 7>, 7);\ + }\ + } +#define MAKE_FUNCS_ONLY_NRC(mul_mat, n) \ + if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ);\ + else {\ + switch (n) {\ + case 1: return std::make_pair(mul_mat<1>, 1);\ + case 2: return std::make_pair(mul_mat<2>, 2);\ + case 3: return std::make_pair(mul_mat<3>, 3);\ + case 4: return std::make_pair(mul_mat<4>, 4);\ + case 5: return std::make_pair(mul_mat<5>, 5);\ + case 6: return std::make_pair(mul_mat<6>, 6);\ + case 7: return std::make_pair(mul_mat<7>, 7);\ + }\ + } + if (typeA == GGML_TYPE_Q8_0) { +#ifdef __aarch64__ + MAKE_FUNCS(mul_mat_qX_0_q8_0, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 4); + MAKE_FUNCS(mul_mat_qX_1_q8_2_T, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_0_Tx, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_0_Tx, 4); + MAKE_FUNCS(mul_mat_qX_0_q8_0_T, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 4); + MAKE_FUNCS(mul_mat_qX_1_q8_2_T, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 4); + MAKE_FUNCS(mul_mat_qX_1_q8_2_T(nullptr, 0); +} + +inline std::pair mul_mat_kernel(int int_typeA, int nq, int k_step) { + switch (k_step) { + case 32: return mul_mat_kernel< 32>(int_typeA, nq); + case 64: return mul_mat_kernel< 64>(int_typeA, nq); + case 128: return mul_mat_kernel<128>(int_typeA, nq); + default: GGML_ABORT("Fatal error"); + } +} +} + +void iqk_gemm_legacy_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step) { + auto [mul_mat, nrc_q] = mul_mat_kernel(type_k, nq, k_step); + for (int iq = 0; iq < nq/nrc_q; ++iq) { + mul_mat(D, k, stride_k, info, k_step); + info.cur_y += nrc_q; + } + int iq = nrc_q*(nq/nrc_q); + if (iq < nq) { + auto [mul_mat1, nrc_q1] = mul_mat_kernel(type_k, nq - iq, k_step); + GGML_ASSERT(nrc_q1 == nq - iq); + mul_mat1(D, k, stride_k, info, k_step); + } +} + #endif diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.h b/ggml/src/iqk/iqk_gemm_legacy_quants.h index 7e37ddad..a472d9bb 100644 --- a/ggml/src/iqk/iqk_gemm_legacy_quants.h +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.h @@ -5,7 +5,10 @@ #ifdef IQK_IMPLEMENT #include +#include bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16); +void iqk_gemm_legacy_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step); + #endif diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 2a3c850c..e22c4b00 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -487,191 +487,6 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n namespace { -template -static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(n%32 == 0); - if (nrc_y == 1 && nrc_x == 1) { - auto dx = (const float *)vx; - auto dy = (const float *)info.src1_row(0); -#ifdef HAVE_FANCY_SIMD - auto sy = (const int32_t *)(dy + 1); - auto x = (const int8_t *)(dx + 2); - auto y = (const int8_t *)(dy + 2); - auto isum = _mm512_setzero_si512(); - for (int i = 0; i < n/64; ++i) { - auto qx = _mm512_loadu_si512((const __m512i *)x + i); - auto qy = _mm512_loadu_si512((const __m512i *)y + i); - isum = _mm512_dpbusd_epi32(isum, _mm512_add_epi8(qx, _mm512_set1_epi8(127)), qy); - } - auto isum256 = _mm256_add_epi32(_mm512_castsi512_si256(isum), _mm512_extracti32x8_epi32(isum, 1)); - for (int i = 2*(n/64); i < n/32; ++i) { - auto qx = _mm256_loadu_si256((const __m256i *)x + i); - auto qy = _mm256_loadu_si256((const __m256i *)y + i); - isum256 = _mm256_dpbusd_epi32(isum256, _mm256_add_epi8(qx, _mm256_set1_epi8(127)), qy); - } - info.store(0, 0, dx[0]*dy[0]*(hsum_i32_8(isum256) - 127*sy[0])); -#else - auto x = (const int8_t *)(dx + 2); - auto y = (const int8_t *)(dy + 2); - auto isum = _mm256_setzero_si256(); - for (int i = 0; i < n/32; ++i) { - auto qx = _mm256_loadu_si256((const __m256i *)x + i); - auto qy = _mm256_loadu_si256((const __m256i *)y + i); - auto dot = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(qy, qx)); - isum = _mm256_add_epi32(isum, _mm256_madd_epi16(_mm256_set1_epi16(1), dot)); - } - info.store(0, 0, dx[0]*dy[0]*hsum_i32_8(isum)); -#endif - return; - } - __m256i qx[2]; - __m256i acc[2*nrc_y] = {}; - float dy[nrc_y]; -#ifdef HAVE_FANCY_SIMD - int32_t sy[nrc_y]; -#else - __m256i sx[2]; - auto m1 = _mm256_set1_epi16(1); -#endif - const int8_t * q8y[nrc_y]; - for (int iy = 0; iy < nrc_y; ++iy) { - auto dptr = (const float *)info.src1_row(iy); - dy[iy] = dptr[0]; -#ifdef HAVE_FANCY_SIMD - auto iptr = (const int32_t *)(dptr+1); - sy[iy] = -127*iptr[0]; -#endif - q8y[iy] = (const int8_t *)(dptr + 2); - } - for (int ix = 0; ix < nrc_x; ++ix) { - auto dx = (const float *)((const char *)vx + ix*bx); - auto q8x = (const int8_t *)(dx + 2); - for (int i = 0; i < n/64; ++i) { - for (int j = 0; j < 2; ++j) { -#ifdef HAVE_FANCY_SIMD - qx[j] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + 2*i + j), _mm256_set1_epi8(127)); -#else - qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j); - sx[j] = _mm256_sign_epi8(qx[j], qx[j]); -#endif - } - for (int iy = 0; iy < nrc_y; ++iy) { - for (int j = 0; j < 2; ++j) { -#ifdef HAVE_FANCY_SIMD - acc[2*iy+j] = _mm256_dpbusd_epi32(acc[2*iy+j], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j)); -#else - auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j])); - acc[2*iy+j] = _mm256_add_epi32(acc[2*iy+j], _mm256_madd_epi16(m1, dot)); -#endif - } - } - } - if (int i = 2*(n/64); i < n/32) { -#ifdef HAVE_FANCY_SIMD - qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127)); -#else - qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i); - sx[0] = _mm256_sign_epi8(qx[0], qx[0]); -#endif - for (int iy = 0; iy < nrc_y; ++iy) { -#ifdef HAVE_FANCY_SIMD - acc[2*iy] = _mm256_dpbusd_epi32(acc[2*iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i)); -#else - auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0])); - acc[2*iy] = _mm256_add_epi32(acc[2*iy], _mm256_madd_epi16(m1, dot)); -#endif - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sumi = hsum_i32_8(_mm256_add_epi32(acc[2*iy], acc[2*iy+1])); -#ifdef HAVE_FANCY_SIMD - info.store(ix, iy, dx[0]*dy[iy]*(sumi+sy[iy])); -#else - info.store(ix, iy, dx[0]*dy[iy]*sumi); -#endif - acc[2*iy] = acc[2*iy+1] = _mm256_setzero_si256(); - } - } -} - -#ifdef HAVE_FANCY_SIMD -template -static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); - GGML_ASSERT(n%32 == 0); - __m512i qx[4]; - __m512i acc[nrc_y <= 4 ? 2*nrc_y : nrc_y] = {}; - float dy[nrc_y]; - int32_t sy[nrc_y]; - const int8_t * q8y[nrc_y]; - for (int iy = 0; iy < nrc_y; ++iy) { - auto dptr = (const float *)info.src1_row(iy); - dy[iy] = dptr[0]; - auto iptr = (const int32_t *)(dptr + 1); - sy[iy] = -64*iptr[0]; - q8y[iy] = (const int8_t *)(dptr + 2); - } - const int8_t * q8x[8]; - float dx[8]; - for (int ix = 0; ix < nrc_x; ix += 8) { - for (int kx = 0; kx < 8; ++kx) { - auto dptr = (const float *)((const char *)vx + (ix+kx)*bx); - dx[kx] = dptr[0]; - q8x[kx] = (const int8_t *)(dptr + 2); - } - for (int i = 0; i < n/32; ++i) { - for (int kx = 0; kx < 4; ++kx) { - qx[kx] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8x[kx+0] + i)), - _mm256_loadu_si256((const __m256i *)q8x[kx+4] + i), 1); - } - auto t0 = _mm512_unpacklo_epi32(qx[0], qx[1]); - auto t1 = _mm512_unpacklo_epi32(qx[2], qx[3]); - auto t2 = _mm512_unpackhi_epi32(qx[0], qx[1]); - auto t3 = _mm512_unpackhi_epi32(qx[2], qx[3]); - qx[0] = _mm512_xor_si512(_mm512_unpacklo_epi64(t0, t1), _mm512_set1_epi8(-128)); - qx[1] = _mm512_xor_si512(_mm512_unpackhi_epi64(t0, t1), _mm512_set1_epi8(-128)); - qx[2] = _mm512_xor_si512(_mm512_unpacklo_epi64(t2, t3), _mm512_set1_epi8(-128)); - qx[3] = _mm512_xor_si512(_mm512_unpackhi_epi64(t2, t3), _mm512_set1_epi8(-128)); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y256 = _mm256_loadu_si256((const __m256i *)q8y[iy] + i); - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1); - if constexpr (nrc_y <= 4) { - acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - } else { - acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - } - } - } - auto scales_x = _mm256_loadu_ps(dx); - for (int iy = 0; iy < nrc_y; ++iy) { - if constexpr (nrc_y <= 4) { - auto ss = _mm512_add_epi32(_mm512_add_epi32(acc[2*iy+0], acc[2*iy+1]), _mm512_set1_epi32(sy[iy])); - auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 0), _mm512_extracti32x4_epi32(ss, 1)); - auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 2), _mm512_extracti32x4_epi32(ss, 3)); - auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy])); - info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1))); - info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2))); - acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512(); - } else { - acc[iy] = _mm512_add_epi32(acc[iy], _mm512_set1_epi32(sy[iy])); - auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 0), _mm512_extracti32x4_epi32(acc[iy], 1)); - auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 2), _mm512_extracti32x4_epi32(acc[iy], 3)); - auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy])); - info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1))); - info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2))); - acc[iy] = _mm512_setzero_si512(); - } - } - } -} -#endif - bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { (void)Ny; @@ -1281,6 +1096,7 @@ template struct HelperQ8KV final : public BaseHelper { using Base = BaseHelper; using block_q8 = block_q8_KV; + constexpr static ggml_type type = GGML_TYPE_Q8_KV; constexpr static int block_size_q = D; HelperQ8KV(const char * data, int stride) : Base(data, stride) {} @@ -1308,6 +1124,7 @@ struct HelperQ8KV final : public BaseHelper { template struct HelperQ80 final : public BaseHelper { using Base = BaseHelper; + constexpr static ggml_type type = GGML_TYPE_Q8_0; #ifdef HAVE_FANCY_SIMD using block_q8 = block_q8_2; constexpr static int block_size_q = QK8_2; @@ -1491,6 +1308,7 @@ namespace { template struct HelperQ80R8 : public BaseHelper { using Base = BaseHelper; + constexpr static ggml_type type = GGML_TYPE_Q8_0_R8; #ifdef __AVX2__ constexpr static int block_size_q = QK8_2; using block_q8 = block_q8_2; @@ -1613,6 +1431,7 @@ struct HelperQ80R8 : public BaseHelper { template struct HelperQ8KVR8 : public BaseHelper { using Base = BaseHelper; + constexpr static ggml_type type = GGML_TYPE_Q8_KV_R8; constexpr static int block_size_q = D; using block_q8 = block_q8_KV; @@ -1708,6 +1527,7 @@ struct HelperQ8KVR8 : public BaseHelper { template struct HelperQ40 final : public BaseHelper { using Base = BaseHelper; + constexpr static ggml_type type = GGML_TYPE_Q4_0; #if defined __AVX2__ using block_q8 = block_q8_2; constexpr static int block_size_q = QK8_2; @@ -1758,6 +1578,7 @@ template struct HelperQ41 final : public BaseHelper { using Base = BaseHelper; using block_q8 = block_q8_2; + constexpr static ggml_type type = GGML_TYPE_Q4_1; constexpr static int block_size_q = QK8_2; HelperQ41(const char * data, int stride) : Base(data, stride) {} @@ -1800,6 +1621,7 @@ struct HelperQ41 final : public BaseHelper { template struct HelperIQ4nl final : public BaseHelper { using Base = BaseHelper; + constexpr static ggml_type type = GGML_TYPE_IQ4_NL; #ifdef __aarch64__ using block_q8 = block_q8_0; HelperIQ4nl(const char * data, int stride) : Base(data, stride), values(vld1q_s8(iq4k_values)) {} @@ -1854,6 +1676,7 @@ struct HelperIQ4nl final : public BaseHelper { template struct HelperQ60 final : public BaseHelper { + constexpr static ggml_type type = GGML_TYPE_Q6_0; #ifdef __aarch64__ using block_q8 = block_q8_0; constexpr static int block_size_q = QK8_0; @@ -2409,28 +2232,13 @@ struct FlashQKfp32 { static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS& fms) { #ifdef HAVE_FANCY_SIMD - constexpr int nrc_q = 8; constexpr int nrc_k = 8; #else - // somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4 - constexpr int nrc_q = 4; constexpr int nrc_k = 8; #endif - constexpr int qrem = q_step - nrc_q*(q_step/nrc_q); - constexpr int krem = k_step - nrc_k*(k_step/nrc_k); - static_assert(krem == 0); + static_assert(k_step%nrc_k == 0); DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; - for (int iq = 0; iq < q_step/nrc_q; ++iq) { - for (int ik = 0; ik < k_step/nrc_k; ++ik) { - mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); - } - info.cur_y += nrc_q; - } - if constexpr (qrem > 0) { - for (int ik = 0; ik < k_step/nrc_k; ++ik) { - mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); - } - } + iqk_gemm_default_floats(D, q_step, kh.block, kh.stride, info, k_step); F16::Data vk[k_step/F16::block_size]; for (int j = 0; j < q_step; ++j) { fms.update_M_S(j, vk, mask + stride_m*j); @@ -2473,64 +2281,10 @@ struct FlashQKfp32 { template static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS& fms) { -#ifdef HAVE_FANCY_SIMD - constexpr int nrc_q = 8; constexpr int nrc_k = 8; -#else - // somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4 - constexpr int nrc_q = 4; - constexpr int nrc_k = 8; -#endif static_assert(k_step%nrc_k == 0); - int qrem = nq - nrc_q*(nq/nrc_q); DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; - for (int iq = 0; iq < nq/nrc_q; ++iq) { - for (int ik = 0; ik < k_step/nrc_k; ++ik) { - mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); - } - info.cur_y += nrc_q; - } - if (qrem > 0) { - switch (qrem) { - case 1: { - for (int ik = 0; ik < k_step/nrc_k; ++ik) { - mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); - } - } break; - case 2: { - for (int ik = 0; ik < k_step/nrc_k; ++ik) { - mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); - } - } break; - case 3: { - for (int ik = 0; ik < k_step/nrc_k; ++ik) { - mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); - } - } break; -#ifdef HAVE_FANCY_SIMD - case 4: { - for (int ik = 0; ik < k_step/nrc_k; ++ik) { - mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); - } - } break; - case 5: { - for (int ik = 0; ik < k_step/nrc_k; ++ik) { - mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); - } - } break; - case 6: { - for (int ik = 0; ik < k_step/nrc_k; ++ik) { - mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); - } - } break; - case 7: { - for (int ik = 0; ik < k_step/nrc_k; ++ik) { - mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); - } - } break; -#endif - } - } + iqk_gemm_default_floats(D, nq, kh.block, kh.stride, info, k_step); F16::Data vk[k_step/F16::block_size]; for (int j = 0; j < nq; ++j) { fms.update_M_S(j, vk, mask + stride_m*j); @@ -2603,136 +2357,17 @@ struct FlashQKfp32 { } #endif - template - static inline std::pair mul_mat_kernel(int nq) { - constexpr int kMaxQ = 8; -#define MAKE_FUNCS(mul_mat, n) \ - if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ>, kMaxQ);\ - else {\ - switch (n) {\ - case 1: return std::make_pair(mul_mat, 1>, 1);\ - case 2: return std::make_pair(mul_mat, 2>, 2);\ - case 3: return std::make_pair(mul_mat, 3>, 3);\ - case 4: return std::make_pair(mul_mat, 4>, 4);\ - case 5: return std::make_pair(mul_mat, 5>, 5);\ - case 6: return std::make_pair(mul_mat, 6>, 6);\ - case 7: return std::make_pair(mul_mat, 7>, 7);\ - }\ - } -#define MAKE_FUNCS_ONLY_NRC(mul_mat, n) \ - if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ);\ - else {\ - switch (n) {\ - case 1: return std::make_pair(mul_mat<1>, 1);\ - case 2: return std::make_pair(mul_mat<2>, 2);\ - case 3: return std::make_pair(mul_mat<3>, 3);\ - case 4: return std::make_pair(mul_mat<4>, 4);\ - case 5: return std::make_pair(mul_mat<5>, 5);\ - case 6: return std::make_pair(mul_mat<6>, 6);\ - case 7: return std::make_pair(mul_mat<7>, 7);\ - }\ - } - if constexpr (std::is_same_v>) { -#ifdef __aarch64__ - MAKE_FUNCS(mul_mat_qX_0_q8_0, 1); - if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 2); - if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 4); - MAKE_FUNCS(mul_mat_qX_1_q8_2_T, 1); - if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_0_Tx, 2); - if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_0_Tx, 4); - MAKE_FUNCS(mul_mat_qX_0_q8_0_T>) { -#ifdef __aarch64__ - if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16); - if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1); - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq); -#else - if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1); -#ifdef HAVE_FANCY_SIMD - if constexpr (D%32 == 0 && k_step%8 == 0) { - if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV_8<16>, 16); - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV_8, nq); - } else { - if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16); - } -#endif - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq); -#endif - } - else if constexpr (std::is_same_v>) { -#ifdef __aarch64__ - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq); -#else - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_2, nq); -#endif - } - else if constexpr (std::is_same_v>) { - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq); - } - else if constexpr (std::is_same_v>) { -#ifdef __aarch64__ - MAKE_FUNCS(mul_mat_qX_0_q8_0, 1); - if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 2); - if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 4); - MAKE_FUNCS(mul_mat_qX_1_q8_2_T>) { -#ifdef __aarch64__ - MAKE_FUNCS(mul_mat_qX_0_q8_0, 1); - if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 2); - if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 4); - MAKE_FUNCS(mul_mat_qX_1_q8_2_T>) { -#ifdef __aarch64__ - MAKE_FUNCS(mul_mat_qX_1_q8_1>) { -#ifdef __aarch64__ - MAKE_FUNCS(mul_mat_qX_0_q8_0(nullptr, 0); - } - template static inline void mul_mask_kq(const KHelper& kh, int stride_m, const block_q8 * q, const char * mask, FlashMS& fms) { constexpr int kMaxQ = 8; static_assert(q_step < kMaxQ || q_step%kMaxQ == 0); - auto [mul_mat, nrc_q] = mul_mat_kernel(q_step); DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr}; - for (int iq = 0; iq < q_step/nrc_q; ++iq) { - mul_mat(D, kh.block, kh.stride, info, k_step); - info.cur_y += nrc_q; + if constexpr (std::is_same_v> || + std::is_same_v>) { + iqk_gemm_q8kv_fa(D, q_step, kh.type, kh.block, kh.stride, info, k_step); + } else { + iqk_gemm_legacy_fa(D, q_step, kh.type, kh.block, kh.stride, info, k_step); } #ifdef __aarch64__ float32x4_t vk[k_step/4]; @@ -2750,17 +2385,13 @@ struct FlashQKfp32 { template static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m, const block_q8 * q, const char * mask, FlashMS& fms) { - auto [mul_mat, nrc_q] = mul_mat_kernel(nq); + GGML_ASSERT(nq < q_step); DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr}; - for (int iq = 0; iq < nq/nrc_q; ++iq) { - mul_mat(D, kh.block, kh.stride, info, k_step); - info.cur_y += nrc_q; - } - int iq = nrc_q*(nq/nrc_q); - if (iq < nq) { - auto [mul_mat1, nrc_q1] = mul_mat_kernel(nq - iq); - GGML_ASSERT(nrc_q1 == nq - iq); - mul_mat1(D, kh.block, kh.stride, info, k_step); + if constexpr (std::is_same_v> || + std::is_same_v>) { + iqk_gemm_q8kv_fa(D, nq, kh.type, kh.block, kh.stride, info, k_step); + } else { + iqk_gemm_legacy_fa(D, nq, kh.type, kh.block, kh.stride, info, k_step); } #ifdef __aarch64__ float32x4_t vk[k_step/4];