mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
iq2_kl: Zen4 GEMM/GEMV
Not particularly fast. I may need to think about rearranging the bits.
This commit is contained in:
@@ -337,6 +337,105 @@ struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
|
||||
};
|
||||
};
|
||||
|
||||
struct DequantizerIQ2KL final : public BaseDequantizer<block_iq2_kl, true, true> {
|
||||
DequantizerIQ2KL(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(); }
|
||||
inline __m128i make_scales(int i) const {
|
||||
//uint16_t aux[8];
|
||||
//auto h = x[i].scales_h;
|
||||
//for (int k = 0; k < 4; ++k) { aux[k+0] = (x[i].scales_l[k] & 0xf) | ((h << 4) & 0x30); h >>= 2; }
|
||||
//for (int k = 0; k < 4; ++k) { aux[k+4] = (x[i].scales_l[k] >> 4) | ((h << 4) & 0x30); h >>= 2; }
|
||||
//return _mm_sub_epi16(_mm_loadu_si128((const __m128i *)aux), _mm_set1_epi16(32));
|
||||
uint32_t aux32; std::memcpy(&aux32, x[i].scales_l, 4);
|
||||
auto scl = _mm_cvtepu8_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(aux32), _mm_set_epi32(0, 0, 4, 0)), _mm_set1_epi8(0xf)));
|
||||
// 0x000a000800060004
|
||||
auto sch = _mm_srlv_epi16(_mm_sllv_epi64(_mm_set1_epi16(x[i].scales_h), _mm_set_epi64x(0, 8)), _mm_set1_epi64x(0x000a000800060004));
|
||||
auto scales128 = _mm_sub_epi16(_mm_or_si128(scl, _mm_and_si128(sch, _mm_set1_epi16(0x30))), _mm_set1_epi16(32));
|
||||
return scales128;
|
||||
}
|
||||
template <typename Q8>
|
||||
inline void compute_block(int i, const Q8& q8, __m512 * acc) {
|
||||
auto scales128 = make_scales(i);
|
||||
auto mins128 = _mm_mullo_epi16(scales128, _mm_set1_epi16(-64));
|
||||
auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0]));
|
||||
auto scales256 = MM256_SET_M128I(scales128, scales128);
|
||||
auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
|
||||
__m512i scales[4];
|
||||
for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]);
|
||||
prepare(i);
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
auto q8s = q8.load_bsums(iy, i);
|
||||
auto prod = _mm256_madd_epi16(mins, q8s);
|
||||
auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k));
|
||||
sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]);
|
||||
}
|
||||
acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]);
|
||||
}
|
||||
}
|
||||
inline void prepare(int i) {
|
||||
__m512i ql[2], qs[4];
|
||||
__mmask64 mask[2];
|
||||
// TODO: optimize this
|
||||
for (int k = 0; k < 2; ++k) {
|
||||
auto b1 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+0);
|
||||
auto b2 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+1);
|
||||
auto c1 = MM256_SET_M128I(_mm_srli_epi16(b1, 4), b1);
|
||||
auto c2 = MM256_SET_M128I(_mm_srli_epi16(b2, 4), b2);
|
||||
ql[k] = _mm512_and_si512(m4, _mm512_inserti32x8(_mm512_castsi256_si512(c1), c2, 1));
|
||||
}
|
||||
auto h128 = _mm_loadu_si128((const __m128i *)x[i].qh);
|
||||
auto h256 = MM256_SET_M128I(_mm_srli_epi16(h128, 1), h128);
|
||||
auto h512 = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 2), 1);
|
||||
mask[0] = _mm512_cmpeq_epi8_mask(_mm512_and_si512(h512, m01), m01);
|
||||
mask[1] = _mm512_cmpeq_epi8_mask(_mm512_and_si512(h512, m10), m10);
|
||||
|
||||
for (int k = 0; k < 2; ++k) {
|
||||
// qs[0]: even quants when hbits is not set (so pair index is in 0...15)
|
||||
// qs[1]: even quants when hbits is set (so pair index is in 16...31)
|
||||
// qs[2]: odd quants when hbits is not set (so pair index is in 0...15)
|
||||
// qs[3]: odd quants when hbits is set (so pair index is in 16...31)
|
||||
// if we blend qs[0] and qs[1] with the hbit mask, we get the correct even quants -> q1
|
||||
// if we blend qs[2] and qs[3] with the hbit mask, we get the correct odd quants -> q2
|
||||
// If we convert q1 and q2 to int16_t, shift q2 left by 8 bits, and or them, we get the quants in the correct order
|
||||
for (int l = 0; l < 4; ++l) qs[l] = _mm512_shuffle_epi8(values[l], ql[k]);
|
||||
auto q1 = _mm512_mask_blend_epi8(mask[k], qs[0], qs[1]);
|
||||
auto q2 = _mm512_mask_blend_epi8(mask[k], qs[2], qs[3]);
|
||||
auto q1l = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(q1));
|
||||
auto q1h = _mm512_cvtepu8_epi16(_mm512_extracti32x8_epi32(q1, 1));
|
||||
auto q2l = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(q2));
|
||||
auto q2h = _mm512_cvtepu8_epi16(_mm512_extracti32x8_epi32(q2, 1));
|
||||
bits.values[2*k+0] = _mm512_or_si512(q1l, _mm512_slli_epi16(q2l, 8));
|
||||
bits.values[2*k+1] = _mm512_or_si512(q1h, _mm512_slli_epi16(q2h, 8));
|
||||
}
|
||||
}
|
||||
void load_values() {
|
||||
static const uint8_t k_values[64] = {
|
||||
1, 1, 24, 24, 24, 24, 41, 41, 41, 41, 41, 54, 54, 54, 54, 65, 65, 65, 65, 65, 77, 77, 77, 77, 77, 92, 92, 92, 92, 92, 111, 111,
|
||||
41, 77, 1, 54, 77, 111, 24, 41, 65, 77, 92, 1, 65, 77, 111, 41, 54, 65, 77, 92, 24, 41, 54, 65, 77, 1, 41, 65, 92, 111, 41, 77,
|
||||
};
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto v128 = _mm_loadu_si128((const __m128i *)k_values + k);
|
||||
auto v256 = MM256_SET_M128I(v128, v128);
|
||||
values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(v256), v256, 1);
|
||||
}
|
||||
}
|
||||
|
||||
struct { __m512i values[4]; } bits;
|
||||
Scales8KBase s8k;
|
||||
const __m512i m01 = _mm512_set1_epi8(0x01);
|
||||
const __m512i m10 = _mm512_set1_epi8(0x10);
|
||||
const __m512i m4 = _mm512_set1_epi8(0xf);
|
||||
__m512i values[4];
|
||||
const __m512i shuffles[4] = {
|
||||
_mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
|
||||
_mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
|
||||
_mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
|
||||
_mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
|
||||
DequantizerIQ4KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
|
||||
template <typename Q8>
|
||||
@@ -1383,7 +1482,8 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
|
||||
|
||||
set_scales_8(all_scales, j, scales);
|
||||
|
||||
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4KS> || std::is_same_v<Dequantizer, DequantizerIQ3KS>) {
|
||||
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4KS> || std::is_same_v<Dequantizer, DequantizerIQ3KS> ||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ2KL) {
|
||||
multiply_add_avx2(deq.bits, scales, j, i, q8, sumi);
|
||||
} else {
|
||||
multiply_add(deq.bits, scales, j, i, q8, sumi);
|
||||
@@ -2127,6 +2227,7 @@ static void mul_mat_iq5_ks_r4_q8_k(int n, const void * vx, size_t bx, const Data
|
||||
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2KS> ||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ2KL> ||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ3KS> ||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ4KS> ||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ5KS>) {
|
||||
@@ -2916,6 +3017,12 @@ bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_m
|
||||
case GGML_TYPE_IQ2_K:
|
||||
set_functions<DequantizerIQ2K>(kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_KL:
|
||||
set_functions<DequantizerIQ2KL>(kernels);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
func16 = mul_mat_iqX_k_q8_K_AVX512_new<DequantizerIQ2KL, 16>;
|
||||
#endif
|
||||
break;
|
||||
case GGML_TYPE_IQ3_KS:
|
||||
set_functions<DequantizerIQ3KS>(kernels);
|
||||
break;
|
||||
|
||||
@@ -424,6 +424,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy,
|
||||
return iqk_convert_iquants_q80_r8(typeA, n, vx, bx, vy, nrc_x);
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
//case GGML_TYPE_IQ2_KL:
|
||||
case GGML_TYPE_IQ3_KS:
|
||||
case GGML_TYPE_IQ3_K:
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
@@ -827,14 +828,15 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
|
||||
case GGML_TYPE_IQ3_XXS_R4:
|
||||
case GGML_TYPE_IQ3_S_R4:
|
||||
return iqk_set_kernels_iquants(ne00, typeA, typeB, mm.funcs, mm.func16);
|
||||
case GGML_TYPE_IQ3_KS:
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
case GGML_TYPE_IQ5_KS:
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ2_KL:
|
||||
case GGML_TYPE_IQ3_KS:
|
||||
case GGML_TYPE_IQ3_K:
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
case GGML_TYPE_IQ4_K:
|
||||
case GGML_TYPE_IQ5_KS:
|
||||
case GGML_TYPE_IQ5_K:
|
||||
case GGML_TYPE_IQ6_K:
|
||||
case GGML_TYPE_IQ2_K_R4:
|
||||
|
||||
Reference in New Issue
Block a user