diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index defe8d10..74fdd9ce 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -7387,63 +7387,86 @@ static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const Da auto m3 = vdupq_n_u8(0x3); int nb = n / QK_IQ1BN; if constexpr (nrc_y == 1) { + //uint8x16x4_t shuff = { + // vreinterpretq_u8_u32(vdupq_n_u32(0x03020100)), + // vreinterpretq_u8_u32(vdupq_n_u32(0x07060504)), + // vreinterpretq_u8_u32(vdupq_n_u32(0x0b0a0908)), + // vreinterpretq_u8_u32(vdupq_n_u32(0x0f0e0d0c)), + //}; auto mc = vdupq_n_u8(0xc); int32x4_t acc[8]; - uint8x16x4_t bits2; - //uint8x16x2_t bits2; for (int ix = 0; ix < nrc_x; ix += 4) { + for (int k = 0; k < 8; ++k) acc[k] = vdupq_n_s32(0); //acc[0] = acc[1] = acc[2] = acc[3] = vdupq_n_s32(0); //acc[4] = acc[5] = acc[6] = acc[7] = vdupq_n_s32(0); const float * dptr = (const float *)((const char *)vx + ix*bx); auto dl = vld1q_f32(dptr); const uint8_t * iq2 = (const uint8_t *)(dptr + 4); - auto bits1 = vld1q_u8_x4(iq2); - for (int k = 0; k < 4; ++k) bits2.val[k] = vshrq_n_u8(bits1.val[k], 4); - auto y = q8.load_quants(0, 0); - for (int k = 0; k < 4; ++k) { - acc[2*k+0] = vdotq_laneq_s32(vdupq_n_s32(0), vandq_u8(bits1.val[k], m3), y.val[k], 0); - acc[2*k+1] = vdotq_laneq_s32(vdupq_n_s32(0), vandq_u8(bits1.val[k], mc), y.val[k], 1); - acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits2.val[k], m3), y.val[k], 2); - acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits2.val[k], mc), y.val[k], 3); + for (int ib = 0; ib < nb; ++ib) { + auto y = q8.load_quants(0, ib); + auto bits1 = vld1q_u8(iq2 + 64*ib); + auto bits2 = vshrq_n_u8(bits1, 4); + acc[0] = vdotq_laneq_s32(acc[0], vandq_u8(bits1, m3), y.val[0], 0); + acc[1] = vdotq_laneq_s32(acc[1], vandq_u8(bits1, mc), y.val[0], 1); + acc[0] = vdotq_laneq_s32(acc[0], vandq_u8(bits2, m3), y.val[0], 2); + acc[1] = vdotq_laneq_s32(acc[1], vandq_u8(bits2, mc), y.val[0], 3); + //acc[0] = vdotq_s32(acc[0], vandq_u8(bits1, m3), vqtbl1q_s8(y.val[0], shuff.val[0])); + //acc[1] = vdotq_s32(acc[1], vandq_u8(bits1, mc), vqtbl1q_s8(y.val[0], shuff.val[1])); + //acc[0] = vdotq_s32(acc[0], vandq_u8(bits2, m3), vqtbl1q_s8(y.val[0], shuff.val[2])); + //acc[1] = vdotq_s32(acc[1], vandq_u8(bits2, mc), vqtbl1q_s8(y.val[0], shuff.val[3])); + bits1 = vld1q_u8(iq2 + 64*ib + 16); + bits2 = vshrq_n_u8(bits1, 4); + acc[2] = vdotq_laneq_s32(acc[2], vandq_u8(bits1, m3), y.val[1], 0); + acc[3] = vdotq_laneq_s32(acc[3], vandq_u8(bits1, mc), y.val[1], 1); + acc[2] = vdotq_laneq_s32(acc[2], vandq_u8(bits2, m3), y.val[1], 2); + acc[3] = vdotq_laneq_s32(acc[3], vandq_u8(bits2, mc), y.val[1], 3); + //acc[2] = vdotq_s32(acc[2], vandq_u8(bits1, m3), vqtbl1q_s8(y.val[1], shuff.val[0])); + //acc[3] = vdotq_s32(acc[3], vandq_u8(bits1, mc), vqtbl1q_s8(y.val[1], shuff.val[1])); + //acc[2] = vdotq_s32(acc[2], vandq_u8(bits2, m3), vqtbl1q_s8(y.val[1], shuff.val[2])); + //acc[3] = vdotq_s32(acc[3], vandq_u8(bits2, mc), vqtbl1q_s8(y.val[1], shuff.val[3])); + bits1 = vld1q_u8(iq2 + 64*ib + 32); + bits2 = vshrq_n_u8(bits1, 4); + acc[4] = vdotq_laneq_s32(acc[4], vandq_u8(bits1, m3), y.val[2], 0); + acc[5] = vdotq_laneq_s32(acc[5], vandq_u8(bits1, mc), y.val[2], 1); + acc[4] = vdotq_laneq_s32(acc[4], vandq_u8(bits2, m3), y.val[2], 2); + acc[5] = vdotq_laneq_s32(acc[5], vandq_u8(bits2, mc), y.val[2], 3); + //acc[4] = vdotq_s32(acc[4], vandq_u8(bits1, m3), vqtbl1q_s8(y.val[2], shuff.val[0])); + //acc[5] = vdotq_s32(acc[5], vandq_u8(bits1, mc), vqtbl1q_s8(y.val[2], shuff.val[1])); + //acc[4] = vdotq_s32(acc[4], vandq_u8(bits2, m3), vqtbl1q_s8(y.val[2], shuff.val[2])); + //acc[5] = vdotq_s32(acc[5], vandq_u8(bits2, mc), vqtbl1q_s8(y.val[2], shuff.val[3])); + bits1 = vld1q_u8(iq2 + 64*ib + 48); + bits2 = vshrq_n_u8(bits1, 4); + acc[6] = vdotq_laneq_s32(acc[6], vandq_u8(bits1, m3), y.val[3], 0); + acc[7] = vdotq_laneq_s32(acc[7], vandq_u8(bits1, mc), y.val[3], 1); + acc[6] = vdotq_laneq_s32(acc[6], vandq_u8(bits2, m3), y.val[3], 2); + acc[7] = vdotq_laneq_s32(acc[7], vandq_u8(bits2, mc), y.val[3], 3); + //acc[6] = vdotq_s32(acc[6], vandq_u8(bits1, m3), vqtbl1q_s8(y.val[3], shuff.val[0])); + //acc[7] = vdotq_s32(acc[7], vandq_u8(bits1, mc), vqtbl1q_s8(y.val[3], shuff.val[1])); + //acc[6] = vdotq_s32(acc[6], vandq_u8(bits2, m3), vqtbl1q_s8(y.val[3], shuff.val[2])); + //acc[7] = vdotq_s32(acc[7], vandq_u8(bits2, mc), vqtbl1q_s8(y.val[3], shuff.val[3])); } - for (int ib = 1; ib < nb; ++ib) { - bits1 = vld1q_u8_x4(iq2 + 64*ib); - for (int k = 0; k < 4; ++k) bits2.val[k] = vshrq_n_u8(bits1.val[k], 4); - y = q8.load_quants(0, ib); - for (int k = 0; k < 4; ++k) { - acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits1.val[k], m3), y.val[k], 0); - acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits1.val[k], mc), y.val[k], 1); - acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits2.val[k], m3), y.val[k], 2); - acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits2.val[k], mc), y.val[k], 3); - } - //auto bits1 = vld1q_u8_x2(iq2 + 64*ib); - //for (int k = 0; k < 2; ++k) bits2.val[k] = vshrq_n_u8(bits1.val[k], 4); - //auto y = q8.load_quants_32(0, 2*ib+0); - //for (int k = 0; k < 2; ++k) { - // acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits1.val[k], m3), y.val[k], 0); - // acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits1.val[k], mc), y.val[k], 1); - // acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits2.val[k], m3), y.val[k], 2); - // acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits2.val[k], mc), y.val[k], 3); - //} - //bits1 = vld1q_u8_x2(iq2 + 64*ib + 32); - //for (int k = 0; k < 2; ++k) bits2.val[k] = vshrq_n_u8(bits1.val[k], 4); - //y = q8.load_quants_32(0, 2*ib+1); - //for (int k = 0; k < 2; ++k) { - // acc[2*k+4] = vdotq_laneq_s32(acc[2*k+4], vandq_u8(bits1.val[k], m3), y.val[k], 0); - // acc[2*k+5] = vdotq_laneq_s32(acc[2*k+5], vandq_u8(bits1.val[k], mc), y.val[k], 1); - // acc[2*k+4] = vdotq_laneq_s32(acc[2*k+4], vandq_u8(bits2.val[k], m3), y.val[k], 2); - // acc[2*k+5] = vdotq_laneq_s32(acc[2*k+5], vandq_u8(bits2.val[k], mc), y.val[k], 3); - //} - } - auto dy = q8.scale(0); - float32x4_t s = vfmaq_f32(vcvtq_f32_s32(acc[0]), vdupq_n_f32(0.25f), vcvtq_f32_s32(acc[1])); - float32x4_t sumf = vmulq_f32(s, vmulq_laneq_f32(dl, dy, 0)); - s = vfmaq_f32(vcvtq_f32_s32(acc[2]), vdupq_n_f32(0.25f), vcvtq_f32_s32(acc[3])); - sumf = vfmaq_f32(sumf, s, vmulq_laneq_f32(dl, dy, 1)); - s = vfmaq_f32(vcvtq_f32_s32(acc[4]), vdupq_n_f32(0.25f), vcvtq_f32_s32(acc[5])); - sumf = vfmaq_f32(sumf, s, vmulq_laneq_f32(dl, dy, 2)); - s = vfmaq_f32(vcvtq_f32_s32(acc[6]), vdupq_n_f32(0.25f), vcvtq_f32_s32(acc[7])); - sumf = vfmaq_f32(sumf, s, vmulq_laneq_f32(dl, dy, 3)); + //auto dy = q8.scale(0); + //auto sumf1 = vmulq_f32( vcvtq_f32_s32(acc[0]), vmulq_laneq_f32(dl, dy, 0)); + //auto sumf2 = vmulq_f32( vcvtq_f32_s32(acc[1]), vmulq_laneq_f32(dl, dy, 0)); + //sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[2]), vmulq_laneq_f32(dl, dy, 1)); + //sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[3]), vmulq_laneq_f32(dl, dy, 1)); + //sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[4]), vmulq_laneq_f32(dl, dy, 2)); + //sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[5]), vmulq_laneq_f32(dl, dy, 2)); + //sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[6]), vmulq_laneq_f32(dl, dy, 3)); + //sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[7]), vmulq_laneq_f32(dl, dy, 3)); + auto dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 0))); + auto sumf1 = vmulq_f32( vcvtq_f32_s32(acc[0]), dy); + auto sumf2 = vmulq_f32( vcvtq_f32_s32(acc[1]), dy); + dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 1))); + sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[2]), dy); + sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[3]), dy); + dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 2))); + sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[4]), dy); + sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[5]), dy); + dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 3))); + sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[6]), dy); + sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[7]), dy); + auto sumf = vfmaq_f32(sumf1, vdupq_n_f32(0.25f), sumf2); sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(0))); info.store(ix, 0, sumf); }