iq2_ks: WIP

This commit is contained in:
Iwan Kawrakow
2024-10-12 17:54:32 +03:00
parent 70e7b758f5
commit 15a8115fcf

View File

@@ -402,14 +402,20 @@ struct ScaleIQ4XS {
const __m128i m32 = _mm_set1_epi16(-32);
};
template <typename Block, bool per_row_scale = false>
template <typename Block, bool per_row_scale = false, bool is_f16 = false>
struct BaseDequantizer {
BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {}
inline void new_row(int ix) {
if constexpr (per_row_scale) {
const float * dptr = (const float *)((const char *)vx + bx*ix);
d = *dptr;
x = (const Block *)(dptr + 1);
if constexpr (is_f16) {
const ggml_half * dptr = (const ggml_half *)((const char *)vx + bx*ix);
d = GGML_FP16_TO_FP32(*dptr);
x = (const Block *)(dptr + 1);
} else {
const float * dptr = (const float *)((const char *)vx + bx*ix);
d = *dptr;
x = (const Block *)(dptr + 1);
}
} else {
x = (const Block *)((const char *)vx + bx*ix);
}
@@ -898,6 +904,58 @@ struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
const __m128i m8 = _mm_set1_epi8(-8);
};
struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> {
DequantizerIQ2KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
prepare(x[i].qs);
auto scales128 = make_scales(x[i].scales, x[i].extra >> 8);
auto shifts = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi8(x[i].extra), hmask), hmask), m5);
auto scales_s = _mm_mullo_epi16(scales128, _mm_cvtepi8_epi16(_mm_add_epi8(m32, shifts)));
s8k.accum_mins(scales_s, q8, i, d, accm);
auto scales256 = MM256_SET_M128I(scales128, scales128);
auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
//scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);
//scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);
//scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);
//scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
}
inline void prepare(const uint8_t * q2) {
bits.prepare(q2);
bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]);
bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]);
bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]);
bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]);
}
static inline __m512i load_values() {
static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0};
auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl);
auto val256 = MM256_SET_M128I(val128, val128);
return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
}
inline __m128i make_scales(const uint8_t * scales_l, const uint8_t scales_h) const {
const uint16_t * scales = (const uint16_t *)scales_l;
uint32_t aux32 = scales[0] | (scales[1] << 16);
auto scl = _mm_srlv_epi32(_mm_set1_epi32(aux32), shift);
scl = _mm_and_si128(_mm_shuffle_epi8(scl, shuffle), _mm_set1_epi8(0xf));
auto sch = _mm_set1_epi8(scales_h);
sch = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(sch, hmask), hmask), m16);
return _mm_cvtepi8_epi16(_mm_add_epi8(scl, sch));
}
Q2Bits bits;
Scales8K s8k;
const __m512i values;
const __m128i m16 = _mm_set1_epi8(-16);
const __m128i m5 = _mm_set1_epi8(5);
const __m128i m32 = _mm_set1_epi8(-32);
const __m128i hmask = _mm_set1_epi64x(0x8040201008040201);
const __m128i shuffle = _mm_set1_epi64x(0x0703060205010400);
const __m128i shift = _mm_set_epi32(0, 0, 4, 0);
};
struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
DequantizerIQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -64), values(load_values()) {}
template <typename Q8>
@@ -1107,8 +1165,8 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
};
struct DequantizerIQ4XXS final : public BaseDequantizer<block_iq4_ks, true> {
DequantizerIQ4XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
DequantizerIQ4KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
@@ -1740,8 +1798,8 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
const __m256i mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing
};
struct DequantizerIQ4XXS final : public BaseDequantizer<block_iq4_ks, true> {
DequantizerIQ4XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {}
struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
DequantizerIQ4KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {}
template <typename Q8>
inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
@@ -3751,7 +3809,7 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
std::is_same_v<Dequantizer, DequantizerIQ4K> ||
std::is_same_v<Dequantizer, DequantizerIQ3K> ||
std::is_same_v<Dequantizer, DequantizerIQ4XS>||
std::is_same_v<Dequantizer, DequantizerIQ4XXS>) {
std::is_same_v<Dequantizer, DequantizerIQ4KS>) {
m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;
m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;
m.funcs[2] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 3>;
@@ -3913,12 +3971,16 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
break;
case GGML_TYPE_IQ4_KS:
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ4XXS>(mm);
MulMat::set_functions<DequantizerIQ4KS>(mm);
break;
case GGML_TYPE_IQ2_K:
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ2K>(mm);
break;
case GGML_TYPE_IQ2_KS:
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ2KS>(mm);
break;
case GGML_TYPE_IQ3_K:
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ3K>(mm);
@@ -4809,9 +4871,9 @@ struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
};
struct DequantizerIQ4XXS final : public BaseDequantizer<block_iq4_ks, true> {
struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
DequantizerIQ4XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {}
DequantizerIQ4KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {}
constexpr static int num_blocks() { return 8; }
constexpr static bool should_scale_quants() { return false; }
@@ -6571,7 +6633,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
MulMat::set_functions<DequantizerIQ4XS>(m);
break;
case GGML_TYPE_IQ4_KS:
MulMat::set_functions<DequantizerIQ4XXS>(m);
MulMat::set_functions<DequantizerIQ4KS>(m);
break;
case GGML_TYPE_IQ4_K:
MulMat::set_functions<DequantizerIQ4K>(m);