diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index e5bf4967..38e76e1e 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -397,6 +397,10 @@ struct Trellis1 { auto x16 = gen8(val1, val2); return { vcvt_f32_f16(vget_low_f16(x16)), vcvt_f32_f16(vget_high_f16(x16)) }; } + inline float32x4x2_t gen8_f32(uint32_t val1, uint32_t val2, float16x8_t scale) const { + auto x16 = vmulq_f16(gen8(val1, val2), scale); + return { vcvt_f32_f16(vget_low_f16(x16)), vcvt_f32_f16(vget_high_f16(x16)) }; + } }; template @@ -640,10 +644,10 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn const uint8_t * qh = ql + kNumGroups; for (int ib = 0; ib < 4; ++ib) { - const float x_scale1 = (int)((shb[ib+0] & 0xff) >> 1) - 64; - const float x_scale2 = (int)((shb[ib+4] & 0xff) >> 1) - 64; - const float32x4_t scale1 = vdupq_n_f32(x_scale1); - const float32x4_t scale2 = vdupq_n_f32(x_scale2); + const uint16_t x_scale1 = (int16_t)((shb[ib+0] & 0xff) >> 1) - 64; + const uint16_t x_scale2 = (int16_t)((shb[ib+4] & 0xff) >> 1) - 64; + const float16x8_t scale1 = vcvtq_f16_s16(vdupq_n_s16(x_scale1)); + const float16x8_t scale2 = vcvtq_f16_s16(vdupq_n_s16(x_scale2)); const uint32_t offset1 = 4096 + ((shb[ib+0] & 1) << 15); const uint32_t offset2 = 4096 + ((shb[ib+4] & 1) << 15); @@ -660,12 +664,8 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn sh1 >>= 6; sh2 >>= 6; - auto x1 = trellis.gen8_f32(val1, val3); - auto x2 = trellis.gen8_f32(val2, val4); - x1.val[0] = vmulq_f32(scale1, x1.val[0]); - x1.val[1] = vmulq_f32(scale1, x1.val[1]); - x2.val[0] = vmulq_f32(scale2, x2.val[0]); - x2.val[1] = vmulq_f32(scale2, x2.val[1]); + auto x1 = trellis.gen8_f32(val1, val3, scale1); + auto x2 = trellis.gen8_f32(val2, val4, scale2); for (int iy = 0; iy < nrc_y; ++iy) { auto y1 = vld1q_f32_x2(y[iy] + i*QK_K + 32*ib + 8*j);