iq1bn(no lookup): somewhat better

We now have for Bitnet-3B:
| threads |          test |              t/s |
| ------: | ------------: | ---------------: |
|      16 |         pp512 |    308.97 ± 1.89 |
|      16 |         tg128 |     58.80 ± 0.07 |
|       8 |         tg128 |     49.79 ± 1.23 |
|       4 |         tg128 |     28.85 ± 0.02 |
|       2 |         tg128 |     15.39 ± 0.01 |
This commit is contained in:
Kawrakow
2024-07-15 13:46:07 +03:00
parent 98be184c23
commit 1f3dbbcc19

View File

@@ -1342,8 +1342,11 @@ template <int nrc> struct Q8_K64 {
struct DequantizerIQ1BN {
const __m256i m1_8 = _mm256_set1_epi8(1);
const __m128i mask1 = _mm_set1_epi8(0xf0);
#ifdef HAVE_FANCY_SIMD
const __m128i shifthh = _mm_set_epi16(5, 6, 7, 8, 9, 10, 11, 12);
#else
const __m128i mulhh = _mm_set_epi16(32, 64, 128, 256, 512, 1024, 2048, 4096);
#endif
const __m128i maskhh = _mm_set1_epi16(4096);
const __m256i shuffles[4] = {
_mm256_set_epi64x(0x0302030203020302, 0x0302030203020302, 0x0100010001000100, 0x0100010001000100),
@@ -1353,22 +1356,54 @@ struct DequantizerIQ1BN {
};
const __m256i mult = _mm256_set_epi16(8, 24, 72, 216, 648, 1944, 5832, 17496, 8, 24, 72, 216, 648, 1944, 5832, 17496);
const __m256i m3 = _mm256_set1_epi16(3);
const __m128i shuff_l = _mm_set_epi8(-128, 8, -128, 7, -128, 6, -128, 5, -128, 4, -128, 3, -128, 2, -128, 1);
const __m128i shuff_h = _mm_set_epi8(12, -128, 11, -128, 10, -128, 9, -128, 12, -128, 11, -128, 10, -128, 9, -128);
const __m128i shift_h = _mm_set_epi32(4, 4, 0, 0);
const __m128i mask_h = _mm_set1_epi16(0x0f00);
const __m128i shuff_hh = _mm_set_epi8(-128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0);
#ifdef HAVE_FANCY_SIMD
const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
#endif
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) {
auto aux1 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)ql));
uint32_t aux32; std::memcpy(&aux32, qh, 4);
auto aux2 = _mm_cvtepu8_epi16(_mm_and_si128(_mm_set_epi32(aux32, aux32, aux32, aux32 << 4), mask1));
auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_set1_epi16(extra), mulhh), maskhh);
auto all128 = _mm_or_si128(_mm_slli_epi16(aux2, 4), _mm_or_si128(aux1, aux3));
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) {
auto data = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes!
auto aux1 = _mm_shuffle_epi8(data, shuff_l);
auto aux2 = _mm_and_si128(_mm_srlv_epi32(_mm_shuffle_epi8(data, shuff_h), shift_h), mask_h);
#ifdef HAVE_FANCY_SIMD
auto aux3 = _mm_and_si128(_mm_sllv_epi16(_mm_shuffle_epi8(data, shuff_hh), shifthh), maskhh);
#else
auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_shuffle_epi8(data, shuff_hh), mulhh), maskhh);
#endif
auto all128 = _mm_or_si128(_mm_or_si128(aux1, aux2), aux3);
auto all = MM256_SET_M128I(all128, all128);
auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3);
auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3);
auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3);
auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3);
#ifdef HAVE_FANCY_SIMD
v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8);
v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8);
#else
v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8);
v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8);
#endif
}
//IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) {
// auto aux1 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)ql));
// uint32_t aux32; std::memcpy(&aux32, qh, 4);
// auto aux2 = _mm_cvtepu8_epi16(_mm_and_si128(_mm_set_epi32(aux32, aux32, aux32, aux32 << 4), mask1));
// auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_set1_epi16(extra), mulhh), maskhh);
// auto all128 = _mm_or_si128(_mm_slli_epi16(aux2, 4), _mm_or_si128(aux1, aux3));
// auto all = MM256_SET_M128I(all128, all128);
// auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3);
// auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3);
// auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3);
// auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3);
// v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8);
// v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8);
//}
};
template <int nrc_y>
@@ -1392,8 +1427,8 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
if constexpr (nrc_y == 1) {
__m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256();
for (int i = 0; i < nb/2; ++i) {
deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]);
deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]);
deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
#if defined __AVX512VNNI__ && defined __AVX512VL__
auto dot1 = _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0]);
auto dot2 = _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]);
@@ -1418,8 +1453,8 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
for (int i = 0; i < nb/2; ++i) {
deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]);
deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]);
deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
for (int iy = 0; iy < nrc_y; ++iy) {
#if defined __AVX512VNNI__ && defined __AVX512VL__
@@ -1442,7 +1477,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
}
int i = 2*(nb/2);
if (i < nb) {
deq.prepare_iq1bn_quants(x[i].extra, x[i].ql, x[i].qh, val[0], val[1]);
deq.prepare_iq1bn_quants(x + i, val[0], val[1]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]);
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]);