45.7 t/s -> 170.8 t/s. q3_k_r4 is at 110.3 t/s.
This commit is contained in:
Iwan Kawrakow
2025-06-23 18:38:31 +02:00
parent 6818e14184
commit 52ad57b042
2 changed files with 108 additions and 1 deletions

View File

@@ -3797,12 +3797,118 @@ void iqk_convert_q2_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int
}
}
void iqk_convert_q3_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
int nb = n/QK_K;
const block_q3_K * x8[8];
block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
uint32_t block[8];
int8x16x2_t xv[8];
uint32_t aux32[4];
auto ml = vdupq_n_s8(0x03);
auto mh = vdupq_n_s8(0x04);
union { int8x16_t vec; int8_t val[16]; } helper;
for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_q3_K *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
float d = GGML_FP16_TO_FP32(x8[k][i].d);
auto sc16 = (const uint16_t *)x8[k][i].scales;
uint32_t aux0 = sc16[0] | (sc16[1] << 16);
uint32_t aux1 = sc16[2] | (sc16[3] << 16);
uint32_t aux2 = sc16[4] | (sc16[5] << 16);
aux32[0] = (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030);
aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030);
aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030);
aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030);
helper.vec = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32));
auto hbits = vld1q_u8_x2(x8[k][i].hmask);
auto max_i16 = vdupq_n_u16(0);
for (int i128 = 0; i128 < 2; ++i128) {
auto q2bits = vld1q_u8_x2(x8[k][i].qs + 32*i128);
xv[4*i128+0].val[0] = vsubq_s8(vorrq_s8(vandq_s8(q2bits.val[0], ml), vandq_s8(vshlq_n_u8(hbits.val[0], 2), mh)), mh);
xv[4*i128+0].val[1] = vsubq_s8(vorrq_s8(vandq_s8(q2bits.val[1], ml), vandq_s8(vshlq_n_u8(hbits.val[1], 2), mh)), mh);
xv[4*i128+1].val[0] = vsubq_s8(vorrq_s8(vandq_s8(vshrq_n_u8(q2bits.val[0], 2), ml), vandq_s8(vshlq_n_u8(hbits.val[0], 1), mh)), mh);
xv[4*i128+1].val[1] = vsubq_s8(vorrq_s8(vandq_s8(vshrq_n_u8(q2bits.val[1], 2), ml), vandq_s8(vshlq_n_u8(hbits.val[1], 1), mh)), mh);
xv[4*i128+2].val[0] = vsubq_s8(vorrq_s8(vandq_s8(vshrq_n_u8(q2bits.val[0], 4), ml), vandq_s8(hbits.val[0], mh)), mh);
xv[4*i128+2].val[1] = vsubq_s8(vorrq_s8(vandq_s8(vshrq_n_u8(q2bits.val[1], 4), ml), vandq_s8(hbits.val[1], mh)), mh);
xv[4*i128+3].val[0] = vsubq_s8(vorrq_s8(vshrq_n_u8(q2bits.val[0], 6), vandq_s8(vshrq_n_u8(hbits.val[0], 1), mh)), mh);
xv[4*i128+3].val[1] = vsubq_s8(vorrq_s8(vshrq_n_u8(q2bits.val[1], 6), vandq_s8(vshrq_n_u8(hbits.val[1], 1), mh)), mh);
hbits.val[0] = vshrq_n_u8(hbits.val[0], 4);
hbits.val[1] = vshrq_n_u8(hbits.val[1], 4);
for (int l = 0; l < 4; ++l) {
auto s1 = vdup_n_s8(helper.val[8*i128+2*l+0]);
auto s2 = vdup_n_s8(helper.val[8*i128+2*l+1]);
auto q16_1 = vmull_s8(s1, vget_low_s8 (xv[4*i128+l].val[0]));
auto q16_2 = vmull_s8(s1, vget_high_s8(xv[4*i128+l].val[0]));
auto q16_3 = vmull_s8(s2, vget_low_s8 (xv[4*i128+l].val[1]));
auto q16_4 = vmull_s8(s2, vget_high_s8(xv[4*i128+l].val[1]));
auto max1 = vmaxq_s16(vabsq_s16(q16_1), vabsq_s16(q16_2));
auto max2 = vmaxq_s16(vabsq_s16(q16_3), vabsq_s16(q16_4));
max_i16 = vmaxq_s16(max_i16, vmaxq_s16(max1, max2));
}
}
auto imax16 = vmaxvq_s16(max_i16);
bool needs_scaling = true;
float dnew = float(imax16) / 127;
if (dnew < 1.f) {
dnew = 1.f; needs_scaling = false;
}
d *= dnew;
y[i].d[k] = GGML_FP32_TO_FP16(d);
auto scale = vdupq_n_f32(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f);
for (int ib32 = 0; ib32 < 8; ++ib32) {
auto s1 = vdup_n_s8(helper.val[2*ib32+0]);
auto s2 = vdup_n_s8(helper.val[2*ib32+1]);
auto q16_1 = vmull_s8(s1, vget_low_s8 (xv[ib32].val[0]));
auto q16_2 = vmull_s8(s1, vget_high_s8(xv[ib32].val[0]));
auto q16_3 = vmull_s8(s2, vget_low_s8 (xv[ib32].val[1]));
auto q16_4 = vmull_s8(s2, vget_high_s8(xv[ib32].val[1]));
if (needs_scaling) {
int32x4x4_t i1{vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (q16_1))))),
vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16_1))))),
vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (q16_2))))),
vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16_2)))))};
int32x4x4_t i2{vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (q16_3))))),
vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16_3))))),
vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (q16_4))))),
vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16_4)))))};
int16x8x4_t i3{vcombine_s16(vmovn_s32(i1.val[0]), vmovn_s32(i1.val[1])),
vcombine_s16(vmovn_s32(i1.val[2]), vmovn_s32(i1.val[3])),
vcombine_s16(vmovn_s32(i2.val[0]), vmovn_s32(i2.val[1])),
vcombine_s16(vmovn_s32(i2.val[2]), vmovn_s32(i2.val[3]))};
vst1q_s8((int8_t *)block + 0, vcombine_s8(vmovn_s16(i3.val[0]), vmovn_s16(i3.val[1])));
vst1q_s8((int8_t *)block + 16, vcombine_s8(vmovn_s16(i3.val[2]), vmovn_s16(i3.val[3])));
} else {
vst1q_s8((int8_t *)block + 0, vcombine_s8(vmovn_s16(q16_1), vmovn_s16(q16_2)));
vst1q_s8((int8_t *)block + 16, vcombine_s8(vmovn_s16(q16_3), vmovn_s16(q16_4)));
}
auto qs = (uint32_t *)y[i].qs + 64*ib32;
for (int l = 0; l < 8; ++l) {
qs[8*l + k] = block[l];
}
}
}
}
y += nb;
}
}
}
bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
switch (ggml_type(type)) {
case GGML_TYPE_Q2_K: iqk_convert_q2_k_q8_k_r8(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_Q3_K: iqk_convert_q3_k_q8_k_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q3_K: iqk_convert_q3_k_q8_k_r8(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_Q4_K: iqk_convert_q4_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_Q5_K: iqk_convert_q5_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_Q6_K: iqk_convert_q6_k_q8_0_r8(n, vx, bx, vy, nrc_x); break;

View File

@@ -272,6 +272,7 @@ struct MulMat {
#else
switch (type) {
case GGML_TYPE_Q2_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_Q3_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;