mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
iqk_mul_mat: faster q3_K TG
We get 31 t/s up from 26 t/s, but we need to treat PP differently from TG, else we get a ~10% drop in PP performance.
This commit is contained in:
@@ -1763,6 +1763,7 @@ struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
|
||||
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
h.bits = vld1q_u8_x2(x[i].hmask);
|
||||
mask = vdupq_n_u8(0x01);
|
||||
const uint16_t * sc16 = (const uint16_t *)x[i].scales;
|
||||
uint32_t aux0 = sc16[0] | (sc16[1] << 16);
|
||||
uint32_t aux1 = sc16[2] | (sc16[3] << 16);
|
||||
@@ -1771,19 +1772,43 @@ struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
|
||||
aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030);
|
||||
aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030);
|
||||
aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030);
|
||||
return process_scales_mins_16(vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)), q8, acc, i, -4.f*d);
|
||||
auto scales8 = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32));
|
||||
if (nrc > 1) {
|
||||
return process_scales_mins_16(scales8, q8, acc, i, -4.f*d);
|
||||
}
|
||||
int16x8x2_t scales16;
|
||||
scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
|
||||
scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
|
||||
return make_wider(scales16);
|
||||
}
|
||||
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs+32*j);
|
||||
h.apply(bits.b1, bits.b2, j == 0);
|
||||
if (nrc > 1) {
|
||||
h.apply(bits.b1, bits.b2, j == 0);
|
||||
} else {
|
||||
auto minus4 = vdupq_n_u8(0xfc);
|
||||
auto zero = vdupq_n_u8(0);
|
||||
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
|
||||
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
|
||||
mask = vshlq_n_u8(mask, 1);
|
||||
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
|
||||
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
|
||||
mask = vshlq_n_u8(mask, 1);
|
||||
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
|
||||
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
|
||||
mask = vshlq_n_u8(mask, 1);
|
||||
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
|
||||
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
|
||||
mask = vshlq_n_u8(mask, 1);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t aux32[4];
|
||||
|
||||
Q2bits bits;
|
||||
|
||||
const uint8x16_t mhb = vdupq_n_u8(0x04);
|
||||
uint8x16_t mask;
|
||||
HighBit3 h;
|
||||
|
||||
float d;
|
||||
|
||||
Reference in New Issue
Block a user