diff --git a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp index 166c6274..b20d74ec 100644 --- a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp @@ -376,14 +376,20 @@ struct DequantizerIQ2KL final : public BaseDequantizer inline void prepare(int i) { __m512i ql[2], qs[4]; __mmask64 mask[2]; - // TODO: optimize this - for (int k = 0; k < 2; ++k) { - auto b1 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+0); - auto b2 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+1); - auto c1 = MM256_SET_M128I(_mm_srli_epi16(b1, 4), b1); - auto c2 = MM256_SET_M128I(_mm_srli_epi16(b2, 4), b2); - ql[k] = _mm512_and_si512(m4, _mm512_inserti32x8(_mm512_castsi256_si512(c1), c2, 1)); - } + //// TODO: optimize this + //for (int k = 0; k < 2; ++k) { + // auto b1 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+0); + // auto b2 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+1); + // auto c1 = MM256_SET_M128I(_mm_srli_epi16(b1, 4), b1); + // auto c2 = MM256_SET_M128I(_mm_srli_epi16(b2, 4), b2); + // ql[k] = _mm512_and_si512(m4, _mm512_inserti32x8(_mm512_castsi256_si512(c1), c2, 1)); + //} + auto lbits = _mm512_loadu_si512((const __m512i *)x[i].qs); + ql[0] = _mm512_and_si512(lbits, m4); + ql[1] = _mm512_and_si512(_mm512_srli_epi16(lbits, 4), m4); + auto tmp = _mm512_permutex2var_epi64(ql[0], permute1, ql[1]); + ql[1] = _mm512_permutex2var_epi64(ql[0], permute2, ql[1]); + ql[0] = tmp; auto h128 = _mm_loadu_si128((const __m128i *)x[i].qh); auto h256 = MM256_SET_M128I(_mm_srli_epi16(h128, 1), h128); auto h512 = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 2), 1); @@ -401,12 +407,16 @@ struct DequantizerIQ2KL final : public BaseDequantizer for (int l = 0; l < 4; ++l) qs[l] = _mm512_shuffle_epi8(values[l], ql[k]); auto q1 = _mm512_mask_blend_epi8(mask[k], qs[0], qs[1]); auto q2 = _mm512_mask_blend_epi8(mask[k], qs[2], qs[3]); - auto q1l = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(q1)); - auto q1h = _mm512_cvtepu8_epi16(_mm512_extracti32x8_epi32(q1, 1)); - auto q2l = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(q2)); - auto q2h = _mm512_cvtepu8_epi16(_mm512_extracti32x8_epi32(q2, 1)); - bits.values[2*k+0] = _mm512_or_si512(q1l, _mm512_slli_epi16(q2l, 8)); - bits.values[2*k+1] = _mm512_or_si512(q1h, _mm512_slli_epi16(q2h, 8)); + auto t1 = _mm512_unpacklo_epi8(q1, q2); // 0...15, 32...47, 64...79, 96...111 + auto t2 = _mm512_unpackhi_epi8(q1, q2); // 16...31, 48...63, 80...95, 112...127 + bits.values[2*k+0] = _mm512_permutex2var_epi64(t1, permute1, t2); + bits.values[2*k+1] = _mm512_permutex2var_epi64(t1, permute2, t2); + //auto q1l = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(q1)); + //auto q1h = _mm512_cvtepu8_epi16(_mm512_extracti32x8_epi32(q1, 1)); + //auto q2l = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(q2)); + //auto q2h = _mm512_cvtepu8_epi16(_mm512_extracti32x8_epi32(q2, 1)); + //bits.values[2*k+0] = _mm512_or_si512(q1l, _mm512_slli_epi16(q2l, 8)); + //bits.values[2*k+1] = _mm512_or_si512(q1h, _mm512_slli_epi16(q2h, 8)); } } void load_values() { @@ -426,6 +436,8 @@ struct DequantizerIQ2KL final : public BaseDequantizer const __m512i m01 = _mm512_set1_epi8(0x01); const __m512i m10 = _mm512_set1_epi8(0x10); const __m512i m4 = _mm512_set1_epi8(0xf); + const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0); + const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4); __m512i values[4]; const __m512i shuffles[4] = { _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),