mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-01 03:41:53 +00:00
iq4_k: AVX512 implementation
For LLaMA-3.1-8B we get PP-512 = 182.6 t/s, TG-128 = 13.6 t/s, so almost the same as q4_K_S.
This commit is contained in:
@@ -571,8 +571,16 @@ struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
|
||||
Scales8K s8k;
|
||||
};
|
||||
|
||||
__m512i load_iq4nl_values() {
|
||||
static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};
|
||||
auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
|
||||
auto val256 = MM256_SET_M128I(val128, val128);
|
||||
return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
|
||||
}
|
||||
|
||||
|
||||
struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
|
||||
DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}
|
||||
DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values()) {}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
@@ -584,12 +592,6 @@ struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
|
||||
scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
|
||||
scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
|
||||
}
|
||||
static __m512i load_values() {
|
||||
static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};
|
||||
auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
|
||||
auto val256 = MM256_SET_M128I(val128, val128);
|
||||
return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
|
||||
}
|
||||
inline void prepare(const uint8_t * q4) {
|
||||
bits.prepare64(q4);
|
||||
// We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111
|
||||
@@ -740,6 +742,70 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
|
||||
DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values()) {}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
prepare(x[i].qs);
|
||||
auto scales8 = make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h);
|
||||
auto extra128 = _mm_set1_epi16(x[i].extra);
|
||||
extra128 = _mm_cmpeq_epi8(_mm_and_si128(extra128, emask), emask);
|
||||
extra128 = _mm_and_si128(extra128, e4);
|
||||
extra128 = _mm_shuffle_epi8(extra128, eshuffle);
|
||||
auto scales16 = _mm256_mullo_epi16(_mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, hshuff)),
|
||||
_mm256_add_epi16(_mm256_set1_epi16(-128), _mm256_cvtepi8_epi16(extra128)));
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
const __m256i prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i));
|
||||
accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
|
||||
}
|
||||
scales16 = MM256_SET_M128I(scales8, scales8);
|
||||
scales[0] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle1));
|
||||
scales[1] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle2));
|
||||
}
|
||||
inline void prepare(const uint8_t * q4) {
|
||||
bits.prepare64(q4);
|
||||
// We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111
|
||||
// bits.valuse[1]: 16..31, 48...63, 80...95, 112..127
|
||||
// etc.
|
||||
auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
|
||||
bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));
|
||||
bits.values[0] = _mm512_shuffle_epi8(values, tmp);
|
||||
tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
|
||||
bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));
|
||||
bits.values[2] = _mm512_shuffle_epi8(values, tmp);
|
||||
}
|
||||
__m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const {
|
||||
uint64_t aux64;
|
||||
memcpy(&aux64, scales_l, 8);
|
||||
auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl);
|
||||
const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16);
|
||||
auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh);
|
||||
auto sch = _mm_shuffle_epi8(aux, hshuff);
|
||||
return _mm_add_epi8(_mm_or_si128(scl, sch), m32);
|
||||
}
|
||||
//static __m256i load_shuffle(int i) {
|
||||
// static const uint64_t k_shuffles[8] = {0x0202020200000000, 0x0a0a0a0a08080808, 0x0303030301010101, 0x0b0b0b0b09090909,
|
||||
// 0x0606060604040404, 0x0e0e0e0e0c0c0c0c, 0x0707070705050505, 0x0f0f0f0f0d0d0d0d};
|
||||
// return _mm256_loadu_si256((const __m256i *)k_shuffles + i);
|
||||
//}
|
||||
|
||||
Q4Bits bits;
|
||||
const __m512i values;
|
||||
const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
|
||||
const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
|
||||
const __m256i shuffle1 = _mm256_set_epi64x(0x0b0b0b0b09090909, 0x0303030301010101, 0x0a0a0a0a08080808, 0x0202020200000000);
|
||||
const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0d0d0d0d, 0x0707070705050505, 0x0e0e0e0e0c0c0c0c, 0x0606060604040404);
|
||||
const __m128i maskl = _mm_set1_epi8(0xf);
|
||||
const __m128i maskh = _mm_set1_epi8(0x30);
|
||||
const __m128i hshuff = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
|
||||
const __m128i m32 = _mm_set1_epi8(-32);
|
||||
const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101);
|
||||
const __m128i e4 = _mm_set1_epi8(4);
|
||||
const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200);
|
||||
|
||||
};
|
||||
|
||||
template <typename Q8>
|
||||
inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {
|
||||
const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));
|
||||
@@ -2783,6 +2849,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerIQ4XS>(mm);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_K:
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerIQ4K>(mm);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerIQ3S>(mm);
|
||||
|
||||
Reference in New Issue
Block a user