diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index a8ce5464..57d914ad 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -10643,6 +10643,69 @@ static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const Dat } } +template +static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nbl = n / QK_K; + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[nrc_y] = {}; + int8x16_t qx[8]; + SignHelper sh; + uint32_t stored_scales[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq3 = (const block_iq3_s_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d)); + auto qs = iq3[ibl].qs; + auto qh = iq3[ibl].qh; + auto scale_bits = vld1q_u8(iq3[ibl].scales); + uint8x16x2_t scales8 = { vandq_u8(scale_bits, vdupq_n_u8(0xf)), vshrq_n_u8(scale_bits, 4) }; + auto tmp = vzip1q_u32(scales8.val[0], scales8.val[1]); + scales8.val[1] = vzip2q_u32(scales8.val[0], scales8.val[1]); + scales8.val[0] = vorrq_u8(vshlq_n_u8(tmp, 1), vdupq_n_u8(1)); + scales8.val[1] = vorrq_u8(vshlq_n_u8(scales8.val[1], 1), vdupq_n_u8(1)); + vst1q_u8_x2((uint8_t *)stored_scales, scales8); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto signs128 = vld1q_u8(iq3[ibl].signs+16*ib); + sh.init(); + for (int i = 0; i < 4; ++i) { + qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[8*i+0] | ((qh[i] << 8) & 0x100)], iq3s_grid[qs[8*i+1] | ((qh[i] << 7) & 0x100)], + iq3s_grid[qs[8*i+2] | ((qh[i] << 6) & 0x100)], iq3s_grid[qs[8*i+3] | ((qh[i] << 5) & 0x100)]}); + sh.apply_signs_1((uint8x16_t *)qx+2*i+0, signs128); + qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[8*i+4] | ((qh[i] << 4) & 0x100)], iq3s_grid[qs[8*i+5] | ((qh[i] << 3) & 0x100)], + iq3s_grid[qs[8*i+6] | ((qh[i] << 2) & 0x100)], iq3s_grid[qs[8*i+7] | ((qh[i] << 1) & 0x100)]}); + sh.apply_signs_1((uint8x16_t *)qx+2*i+1, signs128); + } + auto sc16 = vmovl_s8(vreinterpret_s8_u32(vdup_n_u32(stored_scales[ib]))); + auto scales = vmovl_s16(vget_low_s16(sc16)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib); + //auto sumi = interleaved_dotq(qx, y); + auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]); + auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]); + auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]); + auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]); + auto sumi12 = vpaddq_s32(sumi1, sumi2); + auto sumi34 = vpaddq_s32(sumi3, sumi4); + auto sumi = vpaddq_s32(sumi12, sumi34); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + qs += 32; + qh += 4; + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + isum[iy] = vdupq_n_s32(0); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + template inline void iq3_4_add_shift(int ibl, const Q8& q8, const int8x16x4_t& i8scales, uint8x16_t extra, int32x4_t * isum) { @@ -11960,6 +12023,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { m.func16 = mul_mat_iq3_xxs_r4_q8_k<16>; expected_Btype = GGML_TYPE_Q8_K; break; + case GGML_TYPE_IQ3_S_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_s_r4_q8_k); + m.func16 = mul_mat_iq3_s_r4_q8_k<16>; + expected_Btype = GGML_TYPE_Q8_K; + break; case GGML_TYPE_Q2_K_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_q2_k_r4_q8_k); expected_Btype = GGML_TYPE_Q8_K;