diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 0a17fe83..d4547d4a 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1353,6 +1353,9 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& return vpaddq_s32(dot.val[0], dot.val[2]); }; + //int32x4x2_t shifts = {int32x4_t{-8, -11, -14, -17}, int32x4_t{-20, -23, -26, -29}}; + int32x4x2_t shifts = {int32x4_t{4, 1, -2, -5}, int32x4_t{-8, -11, -14, -17}}; + float32x4x2_t scales; for (int ix = 0; ix < nrc_x; ++ix) { @@ -1376,14 +1379,33 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& o_helper.vec.val[0] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[0], vdupq_n_u32(1)), 15), vdupq_n_u32(4096)); o_helper.vec.val[1] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[1], vdupq_n_u32(1)), 15), vdupq_n_u32(4096)); for (int ib = 0; ib < 4; ++ib) { - for (int j = 0; j < 4; ++j) { - const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); - const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); - values[2*j+0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; - values[2*j+1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; - values[2*j+8] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; - values[2*j+9] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; - } + auto vql1 = vmovl_u8(vld1_u8(ql+8*ib)); + auto vql2 = vmovl_u8(vld1_u8(ql+8*ib+32)); + auto vqh = vmovl_u8(vld1_u8(qh+8*ib)); + vql1 = vaddq_u16(vql1, vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(vqh, 8))); + vql2 = vaddq_u16(vql2, vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(vqh, 4))); + auto sh1_u32 = vdupq_n_u32(shb[ib+0]); + auto sh2_u32 = vdupq_n_u32(shb[ib+4]); + auto sh1 = vcombine_u16(vmovn_u32(vshlq_u32(sh1_u32, shifts.val[0])), vmovn_u32(vshlq_u32(sh1_u32, shifts.val[1]))); + auto sh2 = vcombine_u16(vmovn_u32(vshlq_u32(sh2_u32, shifts.val[0])), vmovn_u32(vshlq_u32(sh2_u32, shifts.val[1]))); + vql1 = vaddq_u16(vql1, vandq_u16(vdupq_n_u16(0x7000), sh1)); + vql2 = vaddq_u16(vql2, vandq_u16(vdupq_n_u16(0x7000), sh2)); + auto oh1 = vdupq_n_u32(o_helper.val[ib+0]); + auto oh2 = vdupq_n_u32(o_helper.val[ib+4]); + vst1q_u32(values +0, vaddq_u32(vmovl_u16(vget_low_u16 (vql1)), oh1)); + vst1q_u32(values +4, vaddq_u32(vmovl_u16(vget_high_u16(vql1)), oh1)); + vst1q_u32(values +8, vaddq_u32(vmovl_u16(vget_low_u16 (vql2)), oh2)); + vst1q_u32(values+12, vaddq_u32(vmovl_u16(vget_high_u16(vql2)), oh2)); + //auto sh1 = vshlq_u32(vdupq_n_u32(shb[ib+0]), shifts); + //auto sh2 = vshlq_u32(vdupq_n_u32(shb[ib+4]), shifts); + //for (int j = 0; j < 4; ++j) { + // const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); + // const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); + // values[2*j+0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; + // values[2*j+1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; + // values[2*j+8] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; + // values[2*j+9] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; + //} xv[ib+0] = trellis.next32(values+0); xv[ib+4] = trellis.next32(values+8); }