diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 3bfded73..77749466 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3666,6 +3666,65 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI } } +// sum[ qy_i * ls_k * (qx_i - 1+/-delta_k)] +// = sum[qy_i * qx_i * ls_k] - 1/8*sum[qy_i * ls_k * (8+/-o_k)] +// = 1/8 * ( sum[qy_i * qx_i * 8*ls+k] - sum[qy_i * ls_k * (8+/-o_k)] ) + +template +static void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + Q8 q8(info); + __m256i qx[8]; + __m256 acc[nrc_y] = {}; + auto delta_mask = _mm_set1_epi16(-32768); // to avoid stupid overflow warnings when using 0x8000 + __m256i shuffle0 = _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100); + for (int ix = 0; ix < nrc_x; ++ix) { + auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < n/QK_K; ++ibl) { + float d = GGML_FP16_TO_FP32(iq1s[ibl].d); + auto qhb = _mm_loadu_si128((const __m128i *)iq1s[ibl].qh); + auto scales128 = _mm_and_si128(_mm_srli_epi16(qhb, 12), _mm_set1_epi16(7)); + scales128 = _mm_add_epi16(_mm_slli_epi16(scales128, 1), _mm_set1_epi16(1)); + auto mask = _mm_cmpeq_epi16_mask(_mm_and_si128(qhb, delta_mask), delta_mask); + auto deltas128 = _mm_mask_blend_epi16(mask, _mm_set1_epi16(-7), _mm_set1_epi16(-9)); + deltas128 = _mm_mullo_epi16(scales128, deltas128); + scales128 = _mm_slli_epi16(scales128, 3); + auto deltas_l = _mm_unpacklo_epi16(deltas128, deltas128); + auto deltas_h = _mm_unpackhi_epi16(deltas128, deltas128); + auto deltas = MM256_SET_M128I(deltas_h, deltas_l); // blocks 0,0, 1,1, 2,2, ..., 7,7 + auto scales = MM256_SET_M128I(scales128, scales128); + const uint8_t * qs = iq1s[ibl].qs; + const uint16_t * qh = iq1s[ibl].qh; + for (int ib = 0; ib < QK_K/32; ib += 2) { + qx[ib+0] = _mm256_set_epi64x(iq1s_grid_us[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid_us[qs[2] | ((qh[ib+0] << 2) & 0x700)], + iq1s_grid_us[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid_us[qs[0] | ((qh[ib+0] << 8) & 0x700)]); + qx[ib+1] = _mm256_set_epi64x(iq1s_grid_us[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid_us[qs[6] | ((qh[ib+1] << 2) & 0x700)], + iq1s_grid_us[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid_us[qs[4] | ((qh[ib+1] << 8) & 0x700)]); + qs += 8; + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = q8.load_bsums(iy, ibl); + auto shuffle = shuffle0; + auto sumi = _mm256_setzero_si256(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + auto qy1 = q8.load_quants(iy, ibl, ib+0); + auto qy2 = q8.load_quants(iy, ibl, ib+1); + auto dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[ib+0], qy1); + auto dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[ib+1], qy2); + sumi = _mm256_dpwssd_epi32(sumi, _mm256_shuffle_epi8(scales, shuffle), _mm256_packs_epi32(dot1, dot2)); + shuffle = _mm256_add_epi8(shuffle, _mm256_set1_epi8(4)); + } + sumi = _mm256_dpwssd_epi32(sumi, bsums, deltas); + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d*q8.scale(iy, ibl)), _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, 0.125f*hsum_float_8(acc[iy])); + acc[iy] = _mm256_setzero_ps(); + } + } +} + template static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -9473,6 +9532,20 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[7] = mul_mat_q8_0_r8_q8_1<8>; expected_typeB = GGML_TYPE_Q8_1_X4; break; + case GGML_TYPE_IQ1_S: + mm.funcs[0] = mul_mat_iq1_s_q8_K<1>; + mm.funcs[1] = mul_mat_iq1_s_q8_K<2>; + mm.funcs[2] = mul_mat_iq1_s_q8_K<3>; + mm.funcs[3] = mul_mat_iq1_s_q8_K<4>; + mm.funcs[4] = mul_mat_iq1_s_q8_K<5>; + mm.funcs[5] = mul_mat_iq1_s_q8_K<6>; + mm.funcs[6] = mul_mat_iq1_s_q8_K<7>; + mm.funcs[7] = mul_mat_iq1_s_q8_K<8>; +#ifdef HAVE_FANCY_SIMD + mm.func16 = mul_mat_iq1_s_q8_K<16>; +#endif + expected_typeB = GGML_TYPE_Q8_K; + break; case GGML_TYPE_IQ1_S_R4: assert (ne00 % QK4_NL == 0); mm.funcs[0] = mul_mat_iq1_s_r4_q8_1<1>;