iq1_tn: faster NEON

This commit is contained in:
Iwan Kawrakow
2024-09-08 21:30:06 +02:00
parent 8d509a7d71
commit 3487e68cc0
2 changed files with 36 additions and 14 deletions

View File

@@ -5848,16 +5848,17 @@ template <int nrc> struct Q8_K64 {
Q8_K64(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
std::memcpy(d + 4*iy, dptr, 4*sizeof(float));
y[iy] = (const int8_t *)(dptr + 4);
std::memcpy(d + 8*iy, dptr, 8*sizeof(float));
y[iy] = (const int8_t *)(dptr + 8);
}
}
inline int8x16x4_t load_quants64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy] + 128*i + 64*j); }
inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy] + 128*i + 32*j); }
inline float32x4_t scale(int iy) const { return vld1q_f32(d + 4*iy); }
inline float32x4_t scale(int iy) const { return vld1q_f32(d + 8*iy); }
inline float32x4_t minus(int iy) const { return vld1q_f32(d + 8*iy + 4); }
float d[4*nrc_y];
float d[8*nrc_y];
const int8_t * y[nrc_y];
};
@@ -5889,6 +5890,14 @@ struct DequantizerIQ1BN {
v.val[k] = vsubq_s8(vreinterpretq_s8_u8(val), m1);
}
}
IQK_ALWAYS_INLINE void prepare_iq1bn_quants_nosub(const block_iq1_bn * x, int8x16x4_t& v) const {
auto data = vld1q_u8((const uint8_t *)x);
for (int k = 0; k < 4; ++k) {
auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]);
v.val[k] = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6));
}
}
};
template <int nrc_y, bool is_iq1_tn>
@@ -5916,10 +5925,10 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
if constexpr (nrc_y == 1) {
int32x4_t acc[4] = {};
for (int i = 0; i < nb/2; ++i) {
deq.prepare_iq1bn_quants(x+2*i+0, v1);
deq.prepare_iq1bn_quants_nosub(x+2*i+0, v1);
auto q = q8.load_quants64(0, i, 0);
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
deq.prepare_iq1bn_quants(x+2*i+1, v2);
deq.prepare_iq1bn_quants_nosub(x+2*i+1, v2);
q = q8.load_quants64(0, i, 1);
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v2.val[j]);
}
@@ -5931,8 +5940,8 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
for (int i = 0; i < nb/2; ++i) {
deq.prepare_iq1bn_quants(x+2*i+0, v1);
deq.prepare_iq1bn_quants(x+2*i+1, v2);
deq.prepare_iq1bn_quants_nosub(x+2*i+0, v1);
deq.prepare_iq1bn_quants_nosub(x+2*i+1, v2);
for (int iy = 0; iy < nrc_y; ++iy) {
auto q = q8.load_quants(iy, i, 0);
@@ -5948,7 +5957,7 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
int i = 2*(nb/2);
if (i < nb) {
deq.prepare_iq1bn_quants(x+i, v1);
deq.prepare_iq1bn_quants_nosub(x+i, v1);
if constexpr (nrc_y == 1) {
auto q = q8.load_quants(0, i/2, 0);
for (int j = 0; j < 4; ++j) {
@@ -5966,9 +5975,9 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
for (int iy = 0; iy < nrc_y; ++iy) {
if constexpr (is_iq1_tn) {
info.store(ix, iy, scale * vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
info.store(ix, iy, -scale * vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
} else {
info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
info.store(ix, iy, -vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
}
}

View File

@@ -379,8 +379,10 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) {
GGML_ASSERT(k >= 8*QK_IQ1BN);
float * dptr = (float *)y;
auto qs = (int8_t *)(dptr + 4);
auto qs = (int8_t *)(dptr + 8);
#ifdef __ARM_NEON
static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60};
auto shuffle = vld1q_u8(k_shuffle);
@@ -399,16 +401,22 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) {
vid[i] = vdupq_n_f32(id);
}
int8x16x4_t q;
int32x4_t qsum = {};
const int8x16_t m1 = vdupq_n_s8(1);
for (int j = 0; j < k; j += 16) {
for (int i = 0; i < 4; ++i) {
auto val = vld1q_f32(x + j + 4*i);
val = vmulq_f32(vid[i], val);
q.val[i] = vreinterpretq_s8_s32(vcvtnq_s32_f32(val));
auto ival = vcvtnq_s32_f32(val);
q.val[i] = vreinterpretq_s8_s32(ival);
}
auto qi = vqtbl4q_s8(q, shuffle);
qsum = ggml_vdotq_s32(qsum, qi, m1);
vst1q_s8(qs, qi);
qs += 16;
}
auto sumf = vmulq_f32(vld1q_f32(dptr), vcvtq_f32_s32(qsum));
vst1q_f32(dptr + 4, sumf);
#elif defined __AVX__
__m128 max[4] = {};
__m128 sign_bit = _mm_set1_ps(-0.f);
@@ -455,11 +463,16 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) {
dptr[i] = aux[i]/127;
aux[i] = dptr[i] > 0 ? 1/dptr[i] : 0.f;
}
int32_t sum[4] = {};
for (int j = 0; j < k; j += 16) {
for (int i = 0; i < 4; ++i) {
for (int l = 0; l < 4; ++l) qs[j+4*i+l] = nearest_int(aux[i]*x[j+4*i+l]);
for (int l = 0; l < 4; ++l) {
qs[j+4*i+l] = nearest_int(aux[i]*x[j+4*i+l]);
sum[i] += qs[j+4*i+l];
}
}
}
for (int i = 0; i < 4; ++i) dptr[4+i] = dptr[i]*sum[i];
#endif
}