iqk_mul_mat: AVX2 implementation for iq2_xxs

2.09X for PP-512 (104.7 t/s), worse than mainline for TG.
I think it needs more work.
This commit is contained in:
Kawrakow
2024-05-29 19:58:02 +03:00
parent be132341f5
commit 41391ff4b0

View File

@@ -1519,6 +1519,87 @@ struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
};
struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
constexpr static int num_blocks = 8;
inline __m128i load_scales(int i) {
d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
auto scales = _mm_set_epi16(x[i].qs[31] >> 12, x[i].qs[27] >> 12, x[i].qs[23] >> 12, x[i].qs[19] >> 12,
x[i].qs[15] >> 12, x[i].qs[11] >> 12, x[i].qs[ 7] >> 12, x[i].qs[ 3] >> 12);
return _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi16(1));
//auto tmp1 = _mm256_loadu_si256((const __m256i *)x[i].qs);
//auto tmp2 = _mm256_loadu_si256((const __m256i *)(x[i].qs+16));
//auto idx = _mm256_unpacklo_epi32(tmp1, tmp2);
//auto sas = _mm256_unpackhi_epi32(tmp1, tmp2);
}
inline void new_block(int i, __m256i * scales) {
auto sc16 = load_scales(i);
scales[0] = MM256_SET_M128I(sc16, sc16);
}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accd, __m256i * scales) {
auto sc16 = load_scales(i);
scb.accum_mins(sc16, q8, i, -minv*d, accd);
scales[0] = MM256_SET_M128I(sc16, sc16);
}
inline static void make4(const uint32_t * aux32, __m256i * values) {
const uint8_t * aux8 = (const uint8_t *)aux32;
values[0] = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[ 1]], iq2xxs_grid[aux8[ 0]]);
values[1] = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[ 9]], iq2xxs_grid[aux8[ 8]]);
values[2] = _mm256_set_epi64x(iq2xxs_grid[aux8[19]], iq2xxs_grid[aux8[18]], iq2xxs_grid[aux8[17]], iq2xxs_grid[aux8[16]]);
values[3] = _mm256_set_epi64x(iq2xxs_grid[aux8[27]], iq2xxs_grid[aux8[26]], iq2xxs_grid[aux8[25]], iq2xxs_grid[aux8[24]]);
}
inline static void sign_value(uint32_t aux32, __m256i& value) {
auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127],
keven_signs[(aux32 >> 7) & 127], keven_signs[(aux32 >> 0) & 127]);
value = _mm256_sign_epi8(value, signs);
}
inline static void sign_values(const uint32_t * aux32, __m256i * values) {
sign_value(aux32[1], values[0]);
sign_value(aux32[3], values[1]);
sign_value(aux32[5], values[2]);
sign_value(aux32[7], values[3]);
}
union Data {
__m256i vec;
uint32_t val[8];
};
inline static void make4_signed(const uint16_t * qs, const __m256i& min_value, __m256i * values) {
Data data;
data.vec = _mm256_loadu_si256((const __m256i *)qs);
make4(data.val, values);
sign_values(data.val, values);
for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value);
}
inline static void make4(const uint16_t * qs, __m256i * values, __m256i * q8) {
Data data;
data.vec = _mm256_loadu_si256((const __m256i *)qs);
make4(data.val, values);
sign_values(data.val, q8);
}
inline void prepare(int i, int j) {
make4_signed(x[i].qs + 16*j, min_value, bits.values);
}
template <typename Q8>
inline void prepare(int i, int j, const Q8& q8, __m256i * q8_quants) {
for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
make4(x[i].qs + 16*j, bits.values, q8_quants);
}
constexpr static int minv = 43;
SimpleBits bits;
Scales8KBase scb;
const __m256i min_value = _mm256_set1_epi8(minv);
};
//
// ============================== Legacy quants
//
@@ -1894,7 +1975,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
m.funcs[7] = mul_mat_qX_1_q8_1_T<Dequantizer, 8>;
}
else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S> || std::is_same_v<Dequantizer, DequantizerIQ3XXS> ||
std::is_same_v<Dequantizer, DequantizerIQ2S> || std::is_same_v<Dequantizer, DequantizerIQ2XS>) {
std::is_same_v<Dequantizer, DequantizerIQ2S> || std::is_same_v<Dequantizer, DequantizerIQ2XS> ||
std::is_same_v<Dequantizer, DequantizerIQ2XXS>) {
m.funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
m.funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
m.funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>;
@@ -1990,6 +2072,10 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ2XS>(mm);
break;
case GGML_TYPE_IQ2_XXS:
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ2XXS>(mm);
break;
case GGML_TYPE_Q4_0:
assert (ne00 % QK4_0 == 0);
MulMat::set_functions<Q4_0_Unpacker>(mm);