iqk_mul_mat: AVX2 implementation for iq2_s

We get 2.04X for PP-512 (107 t/s). TG againsuffers
a small loss in performance (19.9 t/s vs 21.4 t/s @ 16 threads)
This commit is contained in:
Kawrakow
2024-05-29 17:27:36 +03:00
parent f31200bde1
commit 3c448906bf

View File

@@ -365,14 +365,22 @@ inline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) {
scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3)); scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3));
} }
//#if defined(__AVX512VNNI__) && defined(__AVX512VL__) inline __m256i get_scale_shuffle_16(int i) {
// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scale_1, dot1); static const uint8_t k_shuffle[128] = {
// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scale_2, dot2); 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
//#else 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
// const __m256i p1 = _mm256_madd_epi16(scale_1, dot1); 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
// const __m256i p2 = _mm256_madd_epi16(scale_2, dot2); 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
// sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p2)); };
//#endif return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
}
inline void set_scales_16(const __m256i& all_scales, __m256i * scales) {
scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0));
scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1));
scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2));
scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3));
}
template <typename Q8, typename Bits> template <typename Q8, typename Bits>
inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
@@ -903,23 +911,6 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
const __m256i mh = _mm256_set1_epi8(0x30); const __m256i mh = _mm256_set1_epi8(0x30);
}; };
inline __m256i get_scale_shuffle_16(int i) {
static const uint8_t k_shuffle[128] = {
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
};
return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
}
inline void set_scales_16(const __m256i& all_scales, __m256i * scales) {
scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0));
scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1));
scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2));
scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3));
}
template <typename Dequantizer, int nrc_y> template <typename Dequantizer, int nrc_y>
static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0); assert(n%QK_K == 0);
@@ -1061,12 +1052,16 @@ static void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const Data
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
auto all_scales = deq.new_block(i); __m256i sumi[2], all_scales[Dequantizer::num_blocks/8];
__m256i sumi[2]; deq.new_block(i, all_scales);
for (int j = 0; j < QK_K/128; ++j) { for (int j = 0; j < QK_K/128; ++j) {
deq.prepare(i, j, q8, q8_quants); deq.prepare(i, j, q8, q8_quants);
set_scales_8(all_scales, j, scales); if constexpr (Dequantizer::num_blocks == 8) {
set_scales_8(all_scales[0], j, scales);
} else {
set_scales_16(all_scales[j], scales);
}
multiply_add_1(j, deq.bits, scales, q8_quants, sumi); multiply_add_1(j, deq.bits, scales, q8_quants, sumi);
} }
accd = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(0, i)), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi[0], sumi[1])), accd); accd = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(0, i)), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi[0], sumi[1])), accd);
@@ -1092,13 +1087,16 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
auto all_scales = deq.new_block(i, q8, accd); __m256i sumi[nrc_y], all_scales[Dequantizer::num_blocks/8];
deq.new_block(i, q8, accd, all_scales);
__m256i sumi[nrc_y];
for (int j = 0; j < QK_K/128; ++j) { for (int j = 0; j < QK_K/128; ++j) {
deq.prepare(i, j); deq.prepare(i, j);
set_scales_8(all_scales, j, scales); if constexpr (Dequantizer::num_blocks == 8) {
set_scales_8(all_scales[0], j, scales);
} else {
set_scales_16(all_scales[j], scales);
}
multiply_add(deq.bits, scales, j, i, q8, sumi); multiply_add(deq.bits, scales, j, i, q8, sumi);
} }
for (int iy = 0; iy < nrc_y; ++iy) { for (int iy = 0; iy < nrc_y; ++iy) {
@@ -1121,67 +1119,6 @@ static void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataIn
} else { } else {
mul_mat_qX_K_q8_K_IQ_N<Dequantizer, nrc_y>(n, vx, bx, info, nrc_x); mul_mat_qX_K_q8_K_IQ_N<Dequantizer, nrc_y>(n, vx, bx, info, nrc_x);
} }
//const int nb = n / QK_K;
//Q8<nrc_y> q8(info);
//Dequantizer deq(vx, bx);
//__m256i scales[4];
//if constexpr (nrc_y == 1) {
// __m256i q8_quants[4];
// for (int ix = 0; ix < nrc_x; ++ix) {
// __m256 accd = _mm256_setzero_ps();
// deq.new_row(ix);
// for (int i = 0; i < nb; ++i) {
// auto all_scales = deq.new_block(i);
// __m256i sumi[2];
// for (int j = 0; j < QK_K/128; ++j) {
// deq.prepare(i, j, q8, q8_quants);
// set_scales_8(all_scales, j, scales);
// multiply_add_1(j, deq.bits, scales, q8_quants, sumi);
// }
// accd = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(0, i)), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi[0], sumi[1])), accd);
// }
// info.store(ix, 0, hsum_float_8(accd));
// }
//} else {
// __m256 accd[nrc_y];
// for (int ix = 0; ix < nrc_x; ++ix) {
// for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
// deq.new_row(ix);
// for (int i = 0; i < nb; ++i) {
// auto all_scales = deq.new_block(i, q8, accd);
// __m256i sumi[nrc_y];
// for (int j = 0; j < QK_K/128; ++j) {
// deq.prepare(i, j);
// set_scales_8(all_scales, j, scales);
// multiply_add(deq.bits, scales, j, i, q8, sumi);
// }
// for (int iy = 0; iy < nrc_y; ++iy) {
// const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
// accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
// }
// }
// for (int iy = 0; iy < nrc_y; ++iy) {
// info.store(ix, iy, hsum_float_8(accd[iy]));
// }
// }
//}
} }
struct SimpleBits { struct SimpleBits {
@@ -1208,6 +1145,8 @@ struct SignHelper {
struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> { struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
DequantizerIQ3S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} DequantizerIQ3S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
constexpr static int num_blocks = 8;
inline __m128i make_scales(int i, float& dd) const { inline __m128i make_scales(int i, float& dd) const {
dd = GGML_FP16_TO_FP32(x[i].d); dd = GGML_FP16_TO_FP32(x[i].d);
uint32_t aux32[2]; uint32_t aux32[2];
@@ -1218,15 +1157,15 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
auto scales16 = _mm256_castsi256_si128(_mm256_cvtepi8_epi16(scales8)); auto scales16 = _mm256_castsi256_si128(_mm256_cvtepi8_epi16(scales8));
return _mm_or_si128(_mm_slli_epi16(scales16, 1), _mm_set1_epi16(1)); return _mm_or_si128(_mm_slli_epi16(scales16, 1), _mm_set1_epi16(1));
} }
inline __m256i new_block(int i) { inline void new_block(int i, __m256i * scales) {
auto scales16 = make_scales(i, d); auto scales16 = make_scales(i, d);
return MM256_SET_M128I(scales16, scales16); scales[0] = MM256_SET_M128I(scales16, scales16);
} }
template <typename Q8> template <typename Q8>
inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { inline void new_block(int i, const Q8& q8, __m256 * accd, __m256i * scales) {
auto scales16 = make_scales(i, d); auto scales16 = make_scales(i, d);
scb.accum_mins(scales16, q8, i, -minv*d, accd); scb.accum_mins(scales16, q8, i, -minv*d, accd);
return MM256_SET_M128I(scales16, scales16); scales[0] = MM256_SET_M128I(scales16, scales16);
} }
union index_t { union index_t {
@@ -1309,6 +1248,8 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> { struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
DequantizerIQ3XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} DequantizerIQ3XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
constexpr static int num_blocks = 8;
inline __m128i prepare_scales(int i) { inline __m128i prepare_scales(int i) {
d = 0.25f * GGML_FP16_TO_FP32(x[i].d); d = 0.25f * GGML_FP16_TO_FP32(x[i].d);
auto tmp = _mm256_loadu_si256((const __m256i *)(x[i].qs + QK_K/4)); auto tmp = _mm256_loadu_si256((const __m256i *)(x[i].qs + QK_K/4));
@@ -1317,15 +1258,15 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
return _mm_packs_epi32(_mm256_castsi256_si128(scales32), _mm256_extractf128_si256(scales32, 1)); return _mm_packs_epi32(_mm256_castsi256_si128(scales32), _mm256_extractf128_si256(scales32, 1));
} }
inline __m256i new_block(int i) { inline void new_block(int i, __m256i * scales) {
auto scales16 = prepare_scales(i); auto scales16 = prepare_scales(i);
return MM256_SET_M128I(scales16, scales16); scales[0] = MM256_SET_M128I(scales16, scales16);
} }
template <typename Q8> template <typename Q8>
inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { inline void new_block(int i, const Q8& q8, __m256 * accd, __m256i * scales) {
auto scales16 = prepare_scales(i); auto scales16 = prepare_scales(i);
scb.accum_mins(scales16, q8, i, -minv*d, accd); scb.accum_mins(scales16, q8, i, -minv*d, accd);
return MM256_SET_M128I(scales16, scales16); scales[0] = MM256_SET_M128I(scales16, scales16);
} }
inline static __m256i make_quants(const uint8_t * qs) { inline static __m256i make_quants(const uint8_t * qs) {
@@ -1371,43 +1312,97 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
Scales8KBase scb; Scales8KBase scb;
const __m256i min_value = _mm256_set1_epi8(minv); const __m256i min_value = _mm256_set1_epi8(minv);
};
//inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) {
// const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
// const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
// scales[0] = MM256_SET_M128I(l_scales, l_scales);
// scales[1] = MM256_SET_M128I(h_scales, h_scales);
//}
struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
DequantizerIQ2S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
constexpr static int num_blocks = 16;
inline __m256i load_scales(int i) {
d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
auto all = _mm_and_si128(_mm_or_si128(_mm_slli_si128(_mm_srli_epi16(tmp, 4), 8), tmp), _mm_set1_epi8(0xf));
auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
auto shuffle = _mm_set_epi64x(0x0f070e060d050c04, 0x0b030a0209010800);
return _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, shuffle));
}
inline static void prepare_scales(const __m256i& all, __m256i * scales) {
auto scales_l = _mm256_castsi256_si128(all);
auto scales_h = _mm256_extractf128_si256(all, 1);
scales[0] = MM256_SET_M128I(scales_l, scales_l);
scales[1] = MM256_SET_M128I(scales_h, scales_h);
}
inline void new_block(int i, __m256i * scales) {
prepare_scales(load_scales(i), scales);
}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accd, __m256i * scales) {
auto all_scales = load_scales(i);
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
auto bsums = q8.load_bsums(iy, i);
auto prod = _mm256_madd_epi16(all_scales, bsums);
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(-d*q8.scale(iy, i)*minv), _mm256_cvtepi32_ps(prod), accd[iy]);
}
prepare_scales(all_scales, scales);
}
union index_t {
__m256i vec;
uint32_t val[8];
};
inline static void make2(const uint8_t * qs, const uint8_t * qh, const __m256i& idx_shift, const __m256i& idx_mask, __m256i * values) {
auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs));
auto idx_h = MM256_SET_M128I(_mm_set1_epi32(qh[1]), _mm_set1_epi32(qh[0]));
index_t idx;
idx.vec = _mm256_or_si256(idx_l, _mm256_and_si256(_mm256_sllv_epi32(idx_h, idx_shift), idx_mask));
values[0] = _mm256_set_epi64x(iq2s_grid[idx.val[3]], iq2s_grid[idx.val[2]], iq2s_grid[idx.val[1]], iq2s_grid[idx.val[0]]);
values[1] = _mm256_set_epi64x(iq2s_grid[idx.val[7]], iq2s_grid[idx.val[6]], iq2s_grid[idx.val[5]], iq2s_grid[idx.val[4]]);
}
inline static void make2_signed(const SignHelper& sh, const uint8_t * qs, const uint8_t * qh, const uint16_t * sidx,
const __m256i& idx_shift, const __m256i& idx_mask, const __m256i& min_value, __m256i * values) {
make2(qs, qh, idx_shift, idx_mask, values);
values[0] = _mm256_add_epi8(_mm256_sign_epi8(values[0], sh.make_signs(sidx[0] | (sidx[1] << 16))), min_value);
values[1] = _mm256_add_epi8(_mm256_sign_epi8(values[1], sh.make_signs(sidx[2] | (sidx[3] << 16))), min_value);
}
inline void prepare(int i, int j) {
auto qs = x[i].qs + 16*j;
auto qh = x[i].qh + 4*j;
const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/8) + 8*j;
make2_signed(sh, qs+0, qh+0, signs+0, idx_shift, idx_mask, min_value, bits.values+0);
make2_signed(sh, qs+8, qh+2, signs+4, idx_shift, idx_mask, min_value, bits.values+2);
}
template <typename Q8>
inline void prepare(int i, int j, const Q8& q8, __m256i * q8_quants) {
auto qs = x[i].qs + 16*j;
auto qh = x[i].qh + 4*j;
const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/8) + 8*j;
make2(qs+0, qh+0, idx_shift, idx_mask, bits.values+0);
make2(qs+8, qh+2, idx_shift, idx_mask, bits.values+2);
q8_quants[0] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+0), sh.make_signs(signs[0] | (signs[1] << 16)));
q8_quants[1] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+1), sh.make_signs(signs[2] | (signs[3] << 16)));
q8_quants[2] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+2), sh.make_signs(signs[4] | (signs[5] << 16)));
q8_quants[3] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+3), sh.make_signs(signs[6] | (signs[7] << 16)));
}
constexpr static int minv = 43;
SimpleBits bits;
SignHelper sh;
const __m256i idx_shift = _mm256_set_epi32(2, 4, 6, 8, 2, 4, 6, 8);
const __m256i idx_mask = _mm256_set1_epi32(0x300);
const __m256i min_value = _mm256_set1_epi8(minv);
}; };
//struct DequantizerIQ3XXS_1 final : public BaseDequantizer<block_iq3_xxs> {
// DequantizerIQ3XXS_1(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
//
// inline __m256i new_block(int i) {
// d = 0.25f * GGML_FP16_TO_FP32(x[i].d);
// auto tmp = _mm256_loadu_si256((const __m256i *)(x[i].qs + QK_K/4));
// auto scales32 = _mm256_srli_epi32(tmp, 28);
// scales32 = _mm256_or_si256(_mm256_slli_epi32(scales32, 1), _mm256_set1_epi32(1));
// auto scales16 = _mm_packs_epi32(_mm256_castsi256_si128(scales32), _mm256_extractf128_si256(scales32, 1));
// return MM256_SET_M128I(scales16, scales16);
// }
//
// inline static __m256i make1(const uint8_t * qs, const uint16_t * sidx, __m256i& q8_quants) {
// auto val = _mm256_set_epi32(iq3xxs_grid[qs[7]], iq3xxs_grid[qs[6]], iq3xxs_grid[qs[5]], iq3xxs_grid[qs[4]],
// iq3xxs_grid[qs[3]], iq3xxs_grid[qs[2]], iq3xxs_grid[qs[1]], iq3xxs_grid[qs[0]]);
// uint32_t aux32 = sidx[0] | (sidx[1] << 16);
// auto s = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127],
// keven_signs[(aux32 >> 7) & 127], keven_signs[aux32 & 127]);
// q8_quants = _mm256_sign_epi8(q8_quants, s);
// return val;
// }
//
// template <typename Q8>
// inline void prepare(int i, int j, const Q8& q8, __m256i * q8_quants) {
// auto qs = x[i].qs + 32*j;
// const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4) + 8*j;
// q8_quants[0] = q8.load_quants(0, i, 4*j+0); bits.values[0] = make1(qs+ 0, signs+0, q8_quants[0]);
// q8_quants[1] = q8.load_quants(0, i, 4*j+1); bits.values[1] = make1(qs+ 8, signs+2, q8_quants[1]);
// q8_quants[2] = q8.load_quants(0, i, 4*j+2); bits.values[2] = make1(qs+16, signs+4, q8_quants[2]);
// q8_quants[3] = q8.load_quants(0, i, 4*j+3); bits.values[3] = make1(qs+24, signs+6, q8_quants[3]);
// }
//
// SimpleBits bits;
//
//};
// //
// ============================== Legacy quants // ============================== Legacy quants
// //
@@ -1782,7 +1777,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
m.funcs[6] = mul_mat_qX_1_q8_1_T<Dequantizer, 7>; m.funcs[6] = mul_mat_qX_1_q8_1_T<Dequantizer, 7>;
m.funcs[7] = mul_mat_qX_1_q8_1_T<Dequantizer, 8>; m.funcs[7] = mul_mat_qX_1_q8_1_T<Dequantizer, 8>;
} }
else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S> || std::is_same_v<Dequantizer, DequantizerIQ3XXS>) { else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S> || std::is_same_v<Dequantizer, DequantizerIQ3XXS> ||
std::is_same_v<Dequantizer, DequantizerIQ2S>) {
m.funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>; m.funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
m.funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>; m.funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
m.funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>; m.funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>;
@@ -1870,6 +1866,10 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
assert (ne00 % QK_K == 0); assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ3XXS>(mm); MulMat::set_functions<DequantizerIQ3XXS>(mm);
break; break;
case GGML_TYPE_IQ2_S:
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ2S>(mm);
break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
assert (ne00 % QK4_0 == 0); assert (ne00 % QK4_0 == 0);
MulMat::set_functions<Q4_0_Unpacker>(mm); MulMat::set_functions<Q4_0_Unpacker>(mm);