mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-09 07:50:10 +00:00
iq1bn(no lookup): somewhat better
We now have for Bitnet-3B: | threads | test | t/s | | ------: | ------------: | ---------------: | | 16 | pp512 | 308.97 ± 1.89 | | 16 | tg128 | 58.80 ± 0.07 | | 8 | tg128 | 49.79 ± 1.23 | | 4 | tg128 | 28.85 ± 0.02 | | 2 | tg128 | 15.39 ± 0.01 |
This commit is contained in:
@@ -1342,8 +1342,11 @@ template <int nrc> struct Q8_K64 {
|
||||
|
||||
struct DequantizerIQ1BN {
|
||||
const __m256i m1_8 = _mm256_set1_epi8(1);
|
||||
const __m128i mask1 = _mm_set1_epi8(0xf0);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
const __m128i shifthh = _mm_set_epi16(5, 6, 7, 8, 9, 10, 11, 12);
|
||||
#else
|
||||
const __m128i mulhh = _mm_set_epi16(32, 64, 128, 256, 512, 1024, 2048, 4096);
|
||||
#endif
|
||||
const __m128i maskhh = _mm_set1_epi16(4096);
|
||||
const __m256i shuffles[4] = {
|
||||
_mm256_set_epi64x(0x0302030203020302, 0x0302030203020302, 0x0100010001000100, 0x0100010001000100),
|
||||
@@ -1353,22 +1356,54 @@ struct DequantizerIQ1BN {
|
||||
};
|
||||
const __m256i mult = _mm256_set_epi16(8, 24, 72, 216, 648, 1944, 5832, 17496, 8, 24, 72, 216, 648, 1944, 5832, 17496);
|
||||
const __m256i m3 = _mm256_set1_epi16(3);
|
||||
const __m128i shuff_l = _mm_set_epi8(-128, 8, -128, 7, -128, 6, -128, 5, -128, 4, -128, 3, -128, 2, -128, 1);
|
||||
const __m128i shuff_h = _mm_set_epi8(12, -128, 11, -128, 10, -128, 9, -128, 12, -128, 11, -128, 10, -128, 9, -128);
|
||||
const __m128i shift_h = _mm_set_epi32(4, 4, 0, 0);
|
||||
const __m128i mask_h = _mm_set1_epi16(0x0f00);
|
||||
const __m128i shuff_hh = _mm_set_epi8(-128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
||||
#endif
|
||||
|
||||
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) {
|
||||
|
||||
auto aux1 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)ql));
|
||||
uint32_t aux32; std::memcpy(&aux32, qh, 4);
|
||||
auto aux2 = _mm_cvtepu8_epi16(_mm_and_si128(_mm_set_epi32(aux32, aux32, aux32, aux32 << 4), mask1));
|
||||
auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_set1_epi16(extra), mulhh), maskhh);
|
||||
auto all128 = _mm_or_si128(_mm_slli_epi16(aux2, 4), _mm_or_si128(aux1, aux3));
|
||||
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) {
|
||||
auto data = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes!
|
||||
auto aux1 = _mm_shuffle_epi8(data, shuff_l);
|
||||
auto aux2 = _mm_and_si128(_mm_srlv_epi32(_mm_shuffle_epi8(data, shuff_h), shift_h), mask_h);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto aux3 = _mm_and_si128(_mm_sllv_epi16(_mm_shuffle_epi8(data, shuff_hh), shifthh), maskhh);
|
||||
#else
|
||||
auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_shuffle_epi8(data, shuff_hh), mulhh), maskhh);
|
||||
#endif
|
||||
auto all128 = _mm_or_si128(_mm_or_si128(aux1, aux2), aux3);
|
||||
auto all = MM256_SET_M128I(all128, all128);
|
||||
auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3);
|
||||
auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3);
|
||||
auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3);
|
||||
auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8);
|
||||
v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8);
|
||||
#else
|
||||
v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8);
|
||||
v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8);
|
||||
#endif
|
||||
}
|
||||
|
||||
//IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) {
|
||||
|
||||
// auto aux1 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)ql));
|
||||
// uint32_t aux32; std::memcpy(&aux32, qh, 4);
|
||||
// auto aux2 = _mm_cvtepu8_epi16(_mm_and_si128(_mm_set_epi32(aux32, aux32, aux32, aux32 << 4), mask1));
|
||||
// auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_set1_epi16(extra), mulhh), maskhh);
|
||||
// auto all128 = _mm_or_si128(_mm_slli_epi16(aux2, 4), _mm_or_si128(aux1, aux3));
|
||||
// auto all = MM256_SET_M128I(all128, all128);
|
||||
// auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3);
|
||||
// auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3);
|
||||
// auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3);
|
||||
// auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3);
|
||||
// v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8);
|
||||
// v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8);
|
||||
//}
|
||||
};
|
||||
|
||||
template <int nrc_y>
|
||||
@@ -1392,8 +1427,8 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
if constexpr (nrc_y == 1) {
|
||||
__m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256();
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]);
|
||||
deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]);
|
||||
deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
|
||||
deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
auto dot1 = _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0]);
|
||||
auto dot2 = _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]);
|
||||
@@ -1418,8 +1453,8 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
|
||||
deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]);
|
||||
deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]);
|
||||
deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
|
||||
deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
@@ -1442,7 +1477,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
}
|
||||
int i = 2*(nb/2);
|
||||
if (i < nb) {
|
||||
deq.prepare_iq1bn_quants(x[i].extra, x[i].ql, x[i].qh, val[0], val[1]);
|
||||
deq.prepare_iq1bn_quants(x + i, val[0], val[1]);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]);
|
||||
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]);
|
||||
|
||||
Reference in New Issue
Block a user