mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
iq2_bn_r4: 1st shot at NEON
PP-512 is already faster than iq2_bn (284 t/s vs 246 t/s for Bitnet-1.58b-3B). TG-128 is ~5% slower.
This commit is contained in:
@@ -7355,6 +7355,98 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc> struct Q8_16 {
|
||||
|
||||
constexpr static int nrc_y = nrc;
|
||||
|
||||
Q8_16(const DataInfo& info) {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto ptr = (const float *)info.src1_row(iy);
|
||||
std::memcpy(d + 5*iy, ptr, 5*sizeof(float));
|
||||
y[iy] = (const int8_t *)(ptr + 5);
|
||||
}
|
||||
}
|
||||
|
||||
inline int8x16x4_t load_quants(int iy, int i) const { return vld1q_s8_x4(y[iy] + 64*i); }
|
||||
inline int8x16x2_t load_quants_32(int iy, int i) const { return vld1q_s8_x2(y[iy] + 32*i); }
|
||||
inline float scale(int iy, int k) const { return d[5*iy+k]; }
|
||||
inline float sum_row(int iy) const { return d[5*iy + 4]; }
|
||||
inline float32x4_t scale(int iy) const { return vld1q_f32(d + 5*iy); }
|
||||
|
||||
float d[5*nrc_y];
|
||||
const int8_t * y[nrc_y];
|
||||
};
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
if (nrc_x%4) {
|
||||
printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
Q8_16<nrc_y> q8(info);
|
||||
auto m3 = vdupq_n_u8(0x3);
|
||||
int nb = n / QK_IQ1BN;
|
||||
int32x4_t acc[4*nrc_y] = {};
|
||||
uint8x16_t qx[8];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const float * dptr = (const float *)((const char *)vx + ix*bx);
|
||||
auto dl = vld1q_f32(dptr);
|
||||
const uint8_t * iq2 = (const uint8_t *)(dptr + 4);
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
auto bits = vld1q_u8_x2(iq2 + 64*ib);
|
||||
qx[0] = vandq_u8(bits.val[0], m3);
|
||||
qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
|
||||
qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
|
||||
qx[3] = vshrq_n_u8(bits.val[0], 6);
|
||||
qx[4] = vandq_u8(bits.val[1], m3);
|
||||
qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
|
||||
qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
|
||||
qx[7] = vshrq_n_u8(bits.val[1], 6);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = q8.load_quants_32(iy, 2*ib+0);
|
||||
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[0], y.val[0], 0);
|
||||
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[1], y.val[0], 1);
|
||||
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[2], y.val[0], 2);
|
||||
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[3], y.val[0], 3);
|
||||
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[4], y.val[1], 0);
|
||||
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[5], y.val[1], 1);
|
||||
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[6], y.val[1], 2);
|
||||
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[7], y.val[1], 3);
|
||||
}
|
||||
bits = vld1q_u8_x2(iq2 + 64*ib + 32);
|
||||
qx[0] = vandq_u8(bits.val[0], m3);
|
||||
qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
|
||||
qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
|
||||
qx[3] = vshrq_n_u8(bits.val[0], 6);
|
||||
qx[4] = vandq_u8(bits.val[1], m3);
|
||||
qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
|
||||
qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
|
||||
qx[7] = vshrq_n_u8(bits.val[1], 6);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = q8.load_quants_32(iy, 2*ib+1);
|
||||
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[0], y.val[0], 0);
|
||||
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[1], y.val[0], 1);
|
||||
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[2], y.val[0], 2);
|
||||
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[3], y.val[0], 3);
|
||||
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[4], y.val[1], 0);
|
||||
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[5], y.val[1], 1);
|
||||
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[6], y.val[1], 2);
|
||||
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[7], y.val[1], 3);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dy = q8.scale(iy);
|
||||
float32x4_t sumf = vmulq_f32(vcvtq_f32_s32(acc[4*iy+0]), vmulq_laneq_f32(dl, dy, 0));
|
||||
sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+1]), vmulq_laneq_f32(dl, dy, 1));
|
||||
sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+2]), vmulq_laneq_f32(dl, dy, 2));
|
||||
sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+3]), vmulq_laneq_f32(dl, dy, 3));
|
||||
sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(iy)));
|
||||
info.store(ix, iy, sumf);
|
||||
acc[4*iy+0] = acc[4*iy+1] = acc[4*iy+2] = acc[4*iy+3] = vdupq_n_f32(0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
const int nb = n / QK_IQ1BN;
|
||||
@@ -7900,6 +7992,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
m.funcs[7] = mul_mat_iq2bn_q8_K64<8>;
|
||||
expected_Btype = GGML_TYPE_Q8_K64;
|
||||
break;
|
||||
case GGML_TYPE_IQ2_BN_R4:
|
||||
m.funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>;
|
||||
m.funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>;
|
||||
m.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>;
|
||||
m.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>;
|
||||
m.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>;
|
||||
m.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>;
|
||||
m.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>;
|
||||
m.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>;
|
||||
expected_Btype = GGML_TYPE_Q8_K16;
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
MulMat::set_functions<DequantizerQ40>(m);
|
||||
expected_Btype = GGML_TYPE_Q8_0;
|
||||
|
||||
Reference in New Issue
Block a user