mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
iq2_kl: NEON
Had to work around a compiler crash when using vzip2q_u8 using vqtbl2q_u8.
This commit is contained in:
@@ -3719,7 +3719,7 @@ struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true>
|
||||
};
|
||||
|
||||
struct DequantizerIQ2KL final : public BaseDequantizer<block_iq2_kl, true, true> {
|
||||
DequantizerIQ2KL(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), shifts(load_shift()) { load_values(values); }
|
||||
DequantizerIQ2KL(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), shuff(load_shuffle()), shifts(load_shift()) { load_values(values); }
|
||||
|
||||
constexpr static int num_blocks() { return 8; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
@@ -3733,37 +3733,53 @@ struct DequantizerIQ2KL final : public BaseDequantizer<block_iq2_kl, true, true>
|
||||
int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
|
||||
return scales;
|
||||
}
|
||||
inline void process_pair(uint8x16_t x, uint8x16_t * val) const {
|
||||
uint8x16x2_t aux{ vqtbl2q_s8(values[0], x), vqtbl2q_s8(values[1], x) };
|
||||
val[0] = vqtbl2q_u8(aux, shuff.val[0]);
|
||||
val[1] = vqtbl2q_u8(aux, shuff.val[1]);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
hbits = j == 0 ? vld1q_u8(x[i].qh) : vshrq_n_u8(hbits, 4);
|
||||
auto lbits = vld1q_u8_x2(x[i].qs+32*j);
|
||||
bits.b1.val[0] = vandq_u8(lbits.val[0], vdupq_n_u8(0xf));
|
||||
bits.b1.val[2] = vshrq_n_u8(lbits.val[0], 4);
|
||||
bits.b2.val[0] = vandq_u8(lbits.val[1], vdupq_n_u8(0xf));
|
||||
bits.b2.val[2] = vshrq_n_u8(lbits.val[1], 4);
|
||||
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(m10, vshlq_n_u8(hbits, 4)));
|
||||
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(m10, vshlq_n_u8(hbits, 3)));
|
||||
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(m10, vshlq_n_u8(hbits, 2)));
|
||||
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(m10, vshlq_n_u8(hbits, 1)));
|
||||
|
||||
auto t1 = vqtbl2q_s8(values[0], bits.b1.val[0]);
|
||||
auto t2 = vqtbl2q_s8(values[1], bits.b1.val[0]);
|
||||
bits.b1.val[0] = vzip1q_s8(t1, t2);
|
||||
bits.b1.val[1] = vzip2q_s8(t1, t2);
|
||||
t1 = vqtbl2q_s8(values[0], bits.b1.val[2]);
|
||||
t2 = vqtbl2q_s8(values[1], bits.b1.val[2]);
|
||||
bits.b1.val[2] = vzip1q_s8(t1, t2);
|
||||
bits.b1.val[3] = vzip2q_s8(t1, t2);
|
||||
uint8x16x4_t aux;
|
||||
aux.val[0] = vorrq_u8(vandq_u8(m10, vshlq_n_u8(hbits, 4)), vandq_u8(lbits.val[0], vdupq_n_u8(0xf)));
|
||||
aux.val[1] = vorrq_u8(vandq_u8(m10, vshlq_n_u8(hbits, 3)), vshrq_n_u8(lbits.val[0], 4));
|
||||
aux.val[2] = vorrq_u8(vandq_u8(m10, vshlq_n_u8(hbits, 2)), vandq_u8(lbits.val[1], vdupq_n_u8(0xf)));
|
||||
aux.val[3] = vorrq_u8(vandq_u8(m10, vshlq_n_u8(hbits, 1)), vshrq_n_u8(lbits.val[1], 4));
|
||||
|
||||
t1 = vqtbl2q_s8(values[0], bits.b2.val[0]);
|
||||
t2 = vqtbl2q_s8(values[1], bits.b2.val[0]);
|
||||
bits.b2.val[0] = vzip1q_s8(t1, t2);
|
||||
bits.b2.val[1] = vzip2q_s8(t1, t2);
|
||||
t1 = vqtbl2q_s8(values[0], bits.b2.val[2]);
|
||||
t2 = vqtbl2q_s8(values[1], bits.b2.val[2]);
|
||||
bits.b2.val[2] = vzip1q_s8(t1, t2);
|
||||
bits.b2.val[3] = vzip2q_s8(t1, t2);
|
||||
process_pair(aux.val[0], bits.b1.val+0);
|
||||
process_pair(aux.val[1], bits.b1.val+2);
|
||||
process_pair(aux.val[2], bits.b2.val+0);
|
||||
process_pair(aux.val[3], bits.b2.val+2);
|
||||
|
||||
hbits = vshrq_n_u8(hbits, 4);
|
||||
// The compiler crashes the moment I try to use vzip2q_u8!!!
|
||||
//bits.b1.val[0] = vandq_u8(lbits.val[0], vdupq_n_u8(0xf));
|
||||
//bits.b1.val[2] = vshrq_n_u8(lbits.val[0], 4);
|
||||
//bits.b2.val[0] = vandq_u8(lbits.val[1], vdupq_n_u8(0xf));
|
||||
//bits.b2.val[2] = vshrq_n_u8(lbits.val[1], 4);
|
||||
//bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(m10, vshlq_n_u8(hbits, 4)));
|
||||
//bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(m10, vshlq_n_u8(hbits, 3)));
|
||||
//bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(m10, vshlq_n_u8(hbits, 2)));
|
||||
//bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(m10, vshlq_n_u8(hbits, 1)));
|
||||
|
||||
//auto t1 = vqtbl2q_s8(values[0], bits.b1.val[0]);
|
||||
//auto t2 = vqtbl2q_s8(values[1], bits.b1.val[0]);
|
||||
//bits.b1.val[0] = vzip1q_s8(t1, t2);
|
||||
////bits.b1.val[1] = vzip2q_u8(t1, t2);
|
||||
//t1 = vqtbl2q_s8(values[0], bits.b1.val[2]);
|
||||
//t2 = vqtbl2q_s8(values[1], bits.b1.val[2]);
|
||||
//bits.b1.val[2] = vzip1q_s8(t1, t2);
|
||||
////bits.b1.val[3] = vzip2q_s8(t1, t2);
|
||||
|
||||
//t1 = vqtbl2q_s8(values[0], bits.b2.val[0]);
|
||||
//t2 = vqtbl2q_s8(values[1], bits.b2.val[0]);
|
||||
//bits.b2.val[0] = vzip1q_s8(t1, t2);
|
||||
////bits.b2.val[1] = vzip2q_s8(t1, t2);
|
||||
//t1 = vqtbl2q_s8(values[0], bits.b2.val[2]);
|
||||
//t2 = vqtbl2q_s8(values[1], bits.b2.val[2]);
|
||||
//bits.b2.val[2] = vzip1q_s8(t1, t2);
|
||||
////bits.b2.val[3] = vzip2q_s8(t1, t2);
|
||||
}
|
||||
static inline int16x8_t load_shift() {
|
||||
static const int16_t k_shift[8] = {4, 2, 0, -2, -4, -6, -8, -10};
|
||||
@@ -3777,10 +3793,17 @@ struct DequantizerIQ2KL final : public BaseDequantizer<block_iq2_kl, true, true>
|
||||
values[0] = vld1q_s8_x2(k_values+ 0);
|
||||
values[1] = vld1q_s8_x2(k_values+32);
|
||||
}
|
||||
static uint8x16x2_t load_shuffle() {
|
||||
static const uint8_t k_shuff[32] = {
|
||||
0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23,
|
||||
8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31
|
||||
};
|
||||
return vld1q_u8_x2(k_shuff);
|
||||
}
|
||||
|
||||
Q2bits bits;
|
||||
//struct { uint8x16x4_t b1, b2; } bits;
|
||||
struct { uint8x16x4_t b1, b2; } bits;
|
||||
uint8x16_t hbits;
|
||||
const uint8x16x2_t shuff;
|
||||
const int16x8_t shifts;
|
||||
const uint8x16_t m10 = vdupq_n_u8(0x10);
|
||||
int8x16x2_t values[2];
|
||||
|
||||
Reference in New Issue
Block a user