85.7 t/s -> 168.1 t/s. q2_k_r4 is at 111.2 t/s.
This commit is contained in:
Iwan Kawrakow
2025-06-23 17:50:26 +02:00
parent 548a5f3f0d
commit 6818e14184
2 changed files with 101 additions and 7 deletions

View File

@@ -3703,20 +3703,113 @@ void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& i
}
}
typedef struct {
ggml_half d[16];
int8_t qs[8*QK8_1];
} block_q8_1_r8;
void iqk_convert_q2_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
int nb = n/QK_K;
const block_q2_K * x8[8];
block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
float32_t f_values[QK_K];
uint32_t block[8];
int8x16x2_t xv[4];
auto ml = vdupq_n_u8(0x03);
for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_q2_K *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
auto vd = vdupq_n_f32(GGML_FP16_TO_FP32(x8[k][i].d));
auto vm = vdupq_n_f32(-GGML_FP16_TO_FP32(x8[k][i].dmin));
auto block_max = vdupq_n_f32(0);
for (int i128 = 0; i128 < 2; ++i128) {
auto bits = vld1q_u8_x2(x8[k][i].qs+32*i128);
xv[0].val[0] = vandq_u8(bits.val[0], ml);
xv[0].val[1] = vandq_u8(bits.val[1], ml);
xv[1].val[0] = vandq_u8(vshrq_n_u8(bits.val[0], 2), ml);
xv[1].val[1] = vandq_u8(vshrq_n_u8(bits.val[1], 2), ml);
xv[2].val[0] = vandq_u8(vshrq_n_u8(bits.val[0], 4), ml);
xv[2].val[1] = vandq_u8(vshrq_n_u8(bits.val[1], 4), ml);
xv[3].val[0] = vshrq_n_u8(bits.val[0], 6);
xv[3].val[1] = vshrq_n_u8(bits.val[1], 6);
for (int l = 0; l < 4; ++l) {
auto d1 = vdupq_n_s8(x8[k][i].scales[8*i128 + 2*l + 0] & 0xf);
auto d2 = vdupq_n_s8(x8[k][i].scales[8*i128 + 2*l + 1] & 0xf);
auto q1_8 = vmulq_s8(d1, xv[l].val[0]);
auto q2_8 = vmulq_s8(d2, xv[l].val[1]);
auto q1_16_1 = vmovl_s8(vget_low_s8 (q1_8));
auto q1_16_2 = vmovl_s8(vget_high_s8(q1_8));
auto q2_16_1 = vmovl_s8(vget_low_s8 (q2_8));
auto q2_16_2 = vmovl_s8(vget_high_s8(q2_8));
float32x4x4_t f1{vcvtq_f32_s32(vmovl_s16(vget_low_s16(q1_16_1))), vcvtq_f32_s32(vmovl_s16(vget_high_s16(q1_16_1))),
vcvtq_f32_s32(vmovl_s16(vget_low_s16(q1_16_2))), vcvtq_f32_s32(vmovl_s16(vget_high_s16(q1_16_2)))};
float32x4x4_t f2{vcvtq_f32_s32(vmovl_s16(vget_low_s16(q2_16_1))), vcvtq_f32_s32(vmovl_s16(vget_high_s16(q2_16_1))),
vcvtq_f32_s32(vmovl_s16(vget_low_s16(q2_16_2))), vcvtq_f32_s32(vmovl_s16(vget_high_s16(q2_16_2)))};
auto m1 = vmulq_f32(vm, vcvtq_f32_s32(vdupq_n_s32(x8[k][i].scales[8*i128 + 2*l + 0] >> 4)));
auto m2 = vmulq_f32(vm, vcvtq_f32_s32(vdupq_n_s32(x8[k][i].scales[8*i128 + 2*l + 1] >> 4)));
for (int j = 0; j < 4; ++j) {
f1.val[j] = vfmaq_f32(m1, vd, f1.val[j]);
f2.val[j] = vfmaq_f32(m2, vd, f2.val[j]);
}
vst1q_f32_x4(f_values + 128*i128 + 32*l + 0, f1);
vst1q_f32_x4(f_values + 128*i128 + 32*l + 16, f2);
auto max1 = vmaxq_f32(vmaxq_f32(vabsq_f32(f1.val[0]), vabsq_f32(f1.val[1])), vmaxq_f32(vabsq_f32(f1.val[2]), vabsq_f32(f1.val[3])));
auto max2 = vmaxq_f32(vmaxq_f32(vabsq_f32(f2.val[0]), vabsq_f32(f2.val[1])), vmaxq_f32(vabsq_f32(f2.val[2]), vabsq_f32(f2.val[3])));
block_max = vmaxq_f32(block_max, vmaxq_f32(max1, max2));
}
}
auto max = vmaxvq_f32(block_max);
float d = max / 127.f;
auto id = vdupq_n_f32(d != 0.0f ? 1/d : 0.0f);
y[i].d[k] = GGML_FP32_TO_FP16(d);
int16x8x4_t i16;
for (int ib32 = 0; ib32 < 8; ++ib32) {
auto v1 = vld1q_f32_x4(f_values + 32*ib32 + 0);
auto v2 = vld1q_f32_x4(f_values + 32*ib32 + 16);
i16.val[0] = vcombine_s16(vmovn_s32(vcvtnq_s32_f32(vmulq_f32(id, v1.val[0]))), vmovn_s32(vcvtnq_s32_f32(vmulq_f32(id, v1.val[1]))));
i16.val[1] = vcombine_s16(vmovn_s32(vcvtnq_s32_f32(vmulq_f32(id, v1.val[2]))), vmovn_s32(vcvtnq_s32_f32(vmulq_f32(id, v1.val[3]))));
i16.val[2] = vcombine_s16(vmovn_s32(vcvtnq_s32_f32(vmulq_f32(id, v2.val[0]))), vmovn_s32(vcvtnq_s32_f32(vmulq_f32(id, v2.val[1]))));
i16.val[3] = vcombine_s16(vmovn_s32(vcvtnq_s32_f32(vmulq_f32(id, v2.val[2]))), vmovn_s32(vcvtnq_s32_f32(vmulq_f32(id, v2.val[3]))));
vst1q_s8((int8_t *)block + 0, vcombine_s8(vmovn_s16(i16.val[0]), vmovn_s16(i16.val[1])));
vst1q_s8((int8_t *)block + 16, vcombine_s8(vmovn_s16(i16.val[2]), vmovn_s16(i16.val[3])));
auto q8 = (uint32_t *)y[i].qs + 64*ib32;
for (int l = 0; l < 4; ++l) {
q8[8*l + k + 0] = block[l + 0];
q8[8*l + k + 32] = block[l + 4];
}
}
}
}
y += nb;
}
}
bool iqk_convert_kquants_q8X_r8([[maybe_unused]] int type, [[maybe_unused]] int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] void * vy, [[maybe_unused]] int nrc_x) {
return false;
//switch (ggml_type(type)) {
// case GGML_TYPE_Q2_K: iqk_convert_q2_k_q8_k_r8(n, vx, bx, vy, nrc_x); break;
}
bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
switch (ggml_type(type)) {
case GGML_TYPE_Q2_K: iqk_convert_q2_k_q8_k_r8(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_Q3_K: iqk_convert_q3_k_q8_k_r8(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_Q4_K: iqk_convert_q4_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_Q5_K: iqk_convert_q5_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_Q6_K: iqk_convert_q6_k_q8_0_r8(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_IQ4_XS: iqk_convert_iq4_xs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
// default: return false;
//}
//return true;
default: return false;
}
return true;
}
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, [[maybe_unused]] mul_mat_t& func16) {

View File

@@ -271,6 +271,7 @@ struct MulMat {
}
#else
switch (type) {
case GGML_TYPE_Q2_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;