Also iq4_xs belongs to k-quants

This commit is contained in:
Iwan Kawrakow
2025-05-18 18:14:45 +03:00
parent f4ab917e9e
commit 312413694f
2 changed files with 61 additions and 63 deletions

View File

@@ -2096,6 +2096,63 @@ struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
};
struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
static int8x16_t load_values() {
static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
return vld1q_s8(iq4nl_values);
}
DequantizerIQ4XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(load_values()) {}
constexpr static int num_blocks() { return 8; }
constexpr static bool should_scale_quants() { return false; }
inline void new_row(int ix) { x = (const block_iq4_xs *)((const char *)vx + bx*ix); }
template <typename Q8>
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
(void)q8;
(void)acc;
d = GGML_FP16_TO_FP32(x[i].d);
const uint16_t scales_h = x[i].scales_h;
const uint16_t * scales_l = (const uint16_t *)x[i].scales_l;
aux32[0] = scales_l[0] | (scales_l[1] << 16);
aux32[1] = aux32[0] >> 4;
// scl is ordered as 0, 2, 4, 6, 1, 3, 5, 7
uint8x8_t scl8 = vand_u8(vld1_u8((const uint8_t *)aux32), vdup_n_u8(0xf));
uint16_t * aux16 = (uint16_t *)aux32;
aux16[0] = scales_h << 4; aux16[1] = scales_h << 2; aux16[2] = scales_h; aux16[3] = scales_h >> 2;
// sch is ordered as 0, 4, 1, 5, 2, 6, 3, 7
uint8x8_t sch8 = vand_u8(vld1_u8((const uint8_t *)aux16), vdup_n_u8(0x30));
int8x8_t scales8 = vadd_s8(vreinterpret_s8_u8(vorr_u8(scl8, vtbl1_u8(sch8, vreinterpret_u8_u32(hshuff)))), vdup_n_s8(-32));
// shuffle 0, 2, 4, 6, 1, 3, 5, 7 -> 0, 1, 2, 3, 4, 5, 6, 7
scales8 = vtbl1_s8(scales8, vreinterpret_s8_u32(hshuff));
int16x8_t scales16 = vmovl_s8(scales8);
int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
return scales;
}
inline void prepare(int i, int j) {
bits.prepare16(x[i].qs+64*j);
//if (nrc == 1) {
// bits.prepare16_v2(x[i].qs+64*j);
//} else {
// bits.prepare16(x[i].qs+64*j);
//}
for (int k = 0; k < 4; ++k) {
bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b1.val[k]));
bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b2.val[k]));
}
}
Q4bits bits;
const int8x16_t values;
uint32_t aux32[2];
constexpr static uint32x2_t hshuff = {0x05010400, 0x07030602};
};
}
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, [[maybe_unused]] mul_mat_t& func16) {
@@ -2128,9 +2185,9 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
case GGML_TYPE_Q6_K:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ6K, kernels)
break;
// case GGML_TYPE_IQ4_XS:
// set_functions<DequantizerIQ4XS>(kernels);
// break;
case GGML_TYPE_IQ4_XS:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ4XS, kernels)
break;
// case GGML_TYPE_Q2_K_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q2_k_r4_q8_k, kernels)
// break;

View File

@@ -1128,63 +1128,6 @@ struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
};
struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
static int8x16_t load_values() {
static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
return vld1q_s8(iq4nl_values);
}
DequantizerIQ4XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(load_values()) {}
constexpr static int num_blocks() { return 8; }
constexpr static bool should_scale_quants() { return false; }
inline void new_row(int ix) { x = (const block_iq4_xs *)((const char *)vx + bx*ix); }
template <typename Q8>
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
(void)q8;
(void)acc;
d = GGML_FP16_TO_FP32(x[i].d);
const uint16_t scales_h = x[i].scales_h;
const uint16_t * scales_l = (const uint16_t *)x[i].scales_l;
aux32[0] = scales_l[0] | (scales_l[1] << 16);
aux32[1] = aux32[0] >> 4;
// scl is ordered as 0, 2, 4, 6, 1, 3, 5, 7
uint8x8_t scl8 = vand_u8(vld1_u8((const uint8_t *)aux32), vdup_n_u8(0xf));
uint16_t * aux16 = (uint16_t *)aux32;
aux16[0] = scales_h << 4; aux16[1] = scales_h << 2; aux16[2] = scales_h; aux16[3] = scales_h >> 2;
// sch is ordered as 0, 4, 1, 5, 2, 6, 3, 7
uint8x8_t sch8 = vand_u8(vld1_u8((const uint8_t *)aux16), vdup_n_u8(0x30));
int8x8_t scales8 = vadd_s8(vreinterpret_s8_u8(vorr_u8(scl8, vtbl1_u8(sch8, vreinterpret_u8_u32(hshuff)))), vdup_n_s8(-32));
// shuffle 0, 2, 4, 6, 1, 3, 5, 7 -> 0, 1, 2, 3, 4, 5, 6, 7
scales8 = vtbl1_s8(scales8, vreinterpret_s8_u32(hshuff));
int16x8_t scales16 = vmovl_s8(scales8);
int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
return scales;
}
inline void prepare(int i, int j) {
bits.prepare16(x[i].qs+64*j);
//if (nrc == 1) {
// bits.prepare16_v2(x[i].qs+64*j);
//} else {
// bits.prepare16(x[i].qs+64*j);
//}
for (int k = 0; k < 4; ++k) {
bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b1.val[k]));
bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b2.val[k]));
}
}
Q4bits bits;
const int8x16_t values;
uint32_t aux32[2];
constexpr static uint32x2_t hshuff = {0x05010400, 0x07030602};
};
struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
DequantizerIQ4KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {}
@@ -4324,10 +4267,8 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
return iqk_set_kernels_kquants(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_IQ4_XS:
MulMat::set_functions<DequantizerIQ4XS>(m);
break;
return iqk_set_kernels_kquants(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_IQ4_KS:
MulMat::set_functions<DequantizerIQ4KS>(m);
break;