iqk_mul_mat: AVX2 implementation for iq2_xs

We get 2.19X for PP-512 (118.9 t/s). TG is mostly OK
(slightly better @ 4 threads, slightly worse @ 16 threads).
This commit is contained in:
Kawrakow
2024-05-29 19:05:20 +03:00
parent 3c448906bf
commit be132341f5

View File

@@ -1313,12 +1313,7 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
const __m256i min_value = _mm256_set1_epi8(minv);
};
//inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) {
// const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
// const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
// scales[0] = MM256_SET_M128I(l_scales, l_scales);
// scales[1] = MM256_SET_M128I(h_scales, h_scales);
//}
struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
DequantizerIQ2S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
@@ -1327,11 +1322,18 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
inline __m256i load_scales(int i) {
d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
auto all = _mm_and_si128(_mm_or_si128(_mm_slli_si128(_mm_srli_epi16(tmp, 4), 8), tmp), _mm_set1_epi8(0xf));
auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf));
auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
auto shuffle = _mm_set_epi64x(0x0f070e060d050c04, 0x0b030a0209010800);
return _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, shuffle));
return _mm256_cvtepi8_epi16(scales8);
}
//inline __m256i load_scales(int i) {
// d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
// auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
// auto all = _mm_and_si128(_mm_or_si128(_mm_slli_si128(_mm_srli_epi16(tmp, 4), 8), tmp), _mm_set1_epi8(0xf));
// auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
// auto shuffle = _mm_set_epi64x(0x0f070e060d050c04, 0x0b030a0209010800);
// return _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, shuffle));
//}
inline static void prepare_scales(const __m256i& all, __m256i * scales) {
auto scales_l = _mm256_castsi256_si128(all);
auto scales_h = _mm256_extractf128_si256(all, 1);
@@ -1403,6 +1405,120 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
};
struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
DequantizerIQ2XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
constexpr static int num_blocks = 16;
inline __m256i load_scales(int i) {
d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf));
auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
return _mm256_cvtepi8_epi16(scales8);
}
inline static void prepare_scales(const __m256i& all, __m256i * scales) {
auto scales_l = _mm256_castsi256_si128(all);
auto scales_h = _mm256_extractf128_si256(all, 1);
scales[0] = MM256_SET_M128I(scales_l, scales_l);
scales[1] = MM256_SET_M128I(scales_h, scales_h);
}
inline void new_block(int i, __m256i * scales) {
prepare_scales(load_scales(i), scales);
}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accd, __m256i * scales) {
auto all_scales = load_scales(i);
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
auto bsums = q8.load_bsums(iy, i);
auto prod = _mm256_madd_epi16(all_scales, bsums);
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(-d*q8.scale(iy, i)*minv), _mm256_cvtepi32_ps(prod), accd[iy]);
}
prepare_scales(all_scales, scales);
}
struct Helper {
const __m256i mone = _mm256_set1_epi8(1);
const __m256i mask = _mm256_set1_epi64x(0x8040201008040201);
//const __m256i bhelper = _mm256_set_epi64x(0x8000008000808000, 0x0080800080000080, 0x8000008000808000, 0x0080800080000080);
const __m256i bhelper = load_bhelper();
const __m256i shuff1 = _mm256_set_epi64x(0x0606060606060606, 0x0404040404040404, 0x0202020202020202, 0x0000000000000000);
const __m256i shuff2 = _mm256_set_epi64x(0x0e0e0e0e0e0e0e0e, 0x0c0c0c0c0c0c0c0c, 0x0a0a0a0a0a0a0a0a, 0x0808080808080808);
static __m256i load_bhelper() {
static const uint8_t k_bit_helper[32] = {
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
};
return _mm256_loadu_si256((const __m256i*)k_bit_helper);
}
};
union index_t {
__m256i vec;
uint16_t val[8];
};
inline static void make4(const __m256i& data, const __m256i& mask, __m256i * values) {
index_t idx;
idx.vec = _mm256_and_si256(data, mask);
values[0] = _mm256_set_epi64x(iq2xs_grid[idx.val[ 3]], iq2xs_grid[idx.val[ 2]], iq2xs_grid[idx.val[ 1]], iq2xs_grid[idx.val[ 0]]);
values[1] = _mm256_set_epi64x(iq2xs_grid[idx.val[ 7]], iq2xs_grid[idx.val[ 6]], iq2xs_grid[idx.val[ 5]], iq2xs_grid[idx.val[ 4]]);
values[2] = _mm256_set_epi64x(iq2xs_grid[idx.val[11]], iq2xs_grid[idx.val[10]], iq2xs_grid[idx.val[ 9]], iq2xs_grid[idx.val[ 8]]);
values[3] = _mm256_set_epi64x(iq2xs_grid[idx.val[15]], iq2xs_grid[idx.val[14]], iq2xs_grid[idx.val[13]], iq2xs_grid[idx.val[12]]);
}
inline static void sign_value(const __m256i& sign_bits, const __m256i& shuffle, const __m256i& mask,
const __m256i& mone, __m256i& value) {
auto signs = _mm256_shuffle_epi8(sign_bits, shuffle);
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, mask), mask);
value = _mm256_sign_epi8(value, _mm256_or_si256(signs, mone));
}
inline static void sign_values(const __m256i& data, const Helper& helper, __m256i * values) {
auto psb1 = _mm256_srli_epi16(data, 9);
auto psb2 = _mm256_srli_epi16(data, 13);
auto psbc = _mm256_xor_si256(psb1, psb2);
auto oddb = _mm256_shuffle_epi8(helper.bhelper, psbc);
auto full = _mm256_or_si256(psb1, oddb);
auto full_l = _mm256_castsi256_si128(full);
auto full_h = _mm256_extractf128_si256(full, 1);
auto full_1 = MM256_SET_M128I(full_l, full_l);
auto full_2 = MM256_SET_M128I(full_h, full_h);
sign_value(full_1, helper.shuff1, helper.mask, helper.mone, values[0]);
sign_value(full_1, helper.shuff2, helper.mask, helper.mone, values[1]);
sign_value(full_2, helper.shuff1, helper.mask, helper.mone, values[2]);
sign_value(full_2, helper.shuff2, helper.mask, helper.mone, values[3]);
}
inline static void make4_signed(const Helper& helper, const uint16_t * qs, const __m256i& m511,
const __m256i& min_value, __m256i * values) {
auto q2 = _mm256_loadu_si256((const __m256i *)qs);
make4(q2, m511, values);
sign_values(q2, helper, values);
for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value);
}
inline static void make4(const Helper& helper, const uint16_t * qs, const __m256i& m511, __m256i * values, __m256i * q8) {
auto q2 = _mm256_loadu_si256((const __m256i *)qs);
make4(q2, m511, values);
sign_values(q2, helper, q8);
}
inline void prepare(int i, int j) {
make4_signed(helper, x[i].qs + 16*j, idx_mask, min_value, bits.values);
}
template <typename Q8>
inline void prepare(int i, int j, const Q8& q8, __m256i * q8_quants) {
for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
make4(helper, x[i].qs + 16*j, idx_mask, bits.values, q8_quants);
}
constexpr static int minv = 43;
SimpleBits bits;
Helper helper;
const __m256i idx_mask = _mm256_set1_epi16(511);
const __m256i min_value = _mm256_set1_epi8(minv);
};
//
// ============================== Legacy quants
//
@@ -1778,7 +1894,7 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
m.funcs[7] = mul_mat_qX_1_q8_1_T<Dequantizer, 8>;
}
else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S> || std::is_same_v<Dequantizer, DequantizerIQ3XXS> ||
std::is_same_v<Dequantizer, DequantizerIQ2S>) {
std::is_same_v<Dequantizer, DequantizerIQ2S> || std::is_same_v<Dequantizer, DequantizerIQ2XS>) {
m.funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
m.funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
m.funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>;
@@ -1870,6 +1986,10 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ2S>(mm);
break;
case GGML_TYPE_IQ2_XS:
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ2XS>(mm);
break;
case GGML_TYPE_Q4_0:
assert (ne00 % QK4_0 == 0);
MulMat::set_functions<Q4_0_Unpacker>(mm);