diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp index 3cabf2b9..ff5090cc 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.cpp +++ b/ggml/src/iqk/iqk_gemm_kquants.cpp @@ -719,6 +719,135 @@ static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf #endif +// inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) { +// make_q4_scales(data, utmp); +// const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); +// const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1); +// accum_mins(mins128, q8, i, c, accd); +// const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); +// return MM256_SET_M128I(sc128, sc128); +// } +// +// inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) { +// d = GGML_FP16_TO_FP32(x[i].d); +// bits.prepare(x[i].qs); +// auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd); +// scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]); +// scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]); +// } + + +struct Q4Bits_AVX2 { + inline void prepare(const uint8_t * q4, int j) { + auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0); + values[0] = _mm256_and_si256(q4bits, ml); + values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); + q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1); + values[2] = _mm256_and_si256(q4bits, ml); + values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); + } + __m256i values[4]; + const __m256i ml = _mm256_set1_epi8(0xf); +}; + +struct DequantizerQ4K_AVX2 final : public BaseDequantizer { + DequantizerQ4K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + template + inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { + d = GGML_FP16_TO_FP32(x[i].d); + return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + } + + Q4Bits_AVX2 bits; + Scales8K s8k; +}; + +template +static void mul_mat_qX_K_q8_2_X4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8 q8(info); + + Dequantizer deq(vx, bx); + + uint32_t utmp[4]; + __m256 accd[nrc_y]; + __m256 scales[2]; + float d8[8*nrc_y]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + deq.d = GGML_FP16_TO_FP32(deq.x[i].d); + auto vm = _mm256_cvtph_ps(_mm_set1_epi16(deq.x[i].dmin)); + make_q4_scales(deq.x[i].scales, utmp); + auto mins = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(utmp + 2))))); + mins = _mm256_mul_ps(_mm256_set1_ps(-1.f), mins); + for (int iy = 0; iy < nrc_y; ++iy) { + auto d4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d))); + auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d))); + auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16)); + _mm256_storeu_ps(d8 + 8*iy, dy); + auto m4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d+4))); + auto m4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d+4))); + auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(m4_2, m4_1), 16)); + accd[iy] = _mm256_fmadd_ps(my, mins, accd[iy]); + } + + auto all_scales = _mm256_mul_ps(_mm256_set1_ps(deq.d), _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)utmp)))); + scales[0] = _mm256_set_m128(_mm256_castps256_ps128(all_scales), _mm256_castps256_ps128(all_scales)); + auto scales_h = _mm256_extractf128_ps(all_scales, 1); + scales[1] = _mm256_set_m128(scales_h, scales_h); + + for (int j = 0; j < QK_K/128; ++j) { + + deq.prepare(i, j); + + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_2_x4& y = q8.y[iy][2*i+j]; +#ifdef z_HAVE_FANCY_SIMD + auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[0], _mm256_loadu_si256((const __m256i*)y.qs+0)); + auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[1], _mm256_loadu_si256((const __m256i*)y.qs+1)); + auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[2], _mm256_loadu_si256((const __m256i*)y.qs+2)); + auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[3], _mm256_loadu_si256((const __m256i*)y.qs+3)); + sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); + sumi3 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); + sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3)); +#else + auto sumi1 = _mm256_maddubs_epi16(deq.bits.values[0], _mm256_loadu_si256((const __m256i*)y.qs+0)); + auto sumi2 = _mm256_maddubs_epi16(deq.bits.values[1], _mm256_loadu_si256((const __m256i*)y.qs+1)); + auto sumi3 = _mm256_maddubs_epi16(deq.bits.values[2], _mm256_loadu_si256((const __m256i*)y.qs+2)); + auto sumi4 = _mm256_maddubs_epi16(deq.bits.values[3], _mm256_loadu_si256((const __m256i*)y.qs+3)); + sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); + sumi3 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); + sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3)); + sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), sumi1); +#endif + auto dy4 = _mm_loadu_ps(d8 + 8*iy + 4*j); + auto d4d8 = _mm256_mul_ps(scales[j], _mm256_set_m128(dy4, dy4)); + accd[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi1), accd[iy]); + } + + } + + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + + } +} + template static void mul_mat_iq4_xs_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); @@ -1781,6 +1910,7 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array(kernels); break; case GGML_TYPE_Q4_K: - set_functions(kernels); + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_X4_T, DequantizerQ4K_AVX2, kernels); + //set_functions(kernels); break; case GGML_TYPE_Q5_K: set_functions(kernels);