diff --git a/ggml/src/iqk/iqk_gemm_iquants.cpp b/ggml/src/iqk/iqk_gemm_iquants.cpp index 152b6247..782e48d8 100644 --- a/ggml/src/iqk/iqk_gemm_iquants.cpp +++ b/ggml/src/iqk/iqk_gemm_iquants.cpp @@ -1884,6 +1884,315 @@ struct DequantizerIQ3S final : public BaseDequantizer { }; +template +static void mul_mat_iq2_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nbl = n / QK_K; + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[nrc_y] = {}; + int8x16_t qx[8]; + SignHelper sh; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq2 = (const block_iq2_xxs_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d)); + auto qs = iq2[ibl].qs; + for (int ib = 0; ib < QK_K/32; ++ib) { + auto sas = vld1q_u8(iq2[ibl].sas + 16*ib); + auto scale_bits = vandq_u8(sas, vdupq_n_u8(1)); + auto scales = ggml_vdotq_s32(vdupq_n_s32(1), scale_bits, vreinterpretq_s8_u32(vdupq_n_u32(0x10080402))); + auto signs128 = vandq_u8(sas, vdupq_n_u8(254)); + signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1)); + sh.init(); + for (int i = 0; i < 8; ++i) { + qx[i] = vreinterpretq_s8_u64(uint64x2_t{iq2xxs_grid[qs[2*i+0]], iq2xxs_grid[qs[2*i+1]]}); + sh.apply_signs_1((uint8x16_t *)qx+i, signs128); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib); + auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]); + auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]); + auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]); + auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]); + auto sumi12 = vpaddq_s32(sumi1, sumi2); + auto sumi34 = vpaddq_s32(sumi3, sumi4); + auto sumi = vpaddq_s32(sumi12, sumi34); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + qs += 16; + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + isum[iy] = vdupq_n_s32(0); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy])); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nbl = n / QK_K; + static const uint8_t k_shuff[16] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; + auto shuff = vld1q_u8(k_shuff); + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[2*nrc_y] = {}; + int8x16_t qx[8]; + uint16x8x4_t scales16; + SignHelper sh; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq2 = (const block_iq2_xs_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d)); + auto qs = iq2[ibl].qs; + for (int is = 0; is < 2; ++is) { + auto scale_bits = vld1q_u8(iq2[ibl].scales + 16*is); + auto scales1 = vandq_u8(scale_bits, vdupq_n_u8(0xf)); + auto scales2 = vshrq_n_u8(scale_bits, 4); + scales1 = vorrq_u8(vshlq_n_u8(scales1, 1), vdupq_n_u8(1)); + scales2 = vorrq_u8(vshlq_n_u8(scales2, 1), vdupq_n_u8(1)); + auto s1 = vzip1q_u8(scales1, scales2); + auto s2 = vzip2q_u8(scales1, scales2); + scales16.val[0] = vmovl_u8(vget_low_u8 (s1)); + scales16.val[1] = vmovl_u8(vget_high_u8(s1)); + scales16.val[2] = vmovl_u8(vget_low_u8 (s2)); + scales16.val[3] = vmovl_u8(vget_high_u8(s2)); + for (int ib = 0; ib < QK_K/64; ++ib) { + auto v = vld1q_u8_x2((const uint8_t *)qs); + auto signs128 = vandq_u8(vqtbl2q_u8(v, shuff), vdupq_n_u8(254)); + signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1)); + sh.init(); + for (int i = 0; i < 8; ++i) { + qx[i] = vreinterpretq_s8_u64(uint64x2_t{iq2xs_grid[qs[2*i+0] & 511], iq2xs_grid[qs[2*i+1] & 511]}); + sh.apply_signs_1((uint8x16_t *)qx+i, signs128); + } + auto s32_1 = vmovl_u16(vget_low_u16 (scales16.val[ib])); + auto s32_2 = vmovl_u16(vget_high_u16(scales16.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 128*is + 32*ib); + auto sumi1 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[1], y.val[1])); + auto sumi2 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[3], y.val[1])); + auto sumi3 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[5], y.val[1])); + auto sumi4 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[7], y.val[1])); + auto sumi12 = vpaddq_s32(sumi1, sumi2); // blocks 0,1,2,3 in rows 0,1 + auto sumi34 = vpaddq_s32(sumi3, sumi4); // blocks 4,5,6,7 in rows 2,3 + isum[2*iy+0] = vmlaq_s32(isum[2*iy+0], s32_1, sumi12); + isum[2*iy+1] = vmlaq_s32(isum[2*iy+1], s32_2, sumi34); + } + qs += 16; + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = vpaddq_s32(isum[2*iy+0], isum[2*iy+1]); + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi)); + isum[2*iy] = isum[2*iy+1] = vdupq_n_s32(0); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy])); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nbl = n / QK_K; + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[2*nrc_y] = {}; + int8x16_t qx[8]; + uint16x8x4_t scales16; + SignHelper sh; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq2 = (const block_iq2_s_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d)); + auto qs = iq2[ibl].qs; + auto qh = iq2[ibl].qh; + for (int is = 0; is < 2; ++is) { + auto scale_bits = vld1q_u8(iq2[ibl].scales + 16*is); + auto scales1 = vandq_u8(scale_bits, vdupq_n_u8(0xf)); + auto scales2 = vshrq_n_u8(scale_bits, 4); + scales1 = vorrq_u8(vshlq_n_u8(scales1, 1), vdupq_n_u8(1)); + scales2 = vorrq_u8(vshlq_n_u8(scales2, 1), vdupq_n_u8(1)); + auto s1 = vzip1q_u8(scales1, scales2); + auto s2 = vzip2q_u8(scales1, scales2); + scales16.val[0] = vmovl_u8(vget_low_u8 (s1)); + scales16.val[1] = vmovl_u8(vget_high_u8(s1)); + scales16.val[2] = vmovl_u8(vget_low_u8 (s2)); + scales16.val[3] = vmovl_u8(vget_high_u8(s2)); + for (int ib = 0; ib < QK_K/64; ++ib) { + auto signs128 = vld1q_u8(iq2[ibl].signs + 64*is + 16*ib); + sh.init(); + for (int i = 0; i < 4; ++i) { + qx[2*i+0] = vreinterpretq_s8_u64(uint64x2_t{iq2s_grid[qs[4*i+0] | ((qh[i] << 8) & 0x300)], iq2s_grid[qs[4*i+1] | ((qh[i] << 6) & 0x300)]}); + sh.apply_signs_1((uint8x16_t *)qx+2*i+0, signs128); + qx[2*i+1] = vreinterpretq_s8_u64(uint64x2_t{iq2s_grid[qs[4*i+2] | ((qh[i] << 4) & 0x300)], iq2s_grid[qs[4*i+3] | ((qh[i] << 2) & 0x300)]}); + sh.apply_signs_1((uint8x16_t *)qx+2*i+1, signs128); + } + qs += 16; qh += 4; + auto s32_1 = vmovl_u16(vget_low_u16 (scales16.val[ib])); + auto s32_2 = vmovl_u16(vget_high_u16(scales16.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 128*is + 32*ib); + auto sumi1 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[1], y.val[1])); + auto sumi2 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[3], y.val[1])); + auto sumi3 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[5], y.val[1])); + auto sumi4 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[7], y.val[1])); + auto sumi12 = vpaddq_s32(sumi1, sumi2); // blocks 0,1,2,3 in rows 0,1 + auto sumi34 = vpaddq_s32(sumi3, sumi4); // blocks 4,5,6,7 in rows 2,3 + isum[2*iy+0] = vmlaq_s32(isum[2*iy+0], s32_1, sumi12); + isum[2*iy+1] = vmlaq_s32(isum[2*iy+1], s32_2, sumi34); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = vpaddq_s32(isum[2*iy+0], isum[2*iy+1]); + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi)); + isum[2*iy] = isum[2*iy+1] = vdupq_n_s32(0); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy])); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nbl = n / QK_K; + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[nrc_y] = {}; + int8x16_t qx[8]; + SignHelper sh; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq3 = (const block_iq3_xxs_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto d4 = vmulq_f32(vdupq_n_f32(0.25f), vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d))); + auto qs = iq3[ibl].qs; + for (int ib = 0; ib < QK_K/32; ++ib) { + auto sas = vld1q_u8(iq3[ibl].sas + 16*ib); + auto scale_bits = vandq_u8(sas, vdupq_n_u8(1)); + auto scales = ggml_vdotq_s32(vdupq_n_s32(1), scale_bits, vreinterpretq_s8_u32(vdupq_n_u32(0x10080402))); + auto signs128 = vandq_u8(sas, vdupq_n_u8(254)); + signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1)); + sh.init(); + for (int i = 0; i < 8; ++i) { + qx[i] = vreinterpretq_s8_u32(uint32x4_t{iq3xxs_grid[qs[4*i+0]], iq3xxs_grid[qs[4*i+1]], iq3xxs_grid[qs[4*i+2]], iq3xxs_grid[qs[4*i+3]]}); + sh.apply_signs_1((uint8x16_t *)qx+i, signs128); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib); + auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]); + auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]); + auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]); + auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]); + auto sumi12 = vpaddq_s32(sumi1, sumi2); + auto sumi34 = vpaddq_s32(sumi3, sumi4); + auto sumi = vpaddq_s32(sumi12, sumi34); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + qs += 32; + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + isum[iy] = vdupq_n_s32(0); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nbl = n / QK_K; + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[nrc_y] = {}; + int8x16_t qx[8]; + auto m1 = vdupq_n_u8(1); + auto shuff = vreinterpretq_u8_u32(uint32x4_t{0xffffff00, 0xffffff01, 0xffffff02, 0xffffff03}); + uint32_t stored_scales[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq3 = (const block_iq3_s_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d)); + auto qs = iq3[ibl].qs; + auto qh = iq3[ibl].qh; + auto scale_bits = vld1q_u8(iq3[ibl].scales); + uint8x16x2_t scales8 = { vandq_u8(scale_bits, vdupq_n_u8(0xf)), vshrq_n_u8(scale_bits, 4) }; + scales8.val[0] = vorrq_u8(vshlq_n_u8(scales8.val[0], 1), m1); + scales8.val[1] = vorrq_u8(vshlq_n_u8(scales8.val[1], 1), m1); + vst1q_u8_x2((uint8_t *)stored_scales, scales8); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto signs128 = vld1q_u8(iq3[ibl].signs+16*ib); + if constexpr (nrc_y == 1) { + auto qh32 = (const uint32_t *)qh; + auto idx_h = vreinterpretq_u16_u64(vshlq_u64(vreinterpretq_u64_u16(vmovl_u8(vreinterpret_u8_u32(vdup_n_u32(qh32[0])))), int64x2_t{8, 4})); + union { uint16x8_t vec; uint16_t val[8]; } hidx; + for (int i = 0; i < 4; ++i) { + auto idx_l = vmovl_u8(vld1_u8(qs)); + hidx.vec = vorrq_u16(idx_l, vandq_u16(idx_h, vdupq_n_u16(0x100))); idx_h = vshrq_n_u16(idx_h, 1); + qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[hidx.val[0]], iq3s_grid[hidx.val[1]], iq3s_grid[hidx.val[2]], iq3s_grid[hidx.val[3]]}); + auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(signs128, m1), m1), m1)); + qx[2*i+0] = vmulq_s8(qx[2*i+0], signs); + qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[hidx.val[4]], iq3s_grid[hidx.val[5]], iq3s_grid[hidx.val[6]], iq3s_grid[hidx.val[7]]}); + signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(vshrq_n_u8(signs128, 4), m1), m1), m1)); + qx[2*i+1] = vmulq_s8(qx[2*i+1], signs); + signs128 = vshrq_n_u8(signs128, 1); + qs += 8; + } + } else { + for (int i = 0; i < 4; ++i) { + qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[0] | ((qh[0] << (8-i)) & 0x100)], iq3s_grid[qs[1] | ((qh[1] << (8-i)) & 0x100)], + iq3s_grid[qs[2] | ((qh[2] << (8-i)) & 0x100)], iq3s_grid[qs[3] | ((qh[3] << (8-i)) & 0x100)]}); + auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(signs128, m1), m1), m1)); + qx[2*i+0] = vmulq_s8(qx[2*i+0], signs); + + qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[4] | ((qh[0] << (4-i)) & 0x100)], iq3s_grid[qs[5] | ((qh[1] << (4-i)) & 0x100)], + iq3s_grid[qs[6] | ((qh[2] << (4-i)) & 0x100)], iq3s_grid[qs[7] | ((qh[3] << (4-i)) & 0x100)]}); + signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(vshrq_n_u8(signs128, 4), m1), m1), m1)); + qx[2*i+1] = vmulq_s8(qx[2*i+1], signs); + + qs += 8; + signs128 = vshrq_n_u8(signs128, 1); + } + } + auto scales = vreinterpretq_s32_u8(vqtbl1q_u8(vreinterpretq_u8_u32(vdupq_n_u32(stored_scales[ib])), shuff)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + qh += 4; + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + isum[iy] = vdupq_n_s32(0); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + } bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16) { @@ -1910,30 +2219,26 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array; -// break; -// case GGML_TYPE_IQ2_XS_R4: -// assert (ne00 % QK_K == 0); -// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xs_r4_q8_k, kernels); -//#ifndef HAVE_FANCY_SIMD -// // For some reason Zen4 does not like this particular function -// func16 = mul_mat_iq2_xs_r4_q8_k_16; -//#endif -// break; -// case GGML_TYPE_IQ2_S_R4: -// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_s_r4_q8_k, kernels); -// func16 = mul_mat_iq2_s_r4_q8_k_16; -// break; -// case GGML_TYPE_IQ3_XXS_R4: -// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_xxs_r4_q8_k, kernels); -// func16 = mul_mat_iq3_xxs_r4_q8_k<16>; -// break; -// case GGML_TYPE_IQ3_S_R4: -// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_s_r4_q8_k, kernels); -// func16 = mul_mat_iq3_s_r4_q8_k<16>; -// break; + case GGML_TYPE_IQ2_XXS_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xxs_r4_q8_k, kernels); + func16 = mul_mat_iq2_xxs_r4_q8_k<16>; + break; + case GGML_TYPE_IQ2_XS_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xs_r4_q8_k, kernels); + func16 = mul_mat_iq2_xs_r4_q8_k<16>; + break; + case GGML_TYPE_IQ2_S_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_s_r4_q8_k, kernels); + func16 = mul_mat_iq2_s_r4_q8_k<16>; + break; + case GGML_TYPE_IQ3_XXS_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_xxs_r4_q8_k, kernels); + func16 = mul_mat_iq3_xxs_r4_q8_k<16>; + break; + case GGML_TYPE_IQ3_S_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_s_r4_q8_k, kernels); + func16 = mul_mat_iq3_s_r4_q8_k<16>; + break; default: return false; } diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ac890e11..2a3c850c 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -300,8 +300,6 @@ struct MulMat { } #endif } -private: - template static void set_functions(MulMat& m); }; } @@ -674,9 +672,6 @@ static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataI } #endif -template void MulMat::set_functions(MulMat& m) { -} - bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { (void)Ny; @@ -765,478 +760,8 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { namespace { -template struct Q8 { - - constexpr static int nrc_y = nrc; - - Q8(const DataInfo& info) { - for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy); - } - - inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); } - inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); } - inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); } - inline int16x8_t load_bsums8(int iy, int i) const { - auto q8s = vld1q_s16_x2(y[iy][i].bsums); - return vpaddq_s16(q8s.val[0], q8s.val[1]); - } - inline float scale(int iy, int i) const { return y[iy][i].d; } - - const block_q8 * y[nrc_y]; -}; - -template -inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) { - for (int iy = 0; iy < Q8::nrc_y; ++iy) { - auto q8s = q8.load_bsums(iy, i); - int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0])); - int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0])); - int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1])); - int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1])); - float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4))); - acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i))); - } -} - -inline int32x4x4_t make_wider(const int16x8x2_t& scales16) { - int32x4x4_t scales = { - vmovl_s16(vget_low_s16 (scales16.val[0])), - vmovl_s16(vget_high_s16(scales16.val[0])), - vmovl_s16(vget_low_s16 (scales16.val[1])), - vmovl_s16(vget_high_s16(scales16.val[1])), - }; - return scales; -} - -// ============================= i-quants - -inline int32x4x4_t make_wider_8(const int8x16_t& scales8) { - int16x8x2_t scales16{vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8))}; - return make_wider(scales16); -} - -struct Scale16Extra { - template - static inline int32x4x4_t new_block(int i, float d, uint16_t extra, uint8_t val, - const int8x16_t& scales8, const Q8& q8, float32x4_t * acc) { - uint8x16_t e8 = vreinterpretq_u8_u16(vdupq_n_u16(extra)); - e8 = vceqq_u8(vandq_u8(e8, emask), emask); - e8 = vqtbl1q_u8(vandq_u8(e8, vdupq_n_u8(val)), eshuff); - int16x8x2_t extra16 = {vmull_s8(vget_low_s8 (e8), vget_low_s8 (scales8)), - vmull_s8(vget_high_s8(e8), vget_high_s8(scales8))}; - accum_mins_16(extra16, q8, acc, i, d); - return make_wider_8(scales8); - } - - constexpr static uint32x4_t emask = {0x02020101, 0x08080404, 0x20201010, 0x80804040}; - constexpr static uint32x4_t eshuff = {0x06040200, 0x0e0c0a08, 0x07050301, 0x0f0d0b09}; -}; - -// Note: on ARM_NEON we cannot use the values shifted into the uint8_t range because -// the ARM_NEON only has vdotq_s32 or vdotq_u32, where both operands need to -// be signed or unsigned. As the Q8_K quants are signed, we need to have the -// iq4_s quants also signed. We can only use unsigned values in k-quants -// because they are all within the valid int8_t range. - -struct SimpleBits { - uint8x16x4_t b1; - uint8x16x4_t b2; -}; - -inline int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint32x4_t& v2) { - int32x4x2_t scales; - scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v1, 28), 1), vdupq_n_u32(1))); - scales.val[1] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v2, 28), 1), vdupq_n_u32(1))); - return scales; -} - -inline void apply_signs_2(uint8x16_t * b, const uint64_t * signs, uint32_t sidx) { - auto s1 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 7) & 127)))); - auto s2 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >>14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >>21) & 127)))); - b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s1)); - b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2)); -} - -template struct Q8_K64 { - - constexpr static int nrc_y = nrc; - - Q8_K64(const DataInfo& info) { - for (int iy = 0; iy < nrc_y; ++iy) { - auto dptr = (const float *)info.src1_row(iy); - std::memcpy(d + 8*iy, dptr, 8*sizeof(float)); - y[iy] = (const int8_t *)(dptr + 8); - } - } - - inline int8x16x4_t load_quants64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy] + 128*i + 64*j); } - inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy] + 128*i + 32*j); } - inline float32x4_t scale(int iy) const { return vld1q_f32(d + 8*iy); } - inline float32x4_t minus(int iy) const { return vld1q_f32(d + 8*iy + 4); } - - float d[8*nrc_y]; - const int8_t * y[nrc_y]; -}; - -template struct Q8_16 { - - constexpr static int nrc_y = nrc; - - Q8_16(const DataInfo& info) { - for (int iy = 0; iy < nrc_y; ++iy) { - auto ptr = (const float *)info.src1_row(iy); - std::memcpy(d + 5*iy, ptr, 5*sizeof(float)); - y[iy] = (const int8_t *)(ptr + 5); - } - } - - inline int8x16x4_t load_quants(int iy, int i) const { return vld1q_s8_x4(y[iy] + 64*i); } - inline int8x16x2_t load_quants_32(int iy, int i) const { return vld1q_s8_x2(y[iy] + 32*i); } - inline float scale(int iy, int k) const { return d[5*iy+k]; } - inline float sum_row(int iy) const { return d[5*iy + 4]; } - inline float32x4_t scale(int iy) const { return vld1q_f32(d + 5*iy); } - - float d[5*nrc_y]; - const int8_t * y[nrc_y]; -}; - -template -static void mul_mat_iq2_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); - int nbl = n / QK_K; - float32x4_t acc[nrc_y] = {}; - int32x4_t isum[nrc_y] = {}; - int8x16_t qx[8]; - SignHelper sh; - for (int ix = 0; ix < nrc_x; ix += 4) { - auto iq2 = (const block_iq2_xxs_r4 *)((const char *)vx + (ix+0)*bx); - for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 - auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d)); - auto qs = iq2[ibl].qs; - for (int ib = 0; ib < QK_K/32; ++ib) { - auto sas = vld1q_u8(iq2[ibl].sas + 16*ib); - auto scale_bits = vandq_u8(sas, vdupq_n_u8(1)); - auto scales = ggml_vdotq_s32(vdupq_n_s32(1), scale_bits, vreinterpretq_s8_u32(vdupq_n_u32(0x10080402))); - auto signs128 = vandq_u8(sas, vdupq_n_u8(254)); - signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1)); - sh.init(); - for (int i = 0; i < 8; ++i) { - qx[i] = vreinterpretq_s8_u64(uint64x2_t{iq2xxs_grid[qs[2*i+0]], iq2xxs_grid[qs[2*i+1]]}); - sh.apply_signs_1((uint8x16_t *)qx+i, signs128); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib); - auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]); - auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]); - auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]); - auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]); - auto sumi12 = vpaddq_s32(sumi1, sumi2); - auto sumi34 = vpaddq_s32(sumi3, sumi4); - auto sumi = vpaddq_s32(sumi12, sumi34); - isum[iy] = vmlaq_s32(isum[iy], scales, sumi); - } - qs += 16; - } - for (int iy = 0; iy < nrc_y; ++iy) { - acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); - isum[iy] = vdupq_n_s32(0); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy])); - acc[iy] = vdupq_n_f32(0.f); - } - } -} - -template -static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); - int nbl = n / QK_K; - static const uint8_t k_shuff[16] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; - auto shuff = vld1q_u8(k_shuff); - float32x4_t acc[nrc_y] = {}; - int32x4_t isum[2*nrc_y] = {}; - int8x16_t qx[8]; - uint16x8x4_t scales16; - SignHelper sh; - for (int ix = 0; ix < nrc_x; ix += 4) { - auto iq2 = (const block_iq2_xs_r4 *)((const char *)vx + (ix+0)*bx); - for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 - auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d)); - auto qs = iq2[ibl].qs; - for (int is = 0; is < 2; ++is) { - auto scale_bits = vld1q_u8(iq2[ibl].scales + 16*is); - auto scales1 = vandq_u8(scale_bits, vdupq_n_u8(0xf)); - auto scales2 = vshrq_n_u8(scale_bits, 4); - scales1 = vorrq_u8(vshlq_n_u8(scales1, 1), vdupq_n_u8(1)); - scales2 = vorrq_u8(vshlq_n_u8(scales2, 1), vdupq_n_u8(1)); - auto s1 = vzip1q_u8(scales1, scales2); - auto s2 = vzip2q_u8(scales1, scales2); - scales16.val[0] = vmovl_u8(vget_low_u8 (s1)); - scales16.val[1] = vmovl_u8(vget_high_u8(s1)); - scales16.val[2] = vmovl_u8(vget_low_u8 (s2)); - scales16.val[3] = vmovl_u8(vget_high_u8(s2)); - for (int ib = 0; ib < QK_K/64; ++ib) { - auto v = vld1q_u8_x2((const uint8_t *)qs); - auto signs128 = vandq_u8(vqtbl2q_u8(v, shuff), vdupq_n_u8(254)); - signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1)); - sh.init(); - for (int i = 0; i < 8; ++i) { - qx[i] = vreinterpretq_s8_u64(uint64x2_t{iq2xs_grid[qs[2*i+0] & 511], iq2xs_grid[qs[2*i+1] & 511]}); - sh.apply_signs_1((uint8x16_t *)qx+i, signs128); - } - auto s32_1 = vmovl_u16(vget_low_u16 (scales16.val[ib])); - auto s32_2 = vmovl_u16(vget_high_u16(scales16.val[ib])); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 128*is + 32*ib); - auto sumi1 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[1], y.val[1])); - auto sumi2 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[3], y.val[1])); - auto sumi3 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[5], y.val[1])); - auto sumi4 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[7], y.val[1])); - auto sumi12 = vpaddq_s32(sumi1, sumi2); // blocks 0,1,2,3 in rows 0,1 - auto sumi34 = vpaddq_s32(sumi3, sumi4); // blocks 4,5,6,7 in rows 2,3 - isum[2*iy+0] = vmlaq_s32(isum[2*iy+0], s32_1, sumi12); - isum[2*iy+1] = vmlaq_s32(isum[2*iy+1], s32_2, sumi34); - } - qs += 16; - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sumi = vpaddq_s32(isum[2*iy+0], isum[2*iy+1]); - acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi)); - isum[2*iy] = isum[2*iy+1] = vdupq_n_s32(0); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy])); - acc[iy] = vdupq_n_f32(0.f); - } - } -} - -template -static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); - int nbl = n / QK_K; - float32x4_t acc[nrc_y] = {}; - int32x4_t isum[2*nrc_y] = {}; - int8x16_t qx[8]; - uint16x8x4_t scales16; - SignHelper sh; - for (int ix = 0; ix < nrc_x; ix += 4) { - auto iq2 = (const block_iq2_s_r4 *)((const char *)vx + (ix+0)*bx); - for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 - auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d)); - auto qs = iq2[ibl].qs; - auto qh = iq2[ibl].qh; - for (int is = 0; is < 2; ++is) { - auto scale_bits = vld1q_u8(iq2[ibl].scales + 16*is); - auto scales1 = vandq_u8(scale_bits, vdupq_n_u8(0xf)); - auto scales2 = vshrq_n_u8(scale_bits, 4); - scales1 = vorrq_u8(vshlq_n_u8(scales1, 1), vdupq_n_u8(1)); - scales2 = vorrq_u8(vshlq_n_u8(scales2, 1), vdupq_n_u8(1)); - auto s1 = vzip1q_u8(scales1, scales2); - auto s2 = vzip2q_u8(scales1, scales2); - scales16.val[0] = vmovl_u8(vget_low_u8 (s1)); - scales16.val[1] = vmovl_u8(vget_high_u8(s1)); - scales16.val[2] = vmovl_u8(vget_low_u8 (s2)); - scales16.val[3] = vmovl_u8(vget_high_u8(s2)); - for (int ib = 0; ib < QK_K/64; ++ib) { - auto signs128 = vld1q_u8(iq2[ibl].signs + 64*is + 16*ib); - sh.init(); - for (int i = 0; i < 4; ++i) { - qx[2*i+0] = vreinterpretq_s8_u64(uint64x2_t{iq2s_grid[qs[4*i+0] | ((qh[i] << 8) & 0x300)], iq2s_grid[qs[4*i+1] | ((qh[i] << 6) & 0x300)]}); - sh.apply_signs_1((uint8x16_t *)qx+2*i+0, signs128); - qx[2*i+1] = vreinterpretq_s8_u64(uint64x2_t{iq2s_grid[qs[4*i+2] | ((qh[i] << 4) & 0x300)], iq2s_grid[qs[4*i+3] | ((qh[i] << 2) & 0x300)]}); - sh.apply_signs_1((uint8x16_t *)qx+2*i+1, signs128); - } - qs += 16; qh += 4; - auto s32_1 = vmovl_u16(vget_low_u16 (scales16.val[ib])); - auto s32_2 = vmovl_u16(vget_high_u16(scales16.val[ib])); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 128*is + 32*ib); - auto sumi1 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[1], y.val[1])); - auto sumi2 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[3], y.val[1])); - auto sumi3 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[5], y.val[1])); - auto sumi4 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[7], y.val[1])); - auto sumi12 = vpaddq_s32(sumi1, sumi2); // blocks 0,1,2,3 in rows 0,1 - auto sumi34 = vpaddq_s32(sumi3, sumi4); // blocks 4,5,6,7 in rows 2,3 - isum[2*iy+0] = vmlaq_s32(isum[2*iy+0], s32_1, sumi12); - isum[2*iy+1] = vmlaq_s32(isum[2*iy+1], s32_2, sumi34); - } - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sumi = vpaddq_s32(isum[2*iy+0], isum[2*iy+1]); - acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi)); - isum[2*iy] = isum[2*iy+1] = vdupq_n_s32(0); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy])); - acc[iy] = vdupq_n_f32(0.f); - } - } -} - -template -static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); - int nbl = n / QK_K; - float32x4_t acc[nrc_y] = {}; - int32x4_t isum[nrc_y] = {}; - int8x16_t qx[8]; - SignHelper sh; - for (int ix = 0; ix < nrc_x; ix += 4) { - auto iq3 = (const block_iq3_xxs_r4 *)((const char *)vx + (ix+0)*bx); - for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 - auto d4 = vmulq_f32(vdupq_n_f32(0.25f), vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d))); - auto qs = iq3[ibl].qs; - for (int ib = 0; ib < QK_K/32; ++ib) { - auto sas = vld1q_u8(iq3[ibl].sas + 16*ib); - auto scale_bits = vandq_u8(sas, vdupq_n_u8(1)); - auto scales = ggml_vdotq_s32(vdupq_n_s32(1), scale_bits, vreinterpretq_s8_u32(vdupq_n_u32(0x10080402))); - auto signs128 = vandq_u8(sas, vdupq_n_u8(254)); - signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1)); - sh.init(); - for (int i = 0; i < 8; ++i) { - qx[i] = vreinterpretq_s8_u32(uint32x4_t{iq3xxs_grid[qs[4*i+0]], iq3xxs_grid[qs[4*i+1]], iq3xxs_grid[qs[4*i+2]], iq3xxs_grid[qs[4*i+3]]}); - sh.apply_signs_1((uint8x16_t *)qx+i, signs128); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib); - auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]); - auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]); - auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]); - auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]); - auto sumi12 = vpaddq_s32(sumi1, sumi2); - auto sumi34 = vpaddq_s32(sumi3, sumi4); - auto sumi = vpaddq_s32(sumi12, sumi34); - isum[iy] = vmlaq_s32(isum[iy], scales, sumi); - } - qs += 32; - } - for (int iy = 0; iy < nrc_y; ++iy) { - acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); - isum[iy] = vdupq_n_s32(0); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, acc[iy]); - acc[iy] = vdupq_n_f32(0.f); - } - } -} - -template -static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); - int nbl = n / QK_K; - float32x4_t acc[nrc_y] = {}; - int32x4_t isum[nrc_y] = {}; - int8x16_t qx[8]; - auto m1 = vdupq_n_u8(1); - auto shuff = vreinterpretq_u8_u32(uint32x4_t{0xffffff00, 0xffffff01, 0xffffff02, 0xffffff03}); - uint32_t stored_scales[8]; - for (int ix = 0; ix < nrc_x; ix += 4) { - auto iq3 = (const block_iq3_s_r4 *)((const char *)vx + (ix+0)*bx); - for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 - auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d)); - auto qs = iq3[ibl].qs; - auto qh = iq3[ibl].qh; - auto scale_bits = vld1q_u8(iq3[ibl].scales); - uint8x16x2_t scales8 = { vandq_u8(scale_bits, vdupq_n_u8(0xf)), vshrq_n_u8(scale_bits, 4) }; - scales8.val[0] = vorrq_u8(vshlq_n_u8(scales8.val[0], 1), m1); - scales8.val[1] = vorrq_u8(vshlq_n_u8(scales8.val[1], 1), m1); - vst1q_u8_x2((uint8_t *)stored_scales, scales8); - for (int ib = 0; ib < QK_K/32; ++ib) { - auto signs128 = vld1q_u8(iq3[ibl].signs+16*ib); - if constexpr (nrc_y == 1) { - auto qh32 = (const uint32_t *)qh; - auto idx_h = vreinterpretq_u16_u64(vshlq_u64(vreinterpretq_u64_u16(vmovl_u8(vreinterpret_u8_u32(vdup_n_u32(qh32[0])))), int64x2_t{8, 4})); - union { uint16x8_t vec; uint16_t val[8]; } hidx; - for (int i = 0; i < 4; ++i) { - auto idx_l = vmovl_u8(vld1_u8(qs)); - hidx.vec = vorrq_u16(idx_l, vandq_u16(idx_h, vdupq_n_u16(0x100))); idx_h = vshrq_n_u16(idx_h, 1); - qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[hidx.val[0]], iq3s_grid[hidx.val[1]], iq3s_grid[hidx.val[2]], iq3s_grid[hidx.val[3]]}); - auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(signs128, m1), m1), m1)); - qx[2*i+0] = vmulq_s8(qx[2*i+0], signs); - qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[hidx.val[4]], iq3s_grid[hidx.val[5]], iq3s_grid[hidx.val[6]], iq3s_grid[hidx.val[7]]}); - signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(vshrq_n_u8(signs128, 4), m1), m1), m1)); - qx[2*i+1] = vmulq_s8(qx[2*i+1], signs); - signs128 = vshrq_n_u8(signs128, 1); - qs += 8; - } - } else { - for (int i = 0; i < 4; ++i) { - qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[0] | ((qh[0] << (8-i)) & 0x100)], iq3s_grid[qs[1] | ((qh[1] << (8-i)) & 0x100)], - iq3s_grid[qs[2] | ((qh[2] << (8-i)) & 0x100)], iq3s_grid[qs[3] | ((qh[3] << (8-i)) & 0x100)]}); - auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(signs128, m1), m1), m1)); - qx[2*i+0] = vmulq_s8(qx[2*i+0], signs); - - qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[4] | ((qh[0] << (4-i)) & 0x100)], iq3s_grid[qs[5] | ((qh[1] << (4-i)) & 0x100)], - iq3s_grid[qs[6] | ((qh[2] << (4-i)) & 0x100)], iq3s_grid[qs[7] | ((qh[3] << (4-i)) & 0x100)]}); - signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(vshrq_n_u8(signs128, 4), m1), m1), m1)); - qx[2*i+1] = vmulq_s8(qx[2*i+1], signs); - - qs += 8; - signs128 = vshrq_n_u8(signs128, 1); - } - } - auto scales = vreinterpretq_s32_u8(vqtbl1q_u8(vreinterpretq_u8_u32(vdupq_n_u32(stored_scales[ib])), shuff)); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib); - auto sumi = interleaved_dotq(qx, y); - isum[iy] = vmlaq_s32(isum[iy], scales, sumi); - } - qh += 4; - } - for (int iy = 0; iy < nrc_y; ++iy) { - acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); - isum[iy] = vdupq_n_s32(0); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, acc[iy]); - acc[iy] = vdupq_n_f32(0.f); - } - } -} - -#define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \ - m.funcs[0] = func;\ - m.funcs[1] = func;\ - m.funcs[2] = func;\ - m.funcs[3] = func;\ - m.funcs[4] = func;\ - m.funcs[5] = func;\ - m.funcs[6] = func;\ - m.funcs[7] = func;\ - -#define SET_MUL_MAT_FUNCTIONS(m, func) \ - m.funcs[0] = func<1>;\ - m.funcs[1] = func<2>;\ - m.funcs[2] = func<3>;\ - m.funcs[3] = func<4>;\ - m.funcs[4] = func<5>;\ - m.funcs[5] = func<6>;\ - m.funcs[6] = func<7>;\ - m.funcs[7] = func<8>;\ - -template void MulMat::set_functions(MulMat& m) { - SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qX_K_q8_K_T, Dequantizer); -} - bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { - auto expected_Btype = GGML_TYPE_Q8_K; - switch (typeA) { case GGML_TYPE_F16: case GGML_TYPE_BF16: @@ -1279,6 +804,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ2_XXS_R4: + case GGML_TYPE_IQ2_XS_R4: + case GGML_TYPE_IQ2_S_R4: + case GGML_TYPE_IQ3_XXS_R4: + case GGML_TYPE_IQ3_S_R4: return iqk_set_kernels_iquants(ne00, typeA, typeB, m.funcs, m.func16); case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: @@ -1293,21 +823,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_Q8_0_R8: case GGML_TYPE_IQ4_NL_R4: return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, m.funcs, m.func16); - case GGML_TYPE_IQ2_XXS_R4: - SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_xxs_r4_q8_k); - m.func16 = mul_mat_iq2_xxs_r4_q8_k<16>; - expected_Btype = GGML_TYPE_Q8_K; - break; - case GGML_TYPE_IQ2_XS_R4: - SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_xs_r4_q8_k); - m.func16 = mul_mat_iq2_xs_r4_q8_k<16>; - expected_Btype = GGML_TYPE_Q8_K; - break; - case GGML_TYPE_IQ2_S_R4: - SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_s_r4_q8_k); - m.func16 = mul_mat_iq2_s_r4_q8_k<16>; - expected_Btype = GGML_TYPE_Q8_K; - break; case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ2_BN_R4: @@ -1315,21 +830,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ1_M_R4: return iqk_set_kernels_1bit(ne00, typeA, typeB, m.funcs, m.func16); - case GGML_TYPE_IQ3_XXS_R4: - SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_xxs_r4_q8_k); - m.func16 = mul_mat_iq3_xxs_r4_q8_k<16>; - expected_Btype = GGML_TYPE_Q8_K; - break; - case GGML_TYPE_IQ3_S_R4: - SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_s_r4_q8_k); - m.func16 = mul_mat_iq3_s_r4_q8_k<16>; - expected_Btype = GGML_TYPE_Q8_K; - break; default: return false; } - return typeB == expected_Btype; } }