mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-09 21:40:22 +00:00
iqk_mul_mat: AVX2 implementation for iq2_xs
We get 2.19X for PP-512 (118.9 t/s). TG is mostly OK (slightly better @ 4 threads, slightly worse @ 16 threads).
This commit is contained in:
140
iqk_mul_mat.cpp
140
iqk_mul_mat.cpp
@@ -1313,12 +1313,7 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
|
||||
const __m256i min_value = _mm256_set1_epi8(minv);
|
||||
|
||||
};
|
||||
//inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) {
|
||||
// const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
|
||||
// const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
|
||||
// scales[0] = MM256_SET_M128I(l_scales, l_scales);
|
||||
// scales[1] = MM256_SET_M128I(h_scales, h_scales);
|
||||
//}
|
||||
|
||||
struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
|
||||
DequantizerIQ2S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
|
||||
@@ -1327,11 +1322,18 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
|
||||
inline __m256i load_scales(int i) {
|
||||
d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
|
||||
auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
|
||||
auto all = _mm_and_si128(_mm_or_si128(_mm_slli_si128(_mm_srli_epi16(tmp, 4), 8), tmp), _mm_set1_epi8(0xf));
|
||||
auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf));
|
||||
auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
|
||||
auto shuffle = _mm_set_epi64x(0x0f070e060d050c04, 0x0b030a0209010800);
|
||||
return _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, shuffle));
|
||||
return _mm256_cvtepi8_epi16(scales8);
|
||||
}
|
||||
//inline __m256i load_scales(int i) {
|
||||
// d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
|
||||
// auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
|
||||
// auto all = _mm_and_si128(_mm_or_si128(_mm_slli_si128(_mm_srli_epi16(tmp, 4), 8), tmp), _mm_set1_epi8(0xf));
|
||||
// auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
|
||||
// auto shuffle = _mm_set_epi64x(0x0f070e060d050c04, 0x0b030a0209010800);
|
||||
// return _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, shuffle));
|
||||
//}
|
||||
inline static void prepare_scales(const __m256i& all, __m256i * scales) {
|
||||
auto scales_l = _mm256_castsi256_si128(all);
|
||||
auto scales_h = _mm256_extractf128_si256(all, 1);
|
||||
@@ -1403,6 +1405,120 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
|
||||
DequantizerIQ2XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
|
||||
constexpr static int num_blocks = 16;
|
||||
|
||||
inline __m256i load_scales(int i) {
|
||||
d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
|
||||
auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
|
||||
auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf));
|
||||
auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
|
||||
return _mm256_cvtepi8_epi16(scales8);
|
||||
}
|
||||
inline static void prepare_scales(const __m256i& all, __m256i * scales) {
|
||||
auto scales_l = _mm256_castsi256_si128(all);
|
||||
auto scales_h = _mm256_extractf128_si256(all, 1);
|
||||
scales[0] = MM256_SET_M128I(scales_l, scales_l);
|
||||
scales[1] = MM256_SET_M128I(scales_h, scales_h);
|
||||
}
|
||||
|
||||
inline void new_block(int i, __m256i * scales) {
|
||||
prepare_scales(load_scales(i), scales);
|
||||
}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accd, __m256i * scales) {
|
||||
auto all_scales = load_scales(i);
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
auto bsums = q8.load_bsums(iy, i);
|
||||
auto prod = _mm256_madd_epi16(all_scales, bsums);
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(-d*q8.scale(iy, i)*minv), _mm256_cvtepi32_ps(prod), accd[iy]);
|
||||
}
|
||||
prepare_scales(all_scales, scales);
|
||||
}
|
||||
|
||||
struct Helper {
|
||||
const __m256i mone = _mm256_set1_epi8(1);
|
||||
const __m256i mask = _mm256_set1_epi64x(0x8040201008040201);
|
||||
//const __m256i bhelper = _mm256_set_epi64x(0x8000008000808000, 0x0080800080000080, 0x8000008000808000, 0x0080800080000080);
|
||||
const __m256i bhelper = load_bhelper();
|
||||
const __m256i shuff1 = _mm256_set_epi64x(0x0606060606060606, 0x0404040404040404, 0x0202020202020202, 0x0000000000000000);
|
||||
const __m256i shuff2 = _mm256_set_epi64x(0x0e0e0e0e0e0e0e0e, 0x0c0c0c0c0c0c0c0c, 0x0a0a0a0a0a0a0a0a, 0x0808080808080808);
|
||||
static __m256i load_bhelper() {
|
||||
static const uint8_t k_bit_helper[32] = {
|
||||
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
||||
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
||||
};
|
||||
return _mm256_loadu_si256((const __m256i*)k_bit_helper);
|
||||
}
|
||||
};
|
||||
|
||||
union index_t {
|
||||
__m256i vec;
|
||||
uint16_t val[8];
|
||||
};
|
||||
|
||||
inline static void make4(const __m256i& data, const __m256i& mask, __m256i * values) {
|
||||
index_t idx;
|
||||
idx.vec = _mm256_and_si256(data, mask);
|
||||
values[0] = _mm256_set_epi64x(iq2xs_grid[idx.val[ 3]], iq2xs_grid[idx.val[ 2]], iq2xs_grid[idx.val[ 1]], iq2xs_grid[idx.val[ 0]]);
|
||||
values[1] = _mm256_set_epi64x(iq2xs_grid[idx.val[ 7]], iq2xs_grid[idx.val[ 6]], iq2xs_grid[idx.val[ 5]], iq2xs_grid[idx.val[ 4]]);
|
||||
values[2] = _mm256_set_epi64x(iq2xs_grid[idx.val[11]], iq2xs_grid[idx.val[10]], iq2xs_grid[idx.val[ 9]], iq2xs_grid[idx.val[ 8]]);
|
||||
values[3] = _mm256_set_epi64x(iq2xs_grid[idx.val[15]], iq2xs_grid[idx.val[14]], iq2xs_grid[idx.val[13]], iq2xs_grid[idx.val[12]]);
|
||||
}
|
||||
inline static void sign_value(const __m256i& sign_bits, const __m256i& shuffle, const __m256i& mask,
|
||||
const __m256i& mone, __m256i& value) {
|
||||
auto signs = _mm256_shuffle_epi8(sign_bits, shuffle);
|
||||
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, mask), mask);
|
||||
value = _mm256_sign_epi8(value, _mm256_or_si256(signs, mone));
|
||||
}
|
||||
inline static void sign_values(const __m256i& data, const Helper& helper, __m256i * values) {
|
||||
auto psb1 = _mm256_srli_epi16(data, 9);
|
||||
auto psb2 = _mm256_srli_epi16(data, 13);
|
||||
auto psbc = _mm256_xor_si256(psb1, psb2);
|
||||
auto oddb = _mm256_shuffle_epi8(helper.bhelper, psbc);
|
||||
auto full = _mm256_or_si256(psb1, oddb);
|
||||
auto full_l = _mm256_castsi256_si128(full);
|
||||
auto full_h = _mm256_extractf128_si256(full, 1);
|
||||
auto full_1 = MM256_SET_M128I(full_l, full_l);
|
||||
auto full_2 = MM256_SET_M128I(full_h, full_h);
|
||||
sign_value(full_1, helper.shuff1, helper.mask, helper.mone, values[0]);
|
||||
sign_value(full_1, helper.shuff2, helper.mask, helper.mone, values[1]);
|
||||
sign_value(full_2, helper.shuff1, helper.mask, helper.mone, values[2]);
|
||||
sign_value(full_2, helper.shuff2, helper.mask, helper.mone, values[3]);
|
||||
}
|
||||
inline static void make4_signed(const Helper& helper, const uint16_t * qs, const __m256i& m511,
|
||||
const __m256i& min_value, __m256i * values) {
|
||||
auto q2 = _mm256_loadu_si256((const __m256i *)qs);
|
||||
make4(q2, m511, values);
|
||||
sign_values(q2, helper, values);
|
||||
for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value);
|
||||
}
|
||||
inline static void make4(const Helper& helper, const uint16_t * qs, const __m256i& m511, __m256i * values, __m256i * q8) {
|
||||
auto q2 = _mm256_loadu_si256((const __m256i *)qs);
|
||||
make4(q2, m511, values);
|
||||
sign_values(q2, helper, q8);
|
||||
}
|
||||
|
||||
inline void prepare(int i, int j) {
|
||||
make4_signed(helper, x[i].qs + 16*j, idx_mask, min_value, bits.values);
|
||||
}
|
||||
template <typename Q8>
|
||||
inline void prepare(int i, int j, const Q8& q8, __m256i * q8_quants) {
|
||||
for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
|
||||
make4(helper, x[i].qs + 16*j, idx_mask, bits.values, q8_quants);
|
||||
}
|
||||
|
||||
constexpr static int minv = 43;
|
||||
|
||||
SimpleBits bits;
|
||||
Helper helper;
|
||||
const __m256i idx_mask = _mm256_set1_epi16(511);
|
||||
const __m256i min_value = _mm256_set1_epi8(minv);
|
||||
|
||||
};
|
||||
|
||||
//
|
||||
// ============================== Legacy quants
|
||||
//
|
||||
@@ -1778,7 +1894,7 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
|
||||
m.funcs[7] = mul_mat_qX_1_q8_1_T<Dequantizer, 8>;
|
||||
}
|
||||
else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S> || std::is_same_v<Dequantizer, DequantizerIQ3XXS> ||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ2S>) {
|
||||
std::is_same_v<Dequantizer, DequantizerIQ2S> || std::is_same_v<Dequantizer, DequantizerIQ2XS>) {
|
||||
m.funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
|
||||
m.funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
|
||||
m.funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>;
|
||||
@@ -1870,6 +1986,10 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerIQ2S>(mm);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerIQ2XS>(mm);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
assert (ne00 % QK4_0 == 0);
|
||||
MulMat::set_functions<Q4_0_Unpacker>(mm);
|
||||
|
||||
Reference in New Issue
Block a user