mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-21 06:59:21 +00:00
Zen4: faster PP for iq4_ks and iq5_ks
This commit is contained in:
@@ -1798,6 +1798,13 @@ struct Q4Bits {
|
||||
values[2] = _mm512_and_si512(q4bits, ml);
|
||||
values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||
}
|
||||
inline void prepare64a(const uint8_t * q4) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + k);
|
||||
values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(q4bits), _mm256_srli_epi16(q4bits, 4), 1);
|
||||
values[k] = _mm512_and_si512(values[k], ml);
|
||||
}
|
||||
}
|
||||
__m512i values[4];
|
||||
const __m512i ml = _mm512_set1_epi8(0xf);
|
||||
BlockPermuter perm;
|
||||
@@ -2377,6 +2384,29 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
|
||||
scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
|
||||
prepare(x[i].qs);
|
||||
}
|
||||
template <typename Q8>
|
||||
inline void compute_block(int i, const Q8& q8, __m512 * acc) {
|
||||
auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
|
||||
auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m4);
|
||||
scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
|
||||
auto mins128 = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
|
||||
auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0]));
|
||||
auto scales256 = MM256_SET_M128I(scales128, scales128);
|
||||
auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
|
||||
__m512i scales[4];
|
||||
for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]);
|
||||
prepare(x[i].qs);
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
auto q8s = q8.load_bsums(iy, i);
|
||||
auto prod = _mm256_madd_epi16(mins, q8s);
|
||||
auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k));
|
||||
sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]);
|
||||
}
|
||||
acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]);
|
||||
}
|
||||
}
|
||||
inline void prepare(const uint8_t * q4) {
|
||||
bits.prepare64(q4);
|
||||
// We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111
|
||||
@@ -2425,10 +2455,33 @@ struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
|
||||
scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
|
||||
prepare(x[i].qs, x[i].qh);
|
||||
}
|
||||
template <typename Q8>
|
||||
inline void compute_block(int i, const Q8& q8, __m512 * acc) {
|
||||
auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
|
||||
auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m2);
|
||||
scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
|
||||
auto mins128 = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
|
||||
auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0]));
|
||||
auto scales256 = MM256_SET_M128I(scales128, scales128);
|
||||
auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
|
||||
__m512i scales[4];
|
||||
for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]);
|
||||
prepare(x[i].qs, x[i].qh);
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
auto q8s = q8.load_bsums(iy, i);
|
||||
auto prod = _mm256_madd_epi16(mins, q8s);
|
||||
auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k));
|
||||
sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]);
|
||||
}
|
||||
acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]);
|
||||
}
|
||||
}
|
||||
inline void prepare(const uint8_t * q4, const uint8_t * qh) {
|
||||
bits.prepare64(q4);
|
||||
bits.prepare64a(q4);
|
||||
auto h256 = _mm256_loadu_si256((const __m256i *)qh);
|
||||
auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 2), 1);
|
||||
auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 1), 1);
|
||||
auto m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1);
|
||||
auto m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2);
|
||||
bits.values[0] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[0]), m1, values[1], bits.values[0]);
|
||||
@@ -2438,15 +2491,6 @@ struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
|
||||
m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2);
|
||||
bits.values[2] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[2]), m1, values[1], bits.values[2]);
|
||||
bits.values[3] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[3]), m2, values[1], bits.values[3]);
|
||||
// We now have in bits.valuse[0]: 0...31, 64...95
|
||||
// bits.valuse[1]: 32..63, 96..127
|
||||
// etc.
|
||||
auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
|
||||
bits.values[1] = _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]);
|
||||
bits.values[0] = tmp;
|
||||
tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
|
||||
bits.values[3] = _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]);
|
||||
bits.values[2] = tmp;
|
||||
}
|
||||
static void load_values(__m512i * values) {
|
||||
static const uint8_t kvalues_iq5nl[32] = {
|
||||
@@ -2465,9 +2509,7 @@ struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
|
||||
Scales8KBase s8k;
|
||||
__m512i values[2];
|
||||
const __m512i hmask1 = _mm512_set1_epi8(1);
|
||||
const __m512i hmask2 = _mm512_set1_epi8(2);
|
||||
const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
|
||||
const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
|
||||
const __m512i hmask2 = _mm512_set1_epi8(4);
|
||||
const __m128i m127 = _mm_set1_epi16(-127);
|
||||
const __m128i m128 = _mm_set1_epi16(-128);
|
||||
const __m128i mask = _mm_set1_epi16(254);
|
||||
@@ -2651,6 +2693,34 @@ static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const D
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dequantizer, int nrc_y>
|
||||
static void mul_mat_iqX_k_q8_K_AVX512_new(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
const int nb = n / QK_K;
|
||||
|
||||
Q8<nrc_y> q8(info);
|
||||
|
||||
Dequantizer deq(vx, bx);
|
||||
|
||||
__m512 accd[nrc_y];
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
|
||||
|
||||
deq.new_row(ix);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
deq.compute_block(i, q8, accd);
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, _mm512_reduce_add_ps(accd[iy]));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dequantizer>
|
||||
static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
@@ -9713,8 +9783,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
|
||||
std::is_same_v<Dequantizer, DequantizerIQ4K> ||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ3K> ||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ4XS>||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ4KS>||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ5KS>||
|
||||
//std::is_same_v<Dequantizer, DequantizerIQ4KS>||
|
||||
//std::is_same_v<Dequantizer, DequantizerIQ5KS>||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ4KSS>) {
|
||||
m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;
|
||||
m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;
|
||||
@@ -9724,6 +9794,16 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
|
||||
m.funcs[5] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 6>;
|
||||
m.funcs[6] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 7>;
|
||||
m.funcs[7] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 8>;
|
||||
} else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4KS> ||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ5KS>) {
|
||||
m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 1>;
|
||||
m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 2>;
|
||||
m.funcs[2] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 3>;
|
||||
m.funcs[3] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 4>;
|
||||
m.funcs[4] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 5>;
|
||||
m.funcs[5] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 6>;
|
||||
m.funcs[6] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 7>;
|
||||
m.funcs[7] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 8>;
|
||||
} else {
|
||||
m.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;
|
||||
m.funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>;
|
||||
|
||||
Reference in New Issue
Block a user