iqk_mul_mat: faster q3_K TG

We get 31 t/s up from 26 t/s, but we need to treat
PP differently from TG, else we get a ~10% drop in
PP performance.
This commit is contained in:
Iwan Kawrakow
2024-05-27 11:05:44 +02:00
parent 19c578b413
commit b51922530f

View File

@@ -1763,6 +1763,7 @@ struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
h.bits = vld1q_u8_x2(x[i].hmask);
mask = vdupq_n_u8(0x01);
const uint16_t * sc16 = (const uint16_t *)x[i].scales;
uint32_t aux0 = sc16[0] | (sc16[1] << 16);
uint32_t aux1 = sc16[2] | (sc16[3] << 16);
@@ -1771,19 +1772,43 @@ struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030);
aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030);
aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030);
return process_scales_mins_16(vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)), q8, acc, i, -4.f*d);
auto scales8 = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32));
if (nrc > 1) {
return process_scales_mins_16(scales8, q8, acc, i, -4.f*d);
}
int16x8x2_t scales16;
scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
return make_wider(scales16);
}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs+32*j);
h.apply(bits.b1, bits.b2, j == 0);
if (nrc > 1) {
h.apply(bits.b1, bits.b2, j == 0);
} else {
auto minus4 = vdupq_n_u8(0xfc);
auto zero = vdupq_n_u8(0);
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
mask = vshlq_n_u8(mask, 1);
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
mask = vshlq_n_u8(mask, 1);
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
mask = vshlq_n_u8(mask, 1);
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
mask = vshlq_n_u8(mask, 1);
}
}
uint32_t aux32[4];
Q2bits bits;
const uint8x16_t mhb = vdupq_n_u8(0x04);
uint8x16_t mask;
HighBit3 h;
float d;