diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp index 162b55e6..3b598718 100644 --- a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp @@ -2782,6 +2782,74 @@ void mul_mat_q8_0_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf } } +typedef struct { + ggml_half d[16]; + int8_t qs[256]; +} block_q8_1_r8; + +template +void mul_mat_q8_1_r8_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + int nb = n / QK8_0; + float32x4_t acc[2*nrc_y] = {}; + int8x16_t qx[16]; + float d8[8*nrc_y]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_1_r8 * iq8 = (const block_q8_1_r8 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + vst1q_f32(d8+8*iy+0, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d+0))); + vst1q_f32(d8+8*iy+4, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d+4))); + } + for (int k = 0; k < 4; ++k) { + auto scales16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d); + auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16)); + auto scales2 = vcvt_f32_f16(vget_high_f16(scales16)); + auto m16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d+8); + auto m1 = vcvt_f32_f16(vget_low_f16 (m16)); + auto m2 = vcvt_f32_f16(vget_high_f16(m16)); + for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j); + int32x4_t sumi1, sumi2; + for (int iy = 0; iy < nrc_y; ++iy) { + qx_0_q8_0_dot(qx, q8.y[iy][ib4].qs+32*k, sumi1, sumi2); + auto dy = vdupq_n_f32(d8[8*iy+k]); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1)); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2)); + auto my = vdupq_n_f32(d8[8*iy+k+4]); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], m1, my); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], m2, my); + } + } + } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales16 = vld1q_f16((const float16_t *)iq8[ib].d); + auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16)); + auto scales2 = vcvt_f32_f16(vget_high_f16(scales16)); + auto m16 = vld1q_f16((const float16_t *)iq8[ib].d+8); + auto m1 = vcvt_f32_f16(vget_low_f16 (m16)); + auto m2 = vcvt_f32_f16(vget_high_f16(m16)); + for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[ib].qs + 16*j); + int32x4_t sumi1, sumi2; + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + qx_0_q8_0_dot(qx, qy[ib].qs, sumi1, sumi2); + auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d)); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1)); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2)); + auto my = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].s)); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], m1, my); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], m2, my); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix+0, iy, acc[2*iy+0]); + info.store(ix+4, iy, acc[2*iy+1]); + acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f); + } + } +} + struct DeqQ40 { const int8x16_t m8 = vdupq_n_s8(-8); const uint8x16_t ml = vdupq_n_s8(0xf); @@ -2791,6 +2859,14 @@ struct DeqQ40 { } }; +struct DeqQ41 { + const uint8x16_t ml = vdupq_n_s8(0xf); + inline int8x16x2_t dequant(const block_q4_1& x) const { + auto bits = vld1q_u8(x.qs); + return { vreinterpretq_s8_u8(vandq_u8(bits, ml)), vreinterpretq_s8_u8(vshrq_n_u8(bits, 4)) }; + } +}; + struct DeqIQ4NL { const int8x16_t mt = load_values(); const uint8x16_t ml = vdupq_n_s8(0xf); @@ -2874,12 +2950,47 @@ void iqk_convert_qX_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc } } +template +void iqk_convert_qX_1_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK4_0 == 0); + GGML_ASSERT(nrc_x%8 == 0); + + const int nb = n/QK8_0; + + block_q8_1_r8 * y = (block_q8_1_r8 *)vy; + + const Block * x8[8]; + + uint32_t block[8]; + + Dequantizer deq; + + for (int ix = 0; ix < nrc_x; ix += 8) { + + for (int k = 0; k < 8; ++k) x8[k] = (const Block *)((const char *)vx + (ix + k)*bx); + + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + y[i].d[k+0] = x8[k][i].d; + y[i].d[k+8] = x8[k][i].m; + vst1q_s8_x2((int8_t *)block, deq.dequant(x8[k][i])); + auto qs = (uint32_t *)y[i].qs; + for (int l = 0; l < 4; ++l) { + qs[8*l + k + 0] = block[l + 0]; + qs[8*l + k + 32] = block[l + 4]; + } + } + } + y += nb; + } +} + } bool iqk_convert_legacy_quants_q8_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) { switch (type) { case GGML_TYPE_Q4_0 : iqk_convert_qX_q80_r8(n, vx, bx, vy, nrc_x); break; - // case GGML_TYPE_Q4_1 : iqk_convert_qX_1_q8_1_r8(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_Q4_1 : iqk_convert_qX_1_q8_1_r8(n, vx, bx, vy, nrc_x); break; case GGML_TYPE_Q5_0 : iqk_convert_qX_q80_r8(n, vx, bx, vy, nrc_x); break; // case GGML_TYPE_Q5_1 : iqk_convert_qX_1_q8_1_r8>(n, vx, bx, vy, nrc_x); break; case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8(n, vx, bx, vy, nrc_x); break; @@ -2895,7 +3006,7 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_Q4_1 : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; case GGML_TYPE_Q5_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; @@ -918,6 +919,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_Q5_0_R4: case GGML_TYPE_Q6_0_R4: case GGML_TYPE_Q8_0_R8: + case GGML_TYPE_Q8_1: case GGML_TYPE_IQ4_NL_R4: return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, m.funcs, m.func16); case GGML_TYPE_IQ1_BN: