diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 57d914ad..8d49248b 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -4026,6 +4026,30 @@ struct IndexHelperIQ3S { }; #endif +//void dequantize_row_iq3_s_r4(const block_iq3_s_r4 * x, float * y, int64_t k) { +// auto n_per_row = k/4; +// float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; +// int nblock = n_per_row/QK_K; +// for (int ibl = 0; ibl < nblock; ++ibl) { +// for (int k = 0; k < 4; ++k) { +// const float d = GGML_FP16_TO_FP32(x[ibl].d[k]); +// for (int ib = 0; ib < QK_K/32; ++ib) { +// int l = 4*ib + k; +// float dl = d * (1 + 2*((x[ibl].scales[l%16] >> 4*(l/16)) & 0xf)); +// for (int i = 0; i < 4; ++i) { +// auto grid1 = (const uint8_t *)(iq3s_grid + x[ibl].qs[32*ib+k+8*i+0] + ((x[ibl].qh[4*ib+k] << (8-i)) & 0x100)); +// auto grid2 = (const uint8_t *)(iq3s_grid + x[ibl].qs[32*ib+k+8*i+4] + ((x[ibl].qh[4*ib+k] << (4-i)) & 0x100)); +// for (int j = 0; j < 4; ++j) { +// y4[k][QK_K*ibl+32*ib+4*i+ 0+j] = dl * grid1[j] * (x[ibl].signs[16*ib+4*k+j] & (1 << (i+0)) ? -1 : 1); +// y4[k][QK_K*ibl+32*ib+4*i+16+j] = dl * grid2[j] * (x[ibl].signs[16*ib+4*k+j] & (1 << (i+4)) ? -1 : 1); +// } +// } +// } +// } +// } +//} + + template static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -4036,11 +4060,13 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); auto m4 = _mm256_set1_epi8(4); #endif + auto smask = _mm256_set1_epi8(1); union { __m256i vec; uint32_t val[8]; } helper; + union { __m128i vec; uint16_t val[8]; } hidx; __m256 acc[nrc_y] = {}; __m256i isum[nrc_y] = {}; - IndexHelperIQ3S ih; __m256i qx[4]; + __mmask32 mask[4]; for (int ix = 0; ix < nrc_x; ix += 4) { auto iq3 = (const block_iq3_s_r4 *)((const char *)vx + (ix+0)*bx); for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 @@ -4048,40 +4074,45 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI auto d4 = _mm256_set_m128(dl, dl); auto qs = iq3[ibl].qs; auto qh = iq3[ibl].qh; - auto sb1 = _mm_loadu_si128((const __m128i *)iq3[ibl].scales); - auto sb2 = _mm_srli_epi16(sb1, 4); - auto scales8 = MM256_SET_M128I(_mm_unpackhi_epi32(sb1, sb2), _mm_unpacklo_epi32(sb1, sb2)); + auto scale_bits = _mm_loadu_si128((const __m128i *)iq3[ibl].scales); + auto scales8 = MM256_SET_M128I(_mm_srli_epi16(scale_bits, 4), scale_bits); helper.vec = _mm256_or_si256(_mm256_slli_epi16(_mm256_and_si256(scales8, _mm256_set1_epi8(0xf)), 1), _mm256_set1_epi8(1)); for (int ib = 0; ib < QK_K/32; ++ib) { - ih.make2(qs+ 0, qh+0, qx+0); - ih.make2(qs+16, qh+2, qx+2); + auto qh32 = (const uint32_t *)qh; + auto idx_h = _mm_sllv_epi64(_mm_cvtepu8_epi16(_mm_set1_epi32(qh32[0])), _mm_set_epi64x(4, 8)); + for (int i = 0; i < 4; ++i) { + auto idx_l = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)(qs + 8*i))); + hidx.vec = _mm_or_si128(idx_l, _mm_and_si128(idx_h, _mm_set1_epi16(0x100))); idx_h = _mm_srli_epi16(idx_h, 1); + qx[i] = _mm256_set_epi32(iq3s_grid[hidx.val[7]], iq3s_grid[hidx.val[6]], iq3s_grid[hidx.val[5]], iq3s_grid[hidx.val[4]], + iq3s_grid[hidx.val[3]], iq3s_grid[hidx.val[2]], iq3s_grid[hidx.val[1]], iq3s_grid[hidx.val[0]]); + } qs += 32; qh += 4; - auto sc16 = _mm_cvtepi8_epi16(_mm_set1_epi32(helper.val[ib])); - auto scales = MM256_SET_M128I(_mm_unpackhi_epi16(sc16, sc16), _mm_unpacklo_epi16(sc16, sc16)); + auto scales = _mm256_cvtepi8_epi32(_mm_set1_epi32(helper.val[ib])); + auto signs128 = _mm_loadu_si128((const __m128i*)iq3[ibl].signs + ib); + auto signs = MM256_SET_M128I(_mm_srli_epi16(signs128, 4), signs128); #ifdef HAVE_FANCY_SIMD - auto mask = (const __mmask32 *)(iq3[ibl].signs + 16*ib); + mask[0] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1); + mask[1] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1); + mask[2] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1); + mask[3] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); for (int iy = 0; iy < nrc_y; ++iy) { auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); - auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_mask_sub_epi8(y, mask[0], _mm256_setzero_si256(), y)); - auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], _mm256_mask_sub_epi8(y, mask[1], _mm256_setzero_si256(), y)); - auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], _mm256_mask_sub_epi8(y, mask[2], _mm256_setzero_si256(), y)); - auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], _mm256_mask_sub_epi8(y, mask[3], _mm256_setzero_si256(), y)); - auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi2), _mm256_unpackhi_epi64(sumi1, sumi2)); // 0,0, 1,1, 0,0, 1,1 - auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi3, sumi4), _mm256_unpackhi_epi64(sumi3, sumi4)); // 2,2, 3,3, 2,2, 3,3 - //auto x1234 = _mm256_packs_epi32(x12, x34); // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 - isum[iy] = _mm256_dpwssd_epi32(isum[iy], scales, _mm256_packs_epi32(s12, s34)); + auto sumi = _mm256_setzero_si256(); + auto ys = _mm256_shuffle_epi32(y, 0x00); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_mask_sub_epi8(ys, mask[0], _mm256_setzero_si256(), ys)); + ys = _mm256_shuffle_epi32(y, 0x55); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_mask_sub_epi8(ys, mask[1], _mm256_setzero_si256(), ys)); + ys = _mm256_shuffle_epi32(y, 0xaa); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_mask_sub_epi8(ys, mask[2], _mm256_setzero_si256(), ys)); + ys = _mm256_shuffle_epi32(y, 0xff); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_mask_sub_epi8(ys, mask[3], _mm256_setzero_si256(), ys)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(sumi, scales)); } #else - auto signs128 = _mm_loadu_si128((const __m128i*)iq3[ibl].signs + ib); - auto signs = MM256_SET_M128I(signs128, signs128); - auto shuffle = sign_shuffle; - auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); - shuffle = _mm256_add_epi8(shuffle, m4); - auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); - shuffle = _mm256_add_epi8(shuffle, m4); - auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); - shuffle = _mm256_add_epi8(shuffle, m4); - auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1); + auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1); + auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1); + auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); for (int iy = 0; iy < nrc_y; ++iy) { auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); auto sumi1 = _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1)); // 16x0 diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 32aca573..5d7ef7b1 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -5716,18 +5716,27 @@ static void repack_iq3_s(int nrows, int n_per_row, const block_iq3_s * x, block_ for (int row = 0; row < nrows; row += 4) { for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; for (int ibl = 0; ibl < nblock; ++ibl) { + std::memset(y[ibl].scales, 0, QK_K/16); + std::memset(y[ibl].signs, 0, QK_K/2); + std::memset(y[ibl].qh, 0, QK_K/8); for (int k = 0; k < 4; ++k) { y[ibl].d[k] = x4[k][ibl].d; for (int ib = 0; ib < QK_K/64; ++ib) { - y[ibl].scales[4*ib+k] = x4[k][ibl].scales[ib]; + int j = 8*ib + k; + y[ibl].scales[(j+0)%16] |= ((x4[k][ibl].scales[ib] & 0xf) << 4*((j+0)/16)); + y[ibl].scales[(j+4)%16] |= ((x4[k][ibl].scales[ib] >> 4) << 4*((j+4)/16)); } for (int ib = 0; ib < QK_K/32; ++ib) { - y[ibl].qh[4*ib+k] = x4[k][ibl].qh[ib]; - for (int i = 0; i < 8; ++i) { - y[ibl].qs[32*ib+8*k+i] = x4[k][ibl].qs[8*ib+i]; + y[ibl].qh[4*ib+k] = x4[k][ibl].qh[ib]; // leave ot like this? + for (int i = 0; i < 4; ++i) { + y[ibl].qs[32*ib+k+8*i+0] = x4[k][ibl].qs[8*ib+i+0]; + y[ibl].qs[32*ib+k+8*i+4] = x4[k][ibl].qs[8*ib+i+4]; } for (int i = 0; i < 4; ++i) { - y[ibl].signs[16*ib+4*k+i] = x4[k][ibl].signs[4*ib+i]; + y[ibl].signs[16*ib+4*k+i] = (((x4[k][ibl].signs[4*ib+0] >> i) & 1) << 0) | (((x4[k][ibl].signs[4*ib+0] >> (4+i)) & 1) << 1) | + (((x4[k][ibl].signs[4*ib+1] >> i) & 1) << 2) | (((x4[k][ibl].signs[4*ib+1] >> (4+i)) & 1) << 3) | + (((x4[k][ibl].signs[4*ib+2] >> i) & 1) << 4) | (((x4[k][ibl].signs[4*ib+2] >> (4+i)) & 1) << 5) | + (((x4[k][ibl].signs[4*ib+3] >> i) & 1) << 6) | (((x4[k][ibl].signs[4*ib+3] >> (4+i)) & 1) << 7); } } } @@ -5759,26 +5768,15 @@ void dequantize_row_iq3_s_r4(const block_iq3_s_r4 * x, float * y, int64_t k) { for (int ibl = 0; ibl < nblock; ++ibl) { for (int k = 0; k < 4; ++k) { const float d = GGML_FP16_TO_FP32(x[ibl].d[k]); - const uint8_t * qs = x[ibl].qs; - const uint8_t * qh = x[ibl].qh; - const uint8_t * signs = x[ibl].signs; - for (int ib = 0; ib < QK_K/64; ++ib) { - const float db1 = d * (1 + 2*(x[ibl].scales[4*ib+k] & 0xf)); - const float db2 = d * (1 + 2*(x[ibl].scales[4*ib+k] >> 4)); + for (int ib = 0; ib < QK_K/32; ++ib) { + int l = 4*ib + k; + float dl = d * (1 + 2*((x[ibl].scales[l%16] >> 4*(l/16)) & 0xf)); for (int i = 0; i < 4; ++i) { - const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[64*ib+8*k+2*i+0] | ((qh[8*ib+k+0] << (8-2*i)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[64*ib+8*k+2*i+1] | ((qh[8*ib+k+0] << (7-2*i)) & 256))); + auto grid1 = (const uint8_t *)(iq3s_grid + x[ibl].qs[32*ib+k+8*i+0] + ((x[ibl].qh[4*ib+k] << (8-i)) & 0x100)); + auto grid2 = (const uint8_t *)(iq3s_grid + x[ibl].qs[32*ib+k+8*i+4] + ((x[ibl].qh[4*ib+k] << (4-i)) & 0x100)); for (int j = 0; j < 4; ++j) { - y4[k][QK_K*ibl+64*ib+j+8*i+0] = db1 * grid1[j] * (signs[32*ib+4*k+i] & kmask_iq2xs[j+0] ? -1.f : 1.f); - y4[k][QK_K*ibl+64*ib+j+8*i+4] = db1 * grid2[j] * (signs[32*ib+4*k+i] & kmask_iq2xs[j+4] ? -1.f : 1.f); - } - } - for (int i = 0; i < 4; ++i) { - const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[64*ib+8*k+2*i+32] | ((qh[8*ib+k+4] << (8-2*i)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[64*ib+8*k+2*i+33] | ((qh[8*ib+k+4] << (7-2*i)) & 256))); - for (int j = 0; j < 4; ++j) { - y4[k][QK_K*ibl+64*ib+j+8*i+32] = db2 * grid1[j] * (signs[32*ib+4*k+i+16] & kmask_iq2xs[j+0] ? -1.f : 1.f); - y4[k][QK_K*ibl+64*ib+j+8*i+36] = db2 * grid2[j] * (signs[32*ib+4*k+i+16] & kmask_iq2xs[j+4] ? -1.f : 1.f); + y4[k][QK_K*ibl+32*ib+4*i+ 0+j] = dl * grid1[j] * (x[ibl].signs[16*ib+4*k+j] & (1 << (i+0)) ? -1 : 1); + y4[k][QK_K*ibl+32*ib+4*i+16+j] = dl * grid2[j] * (x[ibl].signs[16*ib+4*k+j] & (1 << (i+4)) ? -1 : 1); } } }