diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 554de583..defe8d10 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -7386,63 +7386,126 @@ static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const Da Q8_16 q8(info); auto m3 = vdupq_n_u8(0x3); int nb = n / QK_IQ1BN; - int32x4_t acc[4*nrc_y] = {}; - uint8x16_t qx[8]; - for (int ix = 0; ix < nrc_x; ix += 4) { - const float * dptr = (const float *)((const char *)vx + ix*bx); - auto dl = vld1q_f32(dptr); - const uint8_t * iq2 = (const uint8_t *)(dptr + 4); - for (int ib = 0; ib < nb; ++ib) { - auto bits = vld1q_u8_x2(iq2 + 64*ib); - qx[0] = vandq_u8(bits.val[0], m3); - qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3); - qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3); - qx[3] = vshrq_n_u8(bits.val[0], 6); - qx[4] = vandq_u8(bits.val[1], m3); - qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3); - qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3); - qx[7] = vshrq_n_u8(bits.val[1], 6); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = q8.load_quants_32(iy, 2*ib+0); - acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[0], y.val[0], 0); - acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[1], y.val[0], 1); - acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[2], y.val[0], 2); - acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[3], y.val[0], 3); - acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[4], y.val[1], 0); - acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[5], y.val[1], 1); - acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[6], y.val[1], 2); - acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[7], y.val[1], 3); + if constexpr (nrc_y == 1) { + auto mc = vdupq_n_u8(0xc); + int32x4_t acc[8]; + uint8x16x4_t bits2; + //uint8x16x2_t bits2; + for (int ix = 0; ix < nrc_x; ix += 4) { + //acc[0] = acc[1] = acc[2] = acc[3] = vdupq_n_s32(0); + //acc[4] = acc[5] = acc[6] = acc[7] = vdupq_n_s32(0); + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = vld1q_f32(dptr); + const uint8_t * iq2 = (const uint8_t *)(dptr + 4); + auto bits1 = vld1q_u8_x4(iq2); + for (int k = 0; k < 4; ++k) bits2.val[k] = vshrq_n_u8(bits1.val[k], 4); + auto y = q8.load_quants(0, 0); + for (int k = 0; k < 4; ++k) { + acc[2*k+0] = vdotq_laneq_s32(vdupq_n_s32(0), vandq_u8(bits1.val[k], m3), y.val[k], 0); + acc[2*k+1] = vdotq_laneq_s32(vdupq_n_s32(0), vandq_u8(bits1.val[k], mc), y.val[k], 1); + acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits2.val[k], m3), y.val[k], 2); + acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits2.val[k], mc), y.val[k], 3); } - bits = vld1q_u8_x2(iq2 + 64*ib + 32); - qx[0] = vandq_u8(bits.val[0], m3); - qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3); - qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3); - qx[3] = vshrq_n_u8(bits.val[0], 6); - qx[4] = vandq_u8(bits.val[1], m3); - qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3); - qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3); - qx[7] = vshrq_n_u8(bits.val[1], 6); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = q8.load_quants_32(iy, 2*ib+1); - acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[0], y.val[0], 0); - acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[1], y.val[0], 1); - acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[2], y.val[0], 2); - acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[3], y.val[0], 3); - acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[4], y.val[1], 0); - acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[5], y.val[1], 1); - acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[6], y.val[1], 2); - acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[7], y.val[1], 3); + for (int ib = 1; ib < nb; ++ib) { + bits1 = vld1q_u8_x4(iq2 + 64*ib); + for (int k = 0; k < 4; ++k) bits2.val[k] = vshrq_n_u8(bits1.val[k], 4); + y = q8.load_quants(0, ib); + for (int k = 0; k < 4; ++k) { + acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits1.val[k], m3), y.val[k], 0); + acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits1.val[k], mc), y.val[k], 1); + acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits2.val[k], m3), y.val[k], 2); + acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits2.val[k], mc), y.val[k], 3); + } + //auto bits1 = vld1q_u8_x2(iq2 + 64*ib); + //for (int k = 0; k < 2; ++k) bits2.val[k] = vshrq_n_u8(bits1.val[k], 4); + //auto y = q8.load_quants_32(0, 2*ib+0); + //for (int k = 0; k < 2; ++k) { + // acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits1.val[k], m3), y.val[k], 0); + // acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits1.val[k], mc), y.val[k], 1); + // acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits2.val[k], m3), y.val[k], 2); + // acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits2.val[k], mc), y.val[k], 3); + //} + //bits1 = vld1q_u8_x2(iq2 + 64*ib + 32); + //for (int k = 0; k < 2; ++k) bits2.val[k] = vshrq_n_u8(bits1.val[k], 4); + //y = q8.load_quants_32(0, 2*ib+1); + //for (int k = 0; k < 2; ++k) { + // acc[2*k+4] = vdotq_laneq_s32(acc[2*k+4], vandq_u8(bits1.val[k], m3), y.val[k], 0); + // acc[2*k+5] = vdotq_laneq_s32(acc[2*k+5], vandq_u8(bits1.val[k], mc), y.val[k], 1); + // acc[2*k+4] = vdotq_laneq_s32(acc[2*k+4], vandq_u8(bits2.val[k], m3), y.val[k], 2); + // acc[2*k+5] = vdotq_laneq_s32(acc[2*k+5], vandq_u8(bits2.val[k], mc), y.val[k], 3); + //} } + auto dy = q8.scale(0); + float32x4_t s = vfmaq_f32(vcvtq_f32_s32(acc[0]), vdupq_n_f32(0.25f), vcvtq_f32_s32(acc[1])); + float32x4_t sumf = vmulq_f32(s, vmulq_laneq_f32(dl, dy, 0)); + s = vfmaq_f32(vcvtq_f32_s32(acc[2]), vdupq_n_f32(0.25f), vcvtq_f32_s32(acc[3])); + sumf = vfmaq_f32(sumf, s, vmulq_laneq_f32(dl, dy, 1)); + s = vfmaq_f32(vcvtq_f32_s32(acc[4]), vdupq_n_f32(0.25f), vcvtq_f32_s32(acc[5])); + sumf = vfmaq_f32(sumf, s, vmulq_laneq_f32(dl, dy, 2)); + s = vfmaq_f32(vcvtq_f32_s32(acc[6]), vdupq_n_f32(0.25f), vcvtq_f32_s32(acc[7])); + sumf = vfmaq_f32(sumf, s, vmulq_laneq_f32(dl, dy, 3)); + sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(0))); + info.store(ix, 0, sumf); } - for (int iy = 0; iy < nrc_y; ++iy) { - auto dy = q8.scale(iy); - float32x4_t sumf = vmulq_f32(vcvtq_f32_s32(acc[4*iy+0]), vmulq_laneq_f32(dl, dy, 0)); - sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+1]), vmulq_laneq_f32(dl, dy, 1)); - sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+2]), vmulq_laneq_f32(dl, dy, 2)); - sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+3]), vmulq_laneq_f32(dl, dy, 3)); - sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(iy))); - info.store(ix, iy, sumf); - acc[4*iy+0] = acc[4*iy+1] = acc[4*iy+2] = acc[4*iy+3] = vdupq_n_f32(0.f); + } else { + int32x4_t acc[4*nrc_y] = {}; + uint8x16_t qx[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = vld1q_f32(dptr); + const uint8_t * iq2 = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits = vld1q_u8_x2(iq2 + 64*ib); + qx[0] = vandq_u8(bits.val[0], m3); + qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3); + qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3); + qx[3] = vshrq_n_u8(bits.val[0], 6); + qx[4] = vandq_u8(bits.val[1], m3); + qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3); + qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3); + qx[7] = vshrq_n_u8(bits.val[1], 6); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants_32(iy, 2*ib+0); + acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[0], y.val[0], 0); + acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[1], y.val[0], 1); + acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[2], y.val[0], 2); + acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[3], y.val[0], 3); + acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[4], y.val[1], 0); + acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[5], y.val[1], 1); + acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[6], y.val[1], 2); + acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[7], y.val[1], 3); + } + bits = vld1q_u8_x2(iq2 + 64*ib + 32); + qx[0] = vandq_u8(bits.val[0], m3); + qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3); + qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3); + qx[3] = vshrq_n_u8(bits.val[0], 6); + qx[4] = vandq_u8(bits.val[1], m3); + qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3); + qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3); + qx[7] = vshrq_n_u8(bits.val[1], 6); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants_32(iy, 2*ib+1); + acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[0], y.val[0], 0); + acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[1], y.val[0], 1); + acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[2], y.val[0], 2); + acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[3], y.val[0], 3); + acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[4], y.val[1], 0); + acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[5], y.val[1], 1); + acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[6], y.val[1], 2); + acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[7], y.val[1], 3); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + float32x4_t sumf = vmulq_f32(vcvtq_f32_s32(acc[4*iy+0]), vmulq_laneq_f32(dl, dy, 0)); + sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+1]), vmulq_laneq_f32(dl, dy, 1)); + sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+2]), vmulq_laneq_f32(dl, dy, 2)); + sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+3]), vmulq_laneq_f32(dl, dy, 3)); + sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(iy))); + info.store(ix, iy, sumf); + acc[4*iy+0] = acc[4*iy+1] = acc[4*iy+2] = acc[4*iy+3] = vdupq_n_s32(0); + } } } } @@ -7998,9 +8061,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { m.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>; m.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>; m.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>; - m.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>; - m.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>; - m.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>; + //m.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>; + //m.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>; + //m.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>; expected_Btype = GGML_TYPE_Q8_K16; break; case GGML_TYPE_Q4_0: