iq2_kl: Zen4 GEMM/GEMV

Not particularly fast. I may need to think about rearranging the bits.
This commit is contained in:
Iwan Kawrakow
2025-07-11 12:02:44 +03:00
parent 23e9033f7b
commit cd32c732f5
2 changed files with 115 additions and 6 deletions

View File

@@ -337,6 +337,105 @@ struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
};
};
struct DequantizerIQ2KL final : public BaseDequantizer<block_iq2_kl, true, true> {
DequantizerIQ2KL(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(); }
inline __m128i make_scales(int i) const {
//uint16_t aux[8];
//auto h = x[i].scales_h;
//for (int k = 0; k < 4; ++k) { aux[k+0] = (x[i].scales_l[k] & 0xf) | ((h << 4) & 0x30); h >>= 2; }
//for (int k = 0; k < 4; ++k) { aux[k+4] = (x[i].scales_l[k] >> 4) | ((h << 4) & 0x30); h >>= 2; }
//return _mm_sub_epi16(_mm_loadu_si128((const __m128i *)aux), _mm_set1_epi16(32));
uint32_t aux32; std::memcpy(&aux32, x[i].scales_l, 4);
auto scl = _mm_cvtepu8_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(aux32), _mm_set_epi32(0, 0, 4, 0)), _mm_set1_epi8(0xf)));
// 0x000a000800060004
auto sch = _mm_srlv_epi16(_mm_sllv_epi64(_mm_set1_epi16(x[i].scales_h), _mm_set_epi64x(0, 8)), _mm_set1_epi64x(0x000a000800060004));
auto scales128 = _mm_sub_epi16(_mm_or_si128(scl, _mm_and_si128(sch, _mm_set1_epi16(0x30))), _mm_set1_epi16(32));
return scales128;
}
template <typename Q8>
inline void compute_block(int i, const Q8& q8, __m512 * acc) {
auto scales128 = make_scales(i);
auto mins128 = _mm_mullo_epi16(scales128, _mm_set1_epi16(-64));
auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0]));
auto scales256 = MM256_SET_M128I(scales128, scales128);
auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
__m512i scales[4];
for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]);
prepare(i);
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
auto q8s = q8.load_bsums(iy, i);
auto prod = _mm256_madd_epi16(mins, q8s);
auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0);
for (int k = 0; k < 4; ++k) {
auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k));
sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]);
}
acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]);
}
}
inline void prepare(int i) {
__m512i ql[2], qs[4];
__mmask64 mask[2];
// TODO: optimize this
for (int k = 0; k < 2; ++k) {
auto b1 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+0);
auto b2 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+1);
auto c1 = MM256_SET_M128I(_mm_srli_epi16(b1, 4), b1);
auto c2 = MM256_SET_M128I(_mm_srli_epi16(b2, 4), b2);
ql[k] = _mm512_and_si512(m4, _mm512_inserti32x8(_mm512_castsi256_si512(c1), c2, 1));
}
auto h128 = _mm_loadu_si128((const __m128i *)x[i].qh);
auto h256 = MM256_SET_M128I(_mm_srli_epi16(h128, 1), h128);
auto h512 = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 2), 1);
mask[0] = _mm512_cmpeq_epi8_mask(_mm512_and_si512(h512, m01), m01);
mask[1] = _mm512_cmpeq_epi8_mask(_mm512_and_si512(h512, m10), m10);
for (int k = 0; k < 2; ++k) {
// qs[0]: even quants when hbits is not set (so pair index is in 0...15)
// qs[1]: even quants when hbits is set (so pair index is in 16...31)
// qs[2]: odd quants when hbits is not set (so pair index is in 0...15)
// qs[3]: odd quants when hbits is set (so pair index is in 16...31)
// if we blend qs[0] and qs[1] with the hbit mask, we get the correct even quants -> q1
// if we blend qs[2] and qs[3] with the hbit mask, we get the correct odd quants -> q2
// If we convert q1 and q2 to int16_t, shift q2 left by 8 bits, and or them, we get the quants in the correct order
for (int l = 0; l < 4; ++l) qs[l] = _mm512_shuffle_epi8(values[l], ql[k]);
auto q1 = _mm512_mask_blend_epi8(mask[k], qs[0], qs[1]);
auto q2 = _mm512_mask_blend_epi8(mask[k], qs[2], qs[3]);
auto q1l = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(q1));
auto q1h = _mm512_cvtepu8_epi16(_mm512_extracti32x8_epi32(q1, 1));
auto q2l = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(q2));
auto q2h = _mm512_cvtepu8_epi16(_mm512_extracti32x8_epi32(q2, 1));
bits.values[2*k+0] = _mm512_or_si512(q1l, _mm512_slli_epi16(q2l, 8));
bits.values[2*k+1] = _mm512_or_si512(q1h, _mm512_slli_epi16(q2h, 8));
}
}
void load_values() {
static const uint8_t k_values[64] = {
1, 1, 24, 24, 24, 24, 41, 41, 41, 41, 41, 54, 54, 54, 54, 65, 65, 65, 65, 65, 77, 77, 77, 77, 77, 92, 92, 92, 92, 92, 111, 111,
41, 77, 1, 54, 77, 111, 24, 41, 65, 77, 92, 1, 65, 77, 111, 41, 54, 65, 77, 92, 24, 41, 54, 65, 77, 1, 41, 65, 92, 111, 41, 77,
};
for (int k = 0; k < 4; ++k) {
auto v128 = _mm_loadu_si128((const __m128i *)k_values + k);
auto v256 = MM256_SET_M128I(v128, v128);
values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(v256), v256, 1);
}
}
struct { __m512i values[4]; } bits;
Scales8KBase s8k;
const __m512i m01 = _mm512_set1_epi8(0x01);
const __m512i m10 = _mm512_set1_epi8(0x10);
const __m512i m4 = _mm512_set1_epi8(0xf);
__m512i values[4];
const __m512i shuffles[4] = {
_mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
_mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
_mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
_mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
};
};
struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
DequantizerIQ4KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
template <typename Q8>
@@ -1383,7 +1482,8 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
set_scales_8(all_scales, j, scales);
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4KS> || std::is_same_v<Dequantizer, DequantizerIQ3KS>) {
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4KS> || std::is_same_v<Dequantizer, DequantizerIQ3KS> ||
std::is_same_v<Dequantizer, DequantizerIQ2KL) {
multiply_add_avx2(deq.bits, scales, j, i, q8, sumi);
} else {
multiply_add(deq.bits, scales, j, i, q8, sumi);
@@ -2127,6 +2227,7 @@ static void mul_mat_iq5_ks_r4_q8_k(int n, const void * vx, size_t bx, const Data
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
#ifdef HAVE_FANCY_SIMD
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2KS> ||
std::is_same_v<Dequantizer, DequantizerIQ2KL> ||
std::is_same_v<Dequantizer, DequantizerIQ3KS> ||
std::is_same_v<Dequantizer, DequantizerIQ4KS> ||
std::is_same_v<Dequantizer, DequantizerIQ5KS>) {
@@ -2916,6 +3017,12 @@ bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_m
case GGML_TYPE_IQ2_K:
set_functions<DequantizerIQ2K>(kernels);
break;
case GGML_TYPE_IQ2_KL:
set_functions<DequantizerIQ2KL>(kernels);
#ifdef HAVE_FANCY_SIMD
func16 = mul_mat_iqX_k_q8_K_AVX512_new<DequantizerIQ2KL, 16>;
#endif
break;
case GGML_TYPE_IQ3_KS:
set_functions<DequantizerIQ3KS>(kernels);
break;

View File

@@ -424,6 +424,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy,
return iqk_convert_iquants_q80_r8(typeA, n, vx, bx, vy, nrc_x);
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ2_K:
//case GGML_TYPE_IQ2_KL:
case GGML_TYPE_IQ3_KS:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_KSS:
@@ -827,14 +828,15 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
case GGML_TYPE_IQ3_XXS_R4:
case GGML_TYPE_IQ3_S_R4:
return iqk_set_kernels_iquants(ne00, typeA, typeB, mm.funcs, mm.func16);
case GGML_TYPE_IQ3_KS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ5_KS:
case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_KL:
case GGML_TYPE_IQ3_KS:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_KS:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ2_K_R4: