iq3_ks: ARM_NEON

This commit is contained in:
Iwan Kawrakow
2024-10-09 20:04:56 +03:00
parent 7c966a5eb4
commit 42d85c58fb

View File

@@ -4811,6 +4811,64 @@ struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
};
struct DequantizerIQ3KS final : public BaseDequantizer<block_iq3_ks, true> {
DequantizerIQ3KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(load_values()) {}
constexpr static int num_blocks() { return 8; }
constexpr static bool should_scale_quants() { return false; }
template <typename Q8>
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
(void)q8;
(void)acc;
auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(vld1_u8(x[i].scales)), mask)), m127);
int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
return scales;
}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs+32*j);
if (j == 0) {
hbits = vld1q_u8_x2(x[i].qh);
}
else {
hbits.val[0] = vshrq_n_u8(hbits.val[0], 4);
hbits.val[1] = vshrq_n_u8(hbits.val[1], 4);
}
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hmask));
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hmask));
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hmask));
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hmask));
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], hmask));
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], hmask));
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 1), hmask));
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 1), hmask));
// bits.b1 is 0....63
// bits.b2 is 64..127
bits.b1.val[0] = vqtbl1q_s8(values.val[x[i].scales[4*j+0] & 1], bits.b1.val[0]);
bits.b1.val[1] = vqtbl1q_s8(values.val[x[i].scales[4*j+0] & 1], bits.b1.val[1]);
bits.b1.val[2] = vqtbl1q_s8(values.val[x[i].scales[4*j+1] & 1], bits.b1.val[2]);
bits.b1.val[3] = vqtbl1q_s8(values.val[x[i].scales[4*j+1] & 1], bits.b1.val[3]);
bits.b2.val[0] = vqtbl1q_s8(values.val[x[i].scales[4*j+2] & 1], bits.b2.val[0]);
bits.b2.val[1] = vqtbl1q_s8(values.val[x[i].scales[4*j+2] & 1], bits.b2.val[1]);
bits.b2.val[2] = vqtbl1q_s8(values.val[x[i].scales[4*j+3] & 1], bits.b2.val[2]);
bits.b2.val[3] = vqtbl1q_s8(values.val[x[i].scales[4*j+3] & 1], bits.b2.val[3]);
}
static int8x16x2_t load_values() {
int8x8_t val1 = vld1_s8(iq3nl_values);
int8x8_t val2 = vld1_s8(iq3nl_values+8);
int8x16x2_t result = { vcombine_s8(val1, val1), vcombine_s8(val2, val2) };
return result;
}
Q2bits bits;
uint8x16x2_t hbits;
const int8x16x2_t values;
const uint8x16_t hmask = vdupq_n_u8(4);
const uint16x8_t mask = vdupq_n_u16(254);
const int16x8_t m127 = vdupq_n_s16(-127);
};
struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
static int8x16_t load_values() {
@@ -6647,6 +6705,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_IQ3_K:
MulMat::set_functions<DequantizerIQ3K>(m);
break;
case GGML_TYPE_IQ3_KS:
MulMat::set_functions<DequantizerIQ3KS>(m);
break;
case GGML_TYPE_IQ2_XXS:
MulMat::set_functions<DequantizerIQ2XXS>(m);
break;