mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-20 05:04:11 +00:00
iqk_mul_mat: AVX2 implementation for iq2_xxs
2.09X for PP-512 (104.7 t/s), worse than mainline for TG. I think it needs more work.
This commit is contained in:
@@ -1519,6 +1519,87 @@ struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
|
||||
DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
|
||||
constexpr static int num_blocks = 8;
|
||||
|
||||
inline __m128i load_scales(int i) {
|
||||
d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
|
||||
auto scales = _mm_set_epi16(x[i].qs[31] >> 12, x[i].qs[27] >> 12, x[i].qs[23] >> 12, x[i].qs[19] >> 12,
|
||||
x[i].qs[15] >> 12, x[i].qs[11] >> 12, x[i].qs[ 7] >> 12, x[i].qs[ 3] >> 12);
|
||||
return _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi16(1));
|
||||
//auto tmp1 = _mm256_loadu_si256((const __m256i *)x[i].qs);
|
||||
//auto tmp2 = _mm256_loadu_si256((const __m256i *)(x[i].qs+16));
|
||||
//auto idx = _mm256_unpacklo_epi32(tmp1, tmp2);
|
||||
//auto sas = _mm256_unpackhi_epi32(tmp1, tmp2);
|
||||
}
|
||||
|
||||
inline void new_block(int i, __m256i * scales) {
|
||||
auto sc16 = load_scales(i);
|
||||
scales[0] = MM256_SET_M128I(sc16, sc16);
|
||||
}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accd, __m256i * scales) {
|
||||
auto sc16 = load_scales(i);
|
||||
scb.accum_mins(sc16, q8, i, -minv*d, accd);
|
||||
scales[0] = MM256_SET_M128I(sc16, sc16);
|
||||
}
|
||||
|
||||
inline static void make4(const uint32_t * aux32, __m256i * values) {
|
||||
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
values[0] = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[ 1]], iq2xxs_grid[aux8[ 0]]);
|
||||
values[1] = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[ 9]], iq2xxs_grid[aux8[ 8]]);
|
||||
values[2] = _mm256_set_epi64x(iq2xxs_grid[aux8[19]], iq2xxs_grid[aux8[18]], iq2xxs_grid[aux8[17]], iq2xxs_grid[aux8[16]]);
|
||||
values[3] = _mm256_set_epi64x(iq2xxs_grid[aux8[27]], iq2xxs_grid[aux8[26]], iq2xxs_grid[aux8[25]], iq2xxs_grid[aux8[24]]);
|
||||
}
|
||||
inline static void sign_value(uint32_t aux32, __m256i& value) {
|
||||
auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127],
|
||||
keven_signs[(aux32 >> 7) & 127], keven_signs[(aux32 >> 0) & 127]);
|
||||
value = _mm256_sign_epi8(value, signs);
|
||||
}
|
||||
inline static void sign_values(const uint32_t * aux32, __m256i * values) {
|
||||
sign_value(aux32[1], values[0]);
|
||||
sign_value(aux32[3], values[1]);
|
||||
sign_value(aux32[5], values[2]);
|
||||
sign_value(aux32[7], values[3]);
|
||||
}
|
||||
|
||||
union Data {
|
||||
__m256i vec;
|
||||
uint32_t val[8];
|
||||
};
|
||||
inline static void make4_signed(const uint16_t * qs, const __m256i& min_value, __m256i * values) {
|
||||
Data data;
|
||||
data.vec = _mm256_loadu_si256((const __m256i *)qs);
|
||||
make4(data.val, values);
|
||||
sign_values(data.val, values);
|
||||
for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value);
|
||||
}
|
||||
inline static void make4(const uint16_t * qs, __m256i * values, __m256i * q8) {
|
||||
Data data;
|
||||
data.vec = _mm256_loadu_si256((const __m256i *)qs);
|
||||
make4(data.val, values);
|
||||
sign_values(data.val, q8);
|
||||
}
|
||||
|
||||
inline void prepare(int i, int j) {
|
||||
make4_signed(x[i].qs + 16*j, min_value, bits.values);
|
||||
}
|
||||
template <typename Q8>
|
||||
inline void prepare(int i, int j, const Q8& q8, __m256i * q8_quants) {
|
||||
for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
|
||||
make4(x[i].qs + 16*j, bits.values, q8_quants);
|
||||
}
|
||||
|
||||
constexpr static int minv = 43;
|
||||
|
||||
SimpleBits bits;
|
||||
Scales8KBase scb;
|
||||
const __m256i min_value = _mm256_set1_epi8(minv);
|
||||
|
||||
};
|
||||
|
||||
//
|
||||
// ============================== Legacy quants
|
||||
//
|
||||
@@ -1894,7 +1975,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
|
||||
m.funcs[7] = mul_mat_qX_1_q8_1_T<Dequantizer, 8>;
|
||||
}
|
||||
else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S> || std::is_same_v<Dequantizer, DequantizerIQ3XXS> ||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ2S> || std::is_same_v<Dequantizer, DequantizerIQ2XS>) {
|
||||
std::is_same_v<Dequantizer, DequantizerIQ2S> || std::is_same_v<Dequantizer, DequantizerIQ2XS> ||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ2XXS>) {
|
||||
m.funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
|
||||
m.funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
|
||||
m.funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>;
|
||||
@@ -1990,6 +2072,10 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerIQ2XS>(mm);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerIQ2XXS>(mm);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
assert (ne00 % QK4_0 == 0);
|
||||
MulMat::set_functions<Q4_0_Unpacker>(mm);
|
||||
|
||||
Reference in New Issue
Block a user