diff --git a/ggml/src/iqk/iqk_gemm_1bit.cpp b/ggml/src/iqk/iqk_gemm_1bit.cpp index 46f1820c..770fbf2c 100644 --- a/ggml/src/iqk/iqk_gemm_1bit.cpp +++ b/ggml/src/iqk/iqk_gemm_1bit.cpp @@ -1722,6 +1722,8 @@ void iqk_convert_iq1_m_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int __m256i qx[8]; + auto mask = _mm256_setr_epi32(0x00000008, 0x00000008, 0x00000080, 0x00000080, 0x00080000, 0x00080000, 0x00800000, 0x00800000); + for (int ix = 0; ix < nrc_x; ix += 8) { for (int k = 0; k < 8; ++k) x8[k] = (const block_iq1_m *)((const char *)vx + (ix + k)*bx); for (int i = 0; i < nb; ++i) { @@ -1739,12 +1741,17 @@ void iqk_convert_iq1_m_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int value = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | ((qh[1] << 8) & 0x700)], iq1s_grid[qs[1] | ((qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | ((qh[0] << 8) & 0x700)]); value = _mm256_slli_epi16(_mm256_add_epi8(value, _mm256_set1_epi8(1)), 3); - int64_t delta1 = qh[0] & 0x08 ? 0x0909090909090909 : 0x0707070707070707; - int64_t delta2 = qh[0] & 0x80 ? 0x0909090909090909 : 0x0707070707070707; - int64_t delta3 = qh[1] & 0x08 ? 0x0909090909090909 : 0x0707070707070707; - int64_t delta4 = qh[1] & 0x80 ? 0x0909090909090909 : 0x0707070707070707; - value = _mm256_sub_epi8(value, _mm256_set_epi64x(delta4, delta3, delta2, delta1)); - qx[ib32] = value; + + auto delta_mask = _mm256_cmpeq_epi32(_mm256_and_si256(_mm256_set1_epi32(qh[0] | qh[1] << 16), mask), mask); + auto delta = _mm256_add_epi8(_mm256_set1_epi8(7), _mm256_and_si256(delta_mask, _mm256_set1_epi8(2))); + qx[ib32] = _mm256_sub_epi8(value, delta); + + //int64_t delta1 = qh[0] & 0x08 ? 0x0909090909090909 : 0x0707070707070707; + //int64_t delta2 = qh[0] & 0x80 ? 0x0909090909090909 : 0x0707070707070707; + //int64_t delta3 = qh[1] & 0x08 ? 0x0909090909090909 : 0x0707070707070707; + //int64_t delta4 = qh[1] & 0x80 ? 0x0909090909090909 : 0x0707070707070707; + //value = _mm256_sub_epi8(value, _mm256_set_epi64x(delta4, delta3, delta2, delta1)); + //qx[ib32] = value; qs += 4; qh += 2; }