iq5_ks: Zen4

This commit is contained in:
Iwan Kawrakow
2025-05-15 12:34:20 +03:00
parent f0355f2522
commit 65b9d3302e

View File

@@ -2383,6 +2383,79 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
};
};
struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
DequantizerIQ5KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(values); }
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m2);
scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
s8k.accum_mins(scales_s, q8, i, d, accm);
auto scales256 = MM256_SET_M128I(scales128, scales128);
auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);
scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);
scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);
scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
prepare(x[i].qs, x[i].qh);
}
inline void prepare(const uint8_t * q4, const uint8_t * qh) {
bits.prepare64(q4);
auto h256 = _mm256_loadu_si256((const __m256i *)qh);
auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 2), 1);
auto m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1);
auto m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2);
bits.values[0] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[0]), m1, values[1], bits.values[0]);
bits.values[1] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[1]), m2, values[1], bits.values[1]);
hbits = _mm512_srli_epi16(hbits, 4);
m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1);
m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2);
bits.values[2] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[2]), m1, values[1], bits.values[2]);
bits.values[3] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[3]), m2, values[1], bits.values[3]);
// We now have in bits.valuse[0]: 0...31, 64...95
// bits.valuse[1]: 32..63, 96..127
// etc.
auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
bits.values[1] = _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]);
bits.values[0] = tmp;
tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
bits.values[3] = _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]);
bits.values[2] = tmp;
}
static void load_values(__m512i * values) {
static const uint8_t kvalues_iq5nl[32] = {
2, 14, 25, 36, 45, 54, 63, 71, 78, 85, 92, 98, 104, 110, 116, 122, 127,
133, 139, 145, 151, 157, 164, 171, 179, 187, 196, 205, 215, 225, 237, 249,
};
auto values128_1 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 0);
auto values128_2 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 1);
auto values256_1 = MM256_SET_M128I(values128_1, values128_1);
auto values256_2 = MM256_SET_M128I(values128_2, values128_2);
values[0] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_1), values256_1, 1);
values[1] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_2), values256_2, 1);
}
Q4Bits bits;
Scales8KBase s8k;
__m512i values[2];
const __m512i hmask1 = _mm512_set1_epi8(1);
const __m512i hmask2 = _mm512_set1_epi8(2);
const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
const __m128i m127 = _mm_set1_epi16(-127);
const __m128i m128 = _mm_set1_epi16(-128);
const __m128i mask = _mm_set1_epi16(254);
const __m128i m1 = _mm_set1_epi16(1);
const __m128i m2 = _mm_set1_epi16(2);
const __m512i shuffles[4] = {
_mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
_mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
_mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
_mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
};
};
struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
DequantizerIQ4KSS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
template <typename Q8>
@@ -9455,6 +9528,7 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
std::is_same_v<Dequantizer, DequantizerIQ3K> ||
std::is_same_v<Dequantizer, DequantizerIQ4XS>||
std::is_same_v<Dequantizer, DequantizerIQ4KS>||
std::is_same_v<Dequantizer, DequantizerIQ5KS>||
std::is_same_v<Dequantizer, DequantizerIQ4KSS>) {
m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;
m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;
@@ -9620,6 +9694,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ4KS>(mm);
break;
case GGML_TYPE_IQ5_KS:
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ5KS>(mm);
break;
case GGML_TYPE_IQ4_KSS:
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ4KSS>(mm);