iq2_kl: NEON

Had to work around a compiler crash when using vzip2q_u8 using
vqtbl2q_u8.
This commit is contained in:
Iwan Kawrakow
2025-07-11 16:32:59 +02:00
parent b1956cd122
commit dd1c2a14d7

View File

@@ -3719,7 +3719,7 @@ struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true>
};
struct DequantizerIQ2KL final : public BaseDequantizer<block_iq2_kl, true, true> {
DequantizerIQ2KL(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), shifts(load_shift()) { load_values(values); }
DequantizerIQ2KL(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), shuff(load_shuffle()), shifts(load_shift()) { load_values(values); }
constexpr static int num_blocks() { return 8; }
constexpr static bool should_scale_quants() { return false; }
@@ -3733,37 +3733,53 @@ struct DequantizerIQ2KL final : public BaseDequantizer<block_iq2_kl, true, true>
int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
return scales;
}
inline void process_pair(uint8x16_t x, uint8x16_t * val) const {
uint8x16x2_t aux{ vqtbl2q_s8(values[0], x), vqtbl2q_s8(values[1], x) };
val[0] = vqtbl2q_u8(aux, shuff.val[0]);
val[1] = vqtbl2q_u8(aux, shuff.val[1]);
}
inline void prepare(int i, int j) {
hbits = j == 0 ? vld1q_u8(x[i].qh) : vshrq_n_u8(hbits, 4);
auto lbits = vld1q_u8_x2(x[i].qs+32*j);
bits.b1.val[0] = vandq_u8(lbits.val[0], vdupq_n_u8(0xf));
bits.b1.val[2] = vshrq_n_u8(lbits.val[0], 4);
bits.b2.val[0] = vandq_u8(lbits.val[1], vdupq_n_u8(0xf));
bits.b2.val[2] = vshrq_n_u8(lbits.val[1], 4);
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(m10, vshlq_n_u8(hbits, 4)));
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(m10, vshlq_n_u8(hbits, 3)));
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(m10, vshlq_n_u8(hbits, 2)));
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(m10, vshlq_n_u8(hbits, 1)));
auto t1 = vqtbl2q_s8(values[0], bits.b1.val[0]);
auto t2 = vqtbl2q_s8(values[1], bits.b1.val[0]);
bits.b1.val[0] = vzip1q_s8(t1, t2);
bits.b1.val[1] = vzip2q_s8(t1, t2);
t1 = vqtbl2q_s8(values[0], bits.b1.val[2]);
t2 = vqtbl2q_s8(values[1], bits.b1.val[2]);
bits.b1.val[2] = vzip1q_s8(t1, t2);
bits.b1.val[3] = vzip2q_s8(t1, t2);
uint8x16x4_t aux;
aux.val[0] = vorrq_u8(vandq_u8(m10, vshlq_n_u8(hbits, 4)), vandq_u8(lbits.val[0], vdupq_n_u8(0xf)));
aux.val[1] = vorrq_u8(vandq_u8(m10, vshlq_n_u8(hbits, 3)), vshrq_n_u8(lbits.val[0], 4));
aux.val[2] = vorrq_u8(vandq_u8(m10, vshlq_n_u8(hbits, 2)), vandq_u8(lbits.val[1], vdupq_n_u8(0xf)));
aux.val[3] = vorrq_u8(vandq_u8(m10, vshlq_n_u8(hbits, 1)), vshrq_n_u8(lbits.val[1], 4));
t1 = vqtbl2q_s8(values[0], bits.b2.val[0]);
t2 = vqtbl2q_s8(values[1], bits.b2.val[0]);
bits.b2.val[0] = vzip1q_s8(t1, t2);
bits.b2.val[1] = vzip2q_s8(t1, t2);
t1 = vqtbl2q_s8(values[0], bits.b2.val[2]);
t2 = vqtbl2q_s8(values[1], bits.b2.val[2]);
bits.b2.val[2] = vzip1q_s8(t1, t2);
bits.b2.val[3] = vzip2q_s8(t1, t2);
process_pair(aux.val[0], bits.b1.val+0);
process_pair(aux.val[1], bits.b1.val+2);
process_pair(aux.val[2], bits.b2.val+0);
process_pair(aux.val[3], bits.b2.val+2);
hbits = vshrq_n_u8(hbits, 4);
// The compiler crashes the moment I try to use vzip2q_u8!!!
//bits.b1.val[0] = vandq_u8(lbits.val[0], vdupq_n_u8(0xf));
//bits.b1.val[2] = vshrq_n_u8(lbits.val[0], 4);
//bits.b2.val[0] = vandq_u8(lbits.val[1], vdupq_n_u8(0xf));
//bits.b2.val[2] = vshrq_n_u8(lbits.val[1], 4);
//bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(m10, vshlq_n_u8(hbits, 4)));
//bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(m10, vshlq_n_u8(hbits, 3)));
//bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(m10, vshlq_n_u8(hbits, 2)));
//bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(m10, vshlq_n_u8(hbits, 1)));
//auto t1 = vqtbl2q_s8(values[0], bits.b1.val[0]);
//auto t2 = vqtbl2q_s8(values[1], bits.b1.val[0]);
//bits.b1.val[0] = vzip1q_s8(t1, t2);
////bits.b1.val[1] = vzip2q_u8(t1, t2);
//t1 = vqtbl2q_s8(values[0], bits.b1.val[2]);
//t2 = vqtbl2q_s8(values[1], bits.b1.val[2]);
//bits.b1.val[2] = vzip1q_s8(t1, t2);
////bits.b1.val[3] = vzip2q_s8(t1, t2);
//t1 = vqtbl2q_s8(values[0], bits.b2.val[0]);
//t2 = vqtbl2q_s8(values[1], bits.b2.val[0]);
//bits.b2.val[0] = vzip1q_s8(t1, t2);
////bits.b2.val[1] = vzip2q_s8(t1, t2);
//t1 = vqtbl2q_s8(values[0], bits.b2.val[2]);
//t2 = vqtbl2q_s8(values[1], bits.b2.val[2]);
//bits.b2.val[2] = vzip1q_s8(t1, t2);
////bits.b2.val[3] = vzip2q_s8(t1, t2);
}
static inline int16x8_t load_shift() {
static const int16_t k_shift[8] = {4, 2, 0, -2, -4, -6, -8, -10};
@@ -3777,10 +3793,17 @@ struct DequantizerIQ2KL final : public BaseDequantizer<block_iq2_kl, true, true>
values[0] = vld1q_s8_x2(k_values+ 0);
values[1] = vld1q_s8_x2(k_values+32);
}
static uint8x16x2_t load_shuffle() {
static const uint8_t k_shuff[32] = {
0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23,
8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31
};
return vld1q_u8_x2(k_shuff);
}
Q2bits bits;
//struct { uint8x16x4_t b1, b2; } bits;
struct { uint8x16x4_t b1, b2; } bits;
uint8x16_t hbits;
const uint8x16x2_t shuff;
const int16x8_t shifts;
const uint8x16_t m10 = vdupq_n_u8(0x10);
int8x16x2_t values[2];