mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-28 10:21:48 +00:00
iq1bn(no lookup): better version
We have 4 groups of 16 in a block of 64 quants. For each group of 16 we have 3 groups of 5, each using 8 bits. The remaining 16'th quants of the 4 groups of 16 are encoded with 8 bits using the same encoding as the groups of 5. The only kernel where we have complications is the CUDA dequantize kernel (because we are dequantizing 8 quants there, and we have different encoding for the 1st and 2nd group of 8 in a group of 16). Ths achieves better performance on all tested platforms than any previous 1.625 bpw attempt. We have: | model | size | params | backend | threads | test | t/s | | ---------------- | ---------: | ---------: | ---------- | ------: | ------------: | ---------------: | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | CUDA | 8 | pp512 | 9613.02 ± 24.54 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | CUDA | 8 | tg128 | 229.85 ± 0.33 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 16 | pp512 | 322.59 ± 1.00 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 16 | tg128 | 59.79 ± 0.03 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 8 | tg128 | 57.62 ± 0.21 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 4 | tg128 | 33.66 ± 0.29 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 2 | tg128 | 18.30 ± 0.01 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | Metal | 8 | pp512 | 698.13 ± 0.21 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | Metal | 8 | tg128 | 68.88 ± 0.24 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 8 | pp512 | 196.80 ± 0.50 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 8 | tg128 | 51.58 ± 0.41 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 4 | tg128 | 30.80 ± 0.03 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 2 | tg128 | 16.89 ± 0.01 | It is still slower than 2 bpw Bitnet, but the difference now is not as dramatic.
This commit is contained in:
146
iqk_mul_mat.cpp
146
iqk_mul_mat.cpp
@@ -1342,44 +1342,31 @@ template <int nrc> struct Q8_K64 {
|
||||
|
||||
struct DequantizerIQ1BN {
|
||||
const __m256i m1_8 = _mm256_set1_epi8(1);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
const __m128i shifthh = _mm_set_epi16(5, 6, 7, 8, 9, 10, 11, 12);
|
||||
#else
|
||||
const __m128i mulhh = _mm_set_epi16(32, 64, 128, 256, 512, 1024, 2048, 4096);
|
||||
#endif
|
||||
const __m128i maskhh = _mm_set1_epi16(4096);
|
||||
const __m256i shuffles[4] = {
|
||||
_mm256_set_epi64x(0x0302030203020302, 0x0302030203020302, 0x0100010001000100, 0x0100010001000100),
|
||||
_mm256_set_epi64x(0x0706070607060706, 0x0706070607060706, 0x0504050405040504, 0x0504050405040504),
|
||||
_mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0908090809080908),
|
||||
_mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0f0e0f0e0f0e0f0e, 0x0d0c0d0c0d0c0d0c, 0x0d0c0d0c0d0c0d0c),
|
||||
static __m128i load_shuffle(int i) {
|
||||
static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12,
|
||||
3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12,
|
||||
6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12,
|
||||
9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12};
|
||||
return _mm_loadu_si128((const __m128i*)data + i);
|
||||
}
|
||||
const __m128i shuff[4] = { load_shuffle(0), load_shuffle(1), load_shuffle(2), load_shuffle(3) };
|
||||
const __m256i mult[4] = {
|
||||
_mm256_set_epi64x(0x5100010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
|
||||
_mm256_set_epi64x(0x1b00010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
|
||||
_mm256_set_epi64x(0x0900010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
|
||||
_mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
|
||||
};
|
||||
const __m256i mult = _mm256_set_epi16(8, 24, 72, 216, 648, 1944, 5832, 17496, 8, 24, 72, 216, 648, 1944, 5832, 17496);
|
||||
const __m256i m3 = _mm256_set1_epi16(3);
|
||||
const __m128i shuff_l = _mm_set_epi8(-128, 8, -128, 7, -128, 6, -128, 5, -128, 4, -128, 3, -128, 2, -128, 1);
|
||||
const __m128i shuff_h = _mm_set_epi8(12, -128, 11, -128, 10, -128, 9, -128, 12, -128, 11, -128, 10, -128, 9, -128);
|
||||
const __m128i shift_h = _mm_set_epi32(4, 4, 0, 0);
|
||||
const __m128i mask_h = _mm_set1_epi16(0x0f00);
|
||||
const __m128i shuff_hh = _mm_set_epi8(-128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
||||
#endif
|
||||
|
||||
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) {
|
||||
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) const {
|
||||
auto data = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes!
|
||||
auto aux1 = _mm_shuffle_epi8(data, shuff_l);
|
||||
auto aux2 = _mm_and_si128(_mm_srlv_epi32(_mm_shuffle_epi8(data, shuff_h), shift_h), mask_h);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto aux3 = _mm_and_si128(_mm_sllv_epi16(_mm_shuffle_epi8(data, shuff_hh), shifthh), maskhh);
|
||||
#else
|
||||
auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_shuffle_epi8(data, shuff_hh), mulhh), maskhh);
|
||||
#endif
|
||||
auto all128 = _mm_or_si128(_mm_or_si128(aux1, aux2), aux3);
|
||||
auto all = MM256_SET_M128I(all128, all128);
|
||||
auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3);
|
||||
auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3);
|
||||
auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3);
|
||||
auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3);
|
||||
auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[0])), mult[0]), m3);
|
||||
auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[1])), mult[1]), m3);
|
||||
auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[2])), mult[2]), m3);
|
||||
auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[3])), mult[3]), m3);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8);
|
||||
v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8);
|
||||
@@ -1389,21 +1376,6 @@ struct DequantizerIQ1BN {
|
||||
#endif
|
||||
}
|
||||
|
||||
//IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) {
|
||||
|
||||
// auto aux1 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)ql));
|
||||
// uint32_t aux32; std::memcpy(&aux32, qh, 4);
|
||||
// auto aux2 = _mm_cvtepu8_epi16(_mm_and_si128(_mm_set_epi32(aux32, aux32, aux32, aux32 << 4), mask1));
|
||||
// auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_set1_epi16(extra), mulhh), maskhh);
|
||||
// auto all128 = _mm_or_si128(_mm_slli_epi16(aux2, 4), _mm_or_si128(aux1, aux3));
|
||||
// auto all = MM256_SET_M128I(all128, all128);
|
||||
// auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3);
|
||||
// auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3);
|
||||
// auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3);
|
||||
// auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3);
|
||||
// v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8);
|
||||
// v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8);
|
||||
//}
|
||||
};
|
||||
|
||||
template <int nrc_y>
|
||||
@@ -1466,9 +1438,9 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4);
|
||||
#else
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])),
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1])));
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1])));
|
||||
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])),
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3])));
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3])));
|
||||
dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2));
|
||||
accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
|
||||
#endif
|
||||
@@ -4376,73 +4348,29 @@ static const uint64_t kall_signs[257] = {
|
||||
struct DequantizerIQ1BN {
|
||||
const uint8x16_t m1 = vdupq_n_u8(1);
|
||||
|
||||
static inline uint8x16_t load_shuffle_l() {
|
||||
static const uint8_t data[16] = {1, 255, 2, 255, 3, 255, 4, 255, 5, 255, 6, 255, 7, 255, 8, 255};
|
||||
return vld1q_u8(data);
|
||||
static inline uint8x16x4_t load_shuffles() {
|
||||
static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12,
|
||||
3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12,
|
||||
6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12,
|
||||
9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12};
|
||||
return vld1q_u8_x4(data);
|
||||
}
|
||||
static inline uint8x16_t load_shuffle_h() {
|
||||
static const uint8_t data[16] = {9, 255, 10, 255, 11, 255, 12, 255, 9, 255, 10, 255, 11, 255, 12, 255};
|
||||
return vld1q_u8(data);
|
||||
static inline uint8x16x4_t load_mult() {
|
||||
static const uint8_t data[64] = {81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81,
|
||||
81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 27,
|
||||
81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 9,
|
||||
81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 3};
|
||||
return vld1q_u8_x4(data);
|
||||
}
|
||||
static inline uint8x16_t load_shuffle_hh() {
|
||||
static const uint8_t data[16] = {0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255};
|
||||
return vld1q_u8(data);
|
||||
}
|
||||
static inline int16x8_t load_shift_hh() {
|
||||
static const int16_t data[8] = {12, 11, 10, 9, 8, 7, 6, 5};
|
||||
return vld1q_s16(data);
|
||||
}
|
||||
static inline uint16x8_t load_mult() {
|
||||
//static const uint16_t data[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
|
||||
static const uint16_t data[8] = {2187*8, 729*8, 243*8, 81*8, 27*8, 9*8, 3*8, 1*8};
|
||||
return vld1q_u16(data);
|
||||
}
|
||||
//static inline uint8x16x4_t load_shuffles(uint16_t s0) {
|
||||
// uint8x16x4_t r;
|
||||
// auto step = vdupq_n_u8(4);
|
||||
// r.val[0] = vreinterpretq_u8_u16(vdupq_n_u16(s0));
|
||||
// r.val[1] = vaddq_u8(r.val[0], step);
|
||||
// r.val[2] = vaddq_u8(r.val[1], step);
|
||||
// r.val[3] = vaddq_u8(r.val[2], step);
|
||||
// return r;
|
||||
//}
|
||||
|
||||
const uint8x16_t shuff_l = load_shuffle_l();
|
||||
const uint8x16_t shuff_h = load_shuffle_h();
|
||||
const int32x4_t shift_h = {8, 8, 4, 4};
|
||||
const uint16x8_t mask_h = vdupq_n_u16(0x0f00);
|
||||
const uint8x16_t shuff_hh = load_shuffle_hh();
|
||||
const uint16x8_t mask_hh = vdupq_n_u16(4096);
|
||||
const int16x8_t shift_hh = load_shift_hh();
|
||||
const uint16x8_t mult = load_mult();
|
||||
const uint8x16_t step = vdupq_n_u8(2);
|
||||
const uint8x16_t shuff0 = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
|
||||
//const uint8x16x4_t shuff1 = load_shuffles(0x0100);
|
||||
//const uint8x16x4_t shuff2 = load_shuffles(0x0302);
|
||||
//const uint16x8_t mask = vdupq_n_u16(0x1fff);
|
||||
//const uint16x8_t m3 = vdupq_n_u16(3);
|
||||
const uint8x16x4_t shuff = load_shuffles();
|
||||
const uint8x16x4_t mult = load_mult();
|
||||
|
||||
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const {
|
||||
auto data = vld1q_u8((const uint8_t *)x);
|
||||
auto aux1 = vqtbl1q_u8(data, shuff_l);
|
||||
auto aux2 = vandq_u16(vshlq_u32(vqtbl1q_u8(data, shuff_h), shift_h), mask_h);
|
||||
auto aux3 = vandq_u16(vshlq_u16(vqtbl1q_u8(data, shuff_hh), shift_hh), mask_hh);
|
||||
auto all = vorrq_u16(vorrq_u16(aux1, aux2), aux3);
|
||||
auto shuffle = shuff0;
|
||||
//auto shuffle = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
|
||||
//auto step = vdupq_n_u8(2);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
|
||||
auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
|
||||
//auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff1.val[k]));
|
||||
//auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff2.val[k]));
|
||||
v1 = vmulq_u16(v1, mult);
|
||||
v2 = vmulq_u16(v2, mult);
|
||||
v1 = vshrq_n_u16(vhaddq_u16(v1, vshrq_n_u16(v1, 1)), 14);
|
||||
v2 = vshrq_n_u16(vhaddq_u16(v2, vshrq_n_u16(v2, 1)), 14);
|
||||
//v1 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v1, mult), mask), m3), 13);
|
||||
//v2 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v2, mult), mask), m3), 13);
|
||||
v.val[k] = vsubq_s8(vreinterpretq_s8_u8(vcombine_u8(vmovn_u16(v1), vmovn_u16(v2))), m1);
|
||||
auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]);
|
||||
val = vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6);
|
||||
v.val[k] = vsubq_s8(vreinterpretq_s8_u8(val), m1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user