diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 44a5df4e..96d581cc 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -950,7 +950,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q3_K, .from_float_ref = (ggml_from_float_t) quantize_row_q3_K_ref, .vec_dot = ggml_vec_dot_q3_K_q8_K, +#ifdef __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_2_X4, +#else .vec_dot_type = GGML_TYPE_Q8_K, +#endif .nrows = 1, .row_meta_size = 0, }, diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp index 8d949a4f..ecc17092 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.cpp +++ b/ggml/src/iqk/iqk_gemm_kquants.cpp @@ -879,11 +879,59 @@ struct DequantizerQ6K_AVX2 final : public BaseDequantizer { us[k] = _mm256_sign_epi8(bits.values[k], bits.values[k]); } } + inline __m256i make_scales(int i) const { + return _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)x[i].scales)); + } const __m256i mh = _mm256_set1_epi8(0x30); Q4Bits_AVX2 bits; }; +struct SimpleBits { + __m256i values[4]; +}; + +struct DequantizerQ3K_AVX2 final : public BaseDequantizer { + DequantizerQ3K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + + inline void prepare(int i, int j) { + hbits = j == 0 ? _mm256_loadu_si256((const __m256i *)x[i].hmask) : _mm256_srli_epi16(hbits, 4); + auto q2bits = _mm256_loadu_si256((const __m256i *)x[i].qs + j); + bits.values[0] = _mm256_and_si256(q2bits, ml); + bits.values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml); + bits.values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml); + bits.values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml); + bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); + bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh)); + bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh)); + bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh)); + //bits.values[0] = _mm256_sub_epi8(bits.values[0], _mm256_xor_si256(mh, _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh))); + //bits.values[1] = _mm256_sub_epi8(bits.values[1], _mm256_xor_si256(mh, _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh))); + //bits.values[2] = _mm256_sub_epi8(bits.values[2], _mm256_xor_si256(mh, _mm256_and_si256(hbits, mh))); + //bits.values[3] = _mm256_sub_epi8(bits.values[3], _mm256_xor_si256(mh, _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh))); + } + inline void prepare_signed(int i, int j, __m256i * us) { + prepare(i, j); + for (int k = 0; k < 4; ++k) { + bits.values[k] = _mm256_sub_epi8(bits.values[k], mh); + us[k] = _mm256_sign_epi8(bits.values[k], bits.values[k]); + } + //for (int k = 0; k < 4; ++k) { + // us[k] = _mm256_sign_epi8(bits.values[k], bits.values[k]); + //} + } + inline __m256i make_scales(int i) const { + return _mm256_cvtepi8_epi16(sc3.make_scales((const uint16_t *)x[i].scales)); + } + + ScaleQ3 sc3; + + __m256i hbits; + SimpleBits bits; + const __m256i ml = _mm256_set1_epi8(3); + const __m256i mh = _mm256_set1_epi8(4); +}; + template static void mul_mat_qY_K_q8_2_X4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); @@ -911,7 +959,7 @@ static void mul_mat_qY_K_q8_2_X4_T(int n, const void * vx, size_t bx, const Data deq.d = GGML_FP16_TO_FP32(deq.x[i].d); auto vd = _mm256_set1_ps(deq.d); - auto sc16 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)deq.x[i].scales)), shuff); + auto sc16 = _mm256_shuffle_epi8(deq.make_scales(i), shuff); scales[0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(sc16)))); scales[1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(sc16, 1)))); for (int iy = 0; iy < nrc_y; ++iy) { @@ -2188,6 +2236,118 @@ void iqk_convert_q6_k_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int } } +//struct DequantizerQ3K final : public BaseDequantizer { +// DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} +// +// template +// inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { +// d = GGML_FP16_TO_FP32(x[i].d); +// hbits.load(x[i].hmask); +// process_mins_and_scales_16(sc3.make_scales((const uint16_t *)x[i].scales), q8, i, -4.f*d, accm, scales); +// } +// inline void prepare(int i, int j) { +// bits.prepare(x[i].qs, j); +// hbits.apply(bits, j == 0); +// } +// +// Q2Bits bits; +// HighBit3 hbits; +// ScaleQ3 sc3; +// +// const __m128i m32 = _mm_set1_epi8(-32); +//}; + +void iqk_convert_q3_k_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_q3_K * x8[8]; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + float all_s[64]; + uint32_t block[8]; + __m256i values[8]; + + ScaleQ3 sc3; + auto ml = _mm256_set1_epi8(0x03); + auto mh = _mm256_set1_epi8(0x04); + + union { __m256i vec; int16_t val[16]; } helper; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_q3_K *)((const char *)vx + (ix + k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + float d = GGML_FP16_TO_FP32(x8[k][i].d); + auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].hmask); + for (int i128 = 0; i128 < 2; ++i128) { + auto q2bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs + i128); + values[4*i128+0] = _mm256_and_si256(q2bits, ml); + values[4*i128+1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml); + values[4*i128+2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml); + values[4*i128+3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml); + values[4*i128+0] = _mm256_or_si256(values[4*i128+0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); + values[4*i128+1] = _mm256_or_si256(values[4*i128+1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh)); + values[4*i128+2] = _mm256_or_si256(values[4*i128+2], _mm256_and_si256(hbits, mh)); + values[4*i128+3] = _mm256_or_si256(values[4*i128+3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh)); + values[4*i128+0] = _mm256_sub_epi8(values[4*i128+0], mh); + values[4*i128+1] = _mm256_sub_epi8(values[4*i128+1], mh); + values[4*i128+2] = _mm256_sub_epi8(values[4*i128+2], mh); + values[4*i128+3] = _mm256_sub_epi8(values[4*i128+3], mh); + hbits = _mm256_srli_epi16(hbits, 4); + } + helper.vec = _mm256_cvtepi8_epi16(sc3.make_scales((const uint16_t *)x8[k][i].scales)); + for (int ib32 = 0; ib32 < 8; ++ib32) { + auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(values[ib32])); + auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(values[ib32], 1)); + q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(helper.val[2*ib32+0])); + q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(helper.val[2*ib32+1])); + auto abs_q16_l = _mm256_sign_epi16(q16_l, q16_l); + auto abs_q16_h = _mm256_sign_epi16(q16_h, q16_h); + auto max_q16 = _mm256_max_epi16(abs_q16_l, abs_q16_h); + auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_q16), _mm256_extracti128_si256(max_q16, 1))); + auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1)); + auto max4 = _mm_cvtepi32_ps(imax4); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + float max = _mm_cvtss_f32(max4) / 127; + all_s[8*ib32+k] = d*max; + if (max > 1e-9f) { + auto scale = _mm256_set1_ps(1/max); + auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l)); + auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1)); + auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h)); + auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1)); + i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST)); + i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST)); + i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST)); + i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST)); + i0 = _mm256_packs_epi32(i0, i1); + i2 = _mm256_packs_epi32(i2, i3); + i0 = _mm256_packs_epi16(i0, i2); + i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); + _mm256_storeu_si256((__m256i *)block, i0); + } else { + _mm256_storeu_si256((__m256i *)block, _mm256_setzero_si256()); + } + auto qs = (uint32_t *)y[ib32].qs; + for (int l = 0; l < 4; ++l) { + qs[8*l + k + 0] = block[l + 0]; + qs[8*l + k + 32] = block[l + 4]; + } + } + } + for (int ib32 = 0; ib32 < 8; ++ib32) { + _mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(_mm256_loadu_ps(all_s + 8*ib32), _MM_FROUND_TO_NEAREST_INT)); + } + y += QK_K/32; + } + } +} + } // namespace @@ -2197,7 +2357,8 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array(kernels); break; case GGML_TYPE_Q3_K: - set_functions(kernels); + //set_functions(kernels); + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_2_X4_T, DequantizerQ3K_AVX2, kernels); break; case GGML_TYPE_Q4_K: IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_X4_T, DequantizerQ4K_AVX2, kernels); @@ -2272,6 +2434,7 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ3_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_Q3_K : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_Q4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; case GGML_TYPE_Q6_K : return nrc_y >= 64 ? GGML_TYPE_Q8_0_R8 : type; @@ -347,7 +348,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, //case GGML_TYPE_BF16_R16: // return iqk_set_kernels_float(ne00, typeA, typeB, mm.funcs); //case GGML_TYPE_Q2_K: - //case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: