iq2_kl: WIP NEON

The compiler started crashing!!!
This commit is contained in:
Iwan Kawrakow
2025-07-11 15:50:47 +02:00
parent 4a3b5e3119
commit b1956cd122
2 changed files with 73 additions and 0 deletions

View File

@@ -3718,6 +3718,75 @@ struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true>
};
struct DequantizerIQ2KL final : public BaseDequantizer<block_iq2_kl, true, true> {
DequantizerIQ2KL(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), shifts(load_shift()) { load_values(values); }
constexpr static int num_blocks() { return 8; }
constexpr static bool should_scale_quants() { return false; }
template <typename Q8>
inline int32x4x2_t new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) {
uint32_t aux32; std::memcpy(&aux32, x[i].scales_l, 4);
auto scl = vand_u8(vdup_n_u8(0xf), vreinterpret_u8_u32(uint32x2_t{aux32, aux32 >> 4}));
auto sch = vandq_u16(vshlq_u16(vdupq_n_u16(x[i].scales_h), shifts), vdupq_n_u16(0x30));
auto scales16 = vsubq_s16(vreinterpretq_s16_u16(vorrq_u16(sch, vmovl_u8(scl))), vdupq_n_s16(32));
int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
return scales;
}
inline void prepare(int i, int j) {
hbits = j == 0 ? vld1q_u8(x[i].qh) : vshrq_n_u8(hbits, 4);
auto lbits = vld1q_u8_x2(x[i].qs+32*j);
bits.b1.val[0] = vandq_u8(lbits.val[0], vdupq_n_u8(0xf));
bits.b1.val[2] = vshrq_n_u8(lbits.val[0], 4);
bits.b2.val[0] = vandq_u8(lbits.val[1], vdupq_n_u8(0xf));
bits.b2.val[2] = vshrq_n_u8(lbits.val[1], 4);
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(m10, vshlq_n_u8(hbits, 4)));
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(m10, vshlq_n_u8(hbits, 3)));
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(m10, vshlq_n_u8(hbits, 2)));
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(m10, vshlq_n_u8(hbits, 1)));
auto t1 = vqtbl2q_s8(values[0], bits.b1.val[0]);
auto t2 = vqtbl2q_s8(values[1], bits.b1.val[0]);
bits.b1.val[0] = vzip1q_s8(t1, t2);
bits.b1.val[1] = vzip2q_s8(t1, t2);
t1 = vqtbl2q_s8(values[0], bits.b1.val[2]);
t2 = vqtbl2q_s8(values[1], bits.b1.val[2]);
bits.b1.val[2] = vzip1q_s8(t1, t2);
bits.b1.val[3] = vzip2q_s8(t1, t2);
t1 = vqtbl2q_s8(values[0], bits.b2.val[0]);
t2 = vqtbl2q_s8(values[1], bits.b2.val[0]);
bits.b2.val[0] = vzip1q_s8(t1, t2);
bits.b2.val[1] = vzip2q_s8(t1, t2);
t1 = vqtbl2q_s8(values[0], bits.b2.val[2]);
t2 = vqtbl2q_s8(values[1], bits.b2.val[2]);
bits.b2.val[2] = vzip1q_s8(t1, t2);
bits.b2.val[3] = vzip2q_s8(t1, t2);
hbits = vshrq_n_u8(hbits, 4);
}
static inline int16x8_t load_shift() {
static const int16_t k_shift[8] = {4, 2, 0, -2, -4, -6, -8, -10};
return vld1q_s16(k_shift);
}
static inline void load_values(int8x16x2_t * values) {
static const int8_t k_values[64] = {
-63, -63, -40, -40, -40, -40, -23, -23, -23, -23, -23, -10, -10, -10, -10, 1, 1, 1, 1, 1, 13, 13, 13, 13, 13, 28, 28, 28, 28, 28, 47, 47,
-23, 13, -63, -10, 13, 47, -40, -23, 1, 13, 28, -63, 1, 13, 47, -23, -10, 1, 13, 28, -40, -23, -10, 1, 13, -63, -23, 1, 28, 47, -23, 13,
};
values[0] = vld1q_s8_x2(k_values+ 0);
values[1] = vld1q_s8_x2(k_values+32);
}
Q2bits bits;
//struct { uint8x16x4_t b1, b2; } bits;
uint8x16_t hbits;
const int16x8_t shifts;
const uint8x16_t m10 = vdupq_n_u8(0x10);
int8x16x2_t values[2];
};
template <int nrc_y>
void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -4959,6 +5028,9 @@ bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_m
case GGML_TYPE_IQ2_K:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ2K, kernels);
break;
case GGML_TYPE_IQ2_KL:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ2KL, kernels);
break;
case GGML_TYPE_IQ3_KS:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ3KS, kernels);
break;

View File

@@ -912,6 +912,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
return iqk_set_kernels_kquants(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_KL:
case GGML_TYPE_IQ3_KS:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_KSS: