mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-03 21:09:39 +00:00
bitnet: NEON improvements for iq1_bn
With these changes we get to TG-128 = 34 t/s, PP-512 = 153 t/s.
This commit is contained in:
@@ -355,6 +355,37 @@ void quantize_row_q8_K64_reference(const float * x, block_q8_K64 * y, int64_t k)
|
||||
// x += 64;
|
||||
//}
|
||||
|
||||
float * dptr = (float *)y;
|
||||
auto qs = (int8_t *)(dptr + 4);
|
||||
#ifdef __ARM_NEON
|
||||
static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60};
|
||||
auto shuffle = vld1q_u8(k_shuffle);
|
||||
float32x4_t max[4] = { };
|
||||
for (int j = 0; j < k; j += 16) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
auto val = vld1q_f32(x + j + 4*i);
|
||||
val = vabsq_f32(val);
|
||||
max[i] = vmaxq_f32(max[i], val);
|
||||
}
|
||||
}
|
||||
float32x4_t vid[4];
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
dptr[i] = vmaxvq_f32(max[i])/127;
|
||||
float id = dptr[i] > 0 ? 1/dptr[i] : 0.f;
|
||||
vid[i] = vdupq_n_f32(id);
|
||||
}
|
||||
int8x16x4_t q;
|
||||
for (int j = 0; j < k; j += 16) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
auto val = vld1q_f32(x + j + 4*i);
|
||||
val = vmulq_f32(vid[i], val);
|
||||
q.val[i] = vreinterpretq_s8_s32(vcvtnq_s32_f32(val));
|
||||
}
|
||||
auto qi = vqtbl4q_s8(q, shuffle);
|
||||
vst1q_s8(qs, qi);
|
||||
qs += 16;
|
||||
}
|
||||
#else
|
||||
float aux[4] = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int j = 0; j < k; j += 16) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
@@ -364,17 +395,16 @@ void quantize_row_q8_K64_reference(const float * x, block_q8_K64 * y, int64_t k)
|
||||
}
|
||||
}
|
||||
}
|
||||
float * dptr = (float *)y;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
dptr[i] = aux[i]/127;
|
||||
aux[i] = dptr[i] > 0 ? 1/dptr[i] : 0.f;
|
||||
}
|
||||
auto qs = (int8_t *)(dptr + 4);
|
||||
for (int j = 0; j < k; j += 16) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
for (int l = 0; l < 4; ++l) qs[j+4*i+l] = nearest_int(aux[i]*x[j+4*i+l]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void quantize_row_q8_K64(const float * x, void * y, int64_t k) {
|
||||
|
||||
@@ -4277,6 +4277,74 @@ template <int nrc> struct Q8_K64 {
|
||||
const int8_t * y[nrc_y];
|
||||
};
|
||||
|
||||
static const uint64_t kall_signs[257] = {
|
||||
0x0101010101010101, 0x01010101010101ff, 0x010101010101ff01, 0x010101010101ffff,
|
||||
0x0101010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0x0101010101ffffff,
|
||||
0x01010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0x01010101ff01ffff,
|
||||
0x01010101ffff0101, 0x01010101ffff01ff, 0x01010101ffffff01, 0x01010101ffffffff,
|
||||
0x010101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0x010101ff0101ffff,
|
||||
0x010101ff01ff0101, 0x010101ff01ff01ff, 0x010101ff01ffff01, 0x010101ff01ffffff,
|
||||
0x010101ffff010101, 0x010101ffff0101ff, 0x010101ffff01ff01, 0x010101ffff01ffff,
|
||||
0x010101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0x010101ffffffffff,
|
||||
0x0101ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0x0101ff010101ffff,
|
||||
0x0101ff0101ff0101, 0x0101ff0101ff01ff, 0x0101ff0101ffff01, 0x0101ff0101ffffff,
|
||||
0x0101ff01ff010101, 0x0101ff01ff0101ff, 0x0101ff01ff01ff01, 0x0101ff01ff01ffff,
|
||||
0x0101ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0x0101ff01ffffffff,
|
||||
0x0101ffff01010101, 0x0101ffff010101ff, 0x0101ffff0101ff01, 0x0101ffff0101ffff,
|
||||
0x0101ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0x0101ffff01ffffff,
|
||||
0x0101ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0x0101ffffff01ffff,
|
||||
0x0101ffffffff0101, 0x0101ffffffff01ff, 0x0101ffffffffff01, 0x0101ffffffffffff,
|
||||
0x01ff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0x01ff01010101ffff,
|
||||
0x01ff010101ff0101, 0x01ff010101ff01ff, 0x01ff010101ffff01, 0x01ff010101ffffff,
|
||||
0x01ff0101ff010101, 0x01ff0101ff0101ff, 0x01ff0101ff01ff01, 0x01ff0101ff01ffff,
|
||||
0x01ff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0x01ff0101ffffffff,
|
||||
0x01ff01ff01010101, 0x01ff01ff010101ff, 0x01ff01ff0101ff01, 0x01ff01ff0101ffff,
|
||||
0x01ff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0x01ff01ff01ffffff,
|
||||
0x01ff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0x01ff01ffff01ffff,
|
||||
0x01ff01ffffff0101, 0x01ff01ffffff01ff, 0x01ff01ffffffff01, 0x01ff01ffffffffff,
|
||||
0x01ffff0101010101, 0x01ffff01010101ff, 0x01ffff010101ff01, 0x01ffff010101ffff,
|
||||
0x01ffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0x01ffff0101ffffff,
|
||||
0x01ffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0x01ffff01ff01ffff,
|
||||
0x01ffff01ffff0101, 0x01ffff01ffff01ff, 0x01ffff01ffffff01, 0x01ffff01ffffffff,
|
||||
0x01ffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0x01ffffff0101ffff,
|
||||
0x01ffffff01ff0101, 0x01ffffff01ff01ff, 0x01ffffff01ffff01, 0x01ffffff01ffffff,
|
||||
0x01ffffffff010101, 0x01ffffffff0101ff, 0x01ffffffff01ff01, 0x01ffffffff01ffff,
|
||||
0x01ffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0x01ffffffffffffff,
|
||||
0xff01010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0xff0101010101ffff,
|
||||
0xff01010101ff0101, 0xff01010101ff01ff, 0xff01010101ffff01, 0xff01010101ffffff,
|
||||
0xff010101ff010101, 0xff010101ff0101ff, 0xff010101ff01ff01, 0xff010101ff01ffff,
|
||||
0xff010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0xff010101ffffffff,
|
||||
0xff0101ff01010101, 0xff0101ff010101ff, 0xff0101ff0101ff01, 0xff0101ff0101ffff,
|
||||
0xff0101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0xff0101ff01ffffff,
|
||||
0xff0101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0xff0101ffff01ffff,
|
||||
0xff0101ffffff0101, 0xff0101ffffff01ff, 0xff0101ffffffff01, 0xff0101ffffffffff,
|
||||
0xff01ff0101010101, 0xff01ff01010101ff, 0xff01ff010101ff01, 0xff01ff010101ffff,
|
||||
0xff01ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0xff01ff0101ffffff,
|
||||
0xff01ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0xff01ff01ff01ffff,
|
||||
0xff01ff01ffff0101, 0xff01ff01ffff01ff, 0xff01ff01ffffff01, 0xff01ff01ffffffff,
|
||||
0xff01ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0xff01ffff0101ffff,
|
||||
0xff01ffff01ff0101, 0xff01ffff01ff01ff, 0xff01ffff01ffff01, 0xff01ffff01ffffff,
|
||||
0xff01ffffff010101, 0xff01ffffff0101ff, 0xff01ffffff01ff01, 0xff01ffffff01ffff,
|
||||
0xff01ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0xff01ffffffffffff,
|
||||
0xffff010101010101, 0xffff0101010101ff, 0xffff01010101ff01, 0xffff01010101ffff,
|
||||
0xffff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0xffff010101ffffff,
|
||||
0xffff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0xffff0101ff01ffff,
|
||||
0xffff0101ffff0101, 0xffff0101ffff01ff, 0xffff0101ffffff01, 0xffff0101ffffffff,
|
||||
0xffff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0xffff01ff0101ffff,
|
||||
0xffff01ff01ff0101, 0xffff01ff01ff01ff, 0xffff01ff01ffff01, 0xffff01ff01ffffff,
|
||||
0xffff01ffff010101, 0xffff01ffff0101ff, 0xffff01ffff01ff01, 0xffff01ffff01ffff,
|
||||
0xffff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0xffff01ffffffffff,
|
||||
0xffffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0xffffff010101ffff,
|
||||
0xffffff0101ff0101, 0xffffff0101ff01ff, 0xffffff0101ffff01, 0xffffff0101ffffff,
|
||||
0xffffff01ff010101, 0xffffff01ff0101ff, 0xffffff01ff01ff01, 0xffffff01ff01ffff,
|
||||
0xffffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0xffffff01ffffffff,
|
||||
0xffffffff01010101, 0xffffffff010101ff, 0xffffffff0101ff01, 0xffffffff0101ffff,
|
||||
0xffffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0xffffffff01ffffff,
|
||||
0xffffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0xffffffffff01ffff,
|
||||
0xffffffffffff0101, 0xffffffffffff01ff, 0xffffffffffffff01, 0xffffffffffffffff,
|
||||
0xffffffffffffffff
|
||||
};
|
||||
|
||||
struct DequantizerIQ1BN {
|
||||
const uint8x16_t m1 = vdupq_n_u8(1);
|
||||
const uint8x16x4_t sign_shuffles = {
|
||||
@@ -4292,17 +4360,24 @@ struct DequantizerIQ1BN {
|
||||
int8x16x4_t signs;
|
||||
uint64x2x4_t a;
|
||||
inline void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, int8x16x4_t& v) {
|
||||
auto all_signs = vdupq_n_u8(extra);
|
||||
all_signs = vorrq_u8(vceqq_u8(vandq_u8(all_signs, mask1), mask1), m1);
|
||||
auto all_signs = vld1q_u8((const uint8_t *)(kall_signs + extra));
|
||||
//auto all_signs = vdupq_n_u8(extra);
|
||||
//all_signs = vorrq_u8(vceqq_u8(vandq_u8(all_signs, mask1), mask1), m1);
|
||||
signs.val[0] = vqtbl1q_u8(all_signs, sign_shuffles.val[0]);
|
||||
signs.val[1] = vqtbl1q_u8(all_signs, sign_shuffles.val[1]);
|
||||
signs.val[2] = vqtbl1q_u8(all_signs, sign_shuffles.val[2]);
|
||||
signs.val[3] = vqtbl1q_u8(all_signs, sign_shuffles.val[3]);
|
||||
|
||||
a.val[0] = uint64x2_t{iq1bn_grid_zzz[ql[0] | ((qh[0] << 8) & 0x0f00)], iq1bn_grid_zzz[ql[1] | ((qh[0] << 4) & 0x0f00)]};
|
||||
a.val[1] = uint64x2_t{iq1bn_grid_zzz[ql[2] | ((qh[1] << 8) & 0x0f00)], iq1bn_grid_zzz[ql[3] | ((qh[1] << 4) & 0x0f00)]};
|
||||
a.val[2] = uint64x2_t{iq1bn_grid_zzz[ql[4] | ((qh[2] << 8) & 0x0f00)], iq1bn_grid_zzz[ql[5] | ((qh[2] << 4) & 0x0f00)]};
|
||||
a.val[3] = uint64x2_t{iq1bn_grid_zzz[ql[6] | ((qh[3] << 8) & 0x0f00)], iq1bn_grid_zzz[ql[7] | ((qh[3] << 4) & 0x0f00)]};
|
||||
uint32_t aux32[2];
|
||||
std::memcpy(aux32, qh, 4);
|
||||
aux32[1] = aux32[0] & 0xf0f0f0f0;
|
||||
aux32[0] &= 0x0f0f0f0f;
|
||||
const uint8_t * h = (const uint8_t *)aux32;
|
||||
|
||||
a.val[0] = uint64x2_t{iq1bn_grid_zzz[ql[0] | (h[0] << 8)], iq1bn_grid_zzz[ql[1] | (h[4] << 4)]};
|
||||
a.val[1] = uint64x2_t{iq1bn_grid_zzz[ql[2] | (h[1] << 8)], iq1bn_grid_zzz[ql[3] | (h[5] << 4)]};
|
||||
a.val[2] = uint64x2_t{iq1bn_grid_zzz[ql[4] | (h[2] << 8)], iq1bn_grid_zzz[ql[5] | (h[6] << 4)]};
|
||||
a.val[3] = uint64x2_t{iq1bn_grid_zzz[ql[6] | (h[3] << 8)], iq1bn_grid_zzz[ql[7] | (h[7] << 4)]};
|
||||
|
||||
v.val[0] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff1), shift), qmask), m1);
|
||||
v.val[1] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff1), shift), qmask), m1);
|
||||
@@ -4332,7 +4407,6 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
|
||||
|
||||
x = (const block_iq1_bn *)((const char *)vx + ix*bx);
|
||||
|
||||
|
||||
if constexpr (nrc_y == 1) {
|
||||
int32x4_t acc[4] = {};
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
|
||||
Reference in New Issue
Block a user