From d9c4ea48d1e41d8f7215ff1c094d75e7229b65e2 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 27 Jan 2025 16:50:07 +0200 Subject: [PATCH 01/14] Interleave 8 rows (Q8_0, IQ4_XS) (#178) * Try interleaving 8 rows for iq4_xs On Zen4, PP-512 goes up from ~260 t/s to 288 t/s for L3-8B. TG-128 reaches max. performance at 2 threads and is slightly higher than 4 interleaved rows (14.48 t/s vs 13.11 t/s @ 2 threads and 14/28 t/s @ 4 threads). * Try interleaving 8 iq4_xs rows It is also faster on AVX2. This is the NEON implementation. It is tiny bit faster than 4 interleaved rows (~0.5%). So, this looks like a winner given the Zen4/AVX2 improvement without associated NEON egression. * Cleanup * 8-rows interleaved q8_0 (AVX2) * 8-rows interleaved q8_0 (Zen4) * 8-rows interleaved q8_0 (Zen4) - slightly better PP-512 is now 284 t/s compared to 257 t/s for 4-rows interleaved. TG-128 reaches peak of 8.16 t/s at just 2 threads compared to 7.95 t/s @ 4 threads before. * 8-rows interleaved q8_0 (NEON) PP-512 is slightly better (138 t/s vs 132.5 t/s), TG-128 is about the same. * FA: repack Q8_0 to Q8_0_R8 * Remove special purpose mul_mat_q8_0_r4_q8_1_128 (Zen4) * FA: repack Q8_0 to Q8_0_R8 (NEON) Very slightly faster than the general purpose gemm, slightly slower than the D = 128 special case gemm mul_mat_q8_0_r4_q8_0_128. Still removing mul_mat_q8_0_r4_q8_0_128 as we simply don't have enough vector registers to hold 8 interleaved rows, so there is no point to have the special purpose implementation. --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-common.h | 15 +- ggml/src/ggml-quants.c | 1 - ggml/src/iqk/iqk_mul_mat.cpp | 690 +++++++++++++++++----------------- ggml/src/iqk/iqk_quantize.cpp | 150 ++++---- ggml/src/iqk/iqk_quantize.h | 4 +- src/llama.cpp | 8 +- 6 files changed, 437 insertions(+), 431 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 7f79b27b..d08870ad 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -236,6 +236,11 @@ typedef struct { int8_t qs[4*QK8_0]; } block_q8_0_x4; static_assert(sizeof(block_q8_0_x4) == 4*sizeof(block_q8_0), "wrong q8_0_x4 block size/padding"); +typedef struct { + ggml_half d[8]; + int8_t qs[8*QK8_0]; +} block_q8_0_r8; +static_assert(sizeof(block_q8_0_r8) == 8*sizeof(block_q8_0), "wrong q8_0_r8 block size/padding"); typedef struct { ggml_half d[4]; // deltas for 4 q4_0 blocks @@ -534,12 +539,12 @@ typedef struct { static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); typedef struct { - ggml_half d[4]; - uint8_t scales_h[QK_K/32]; - uint8_t scales_l[QK_K/16]; - uint8_t qs[QK_K*2]; + ggml_half d[8]; + uint8_t scales_h[QK_K/16]; + uint8_t scales_l[QK_K/ 8]; + uint8_t qs[QK_K*4]; } block_iq4_xs_r4; -static_assert(sizeof(block_iq4_xs_r4) == 4*sizeof(ggml_half) + QK_K/32 + QK_K/16 + QK_K*2, "wrong iq4_xs_rs block size/padding"); +static_assert(sizeof(block_iq4_xs_r4) == 8*sizeof(block_iq4_xs), "wrong iq4_xs_rs block size/padding"); typedef struct { uint8_t scales[QK_K/32]; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 23ac9915..391d9e2e 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -936,7 +936,6 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) #if defined(__ARM_NEON) for (int i = 0; i < nb; i++) { - int i4 = i/4, ir = i%4; float32x4_t srcv [8]; float32x4_t asrcv[8]; float32x4_t amaxv[8]; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 7ddaee2a..d8273415 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -245,9 +245,7 @@ struct MulMat { case GGML_TYPE_Q4_0_R4: case GGML_TYPE_Q5_0_R4: case GGML_TYPE_Q6_0_R4: - case GGML_TYPE_Q8_0_R4: case GGML_TYPE_IQ4_NL_R4: - case GGML_TYPE_IQ4_XS_R4: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ3_K_R4: case GGML_TYPE_IQ4_K_R4: @@ -259,6 +257,8 @@ struct MulMat { case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_BN_R4: return 4; + case GGML_TYPE_IQ4_XS_R4: + case GGML_TYPE_Q8_0_R4: case GGML_TYPE_Q8_K_R8: return 8; case GGML_TYPE_BF16_R16: return 16; default: return 1; @@ -2902,91 +2902,103 @@ static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn #ifdef HAVE_FANCY_SIMD template static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(nrc_x%16 == 0); Q8 q8(info); int nb = n / QK8_0; GGML_ASSERT(nb%4 == 0); if constexpr (nrc_y == 1) { auto m127 = _mm256_set1_epi8(127); - auto m1 = _mm256_set1_epi16(1); - __m256 acc[nrc_y] = {}; - for (int ix = 0; ix < nrc_x; ix += 4) { - const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx); + __m256 acc[2] = {}; + __m256i qx[8]; + float d8[8]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { + _mm256_storeu_ps(d8, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d))); for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq8[4*ib4+k].d)); - auto scales = _mm256_set_m128(scales128, scales128); - auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-63.5f)); - auto q1 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0), m127); - auto q2 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1), m127); - auto q3 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2), m127); - auto q4 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3), m127); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55)))); - auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff)))); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]))); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[iy]); - } + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d)); + auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-127.f)); + qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0), m127); + qx[1] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1), m127); + qx[2] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2), m127); + qx[3] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3), m127); + qx[4] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4), m127); + qx[5] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5), m127); + qx[6] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6), m127); + qx[7] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7), m127); + auto y4l = _mm_loadu_si128((const __m128i*)q8.y[0][ib4].qs+2*k+0); + auto y4h = _mm_loadu_si128((const __m128i*)q8.y[0][ib4].qs+2*k+1); + auto yl = MM256_SET_M128I(y4l, y4l); + auto yh = MM256_SET_M128I(y4h, y4h); + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(yl, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(yl, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(yl, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(yl, 0xff)); + sumi = _mm256_dpbusd_epi32(sumi, qx[4], _mm256_shuffle_epi32(yh, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[5], _mm256_shuffle_epi32(yh, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[6], _mm256_shuffle_epi32(yh, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[7], _mm256_shuffle_epi32(yh, 0xff)); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[k])); + acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]); + acc[1] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(d8[k+4]), acc[1]); } } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); - info.store(ix, iy, sum); - acc[iy] = _mm256_setzero_ps(); - } + info.store(ix, 0, _mm256_add_ps(acc[0], acc[1])); + acc[0] = acc[1] = _mm256_setzero_ps(); } } else { __m512 acc[2*nrc_y] = {}; - __m512i qx[4]; + __m512i qx[8]; + float d8[8*nrc_y]; auto m127 = _mm512_set1_epi8(127); - for (int ix = 0; ix < nrc_x; ix += 8) { - const block_q8_0_x4 * q8l = (const block_q8_0_x4 *)((const char *)vx + (ix+0)*bx); - const block_q8_0_x4 * q8h = (const block_q8_0_x4 *)((const char *)vx + (ix+4)*bx); + for (int ix = 0; ix < nrc_x; ix += 16) { + const block_q8_0_r8 * q8l = (const block_q8_0_r8 *)((const char *)vx + (ix+0)*bx); + const block_q8_0_r8 * q8h = (const block_q8_0_r8 *)((const char *)vx + (ix+8)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d))); + } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8l[4*ib4+k].d)); - auto scales1 = _mm256_set_m128(scales128, scales128); - scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8h[4*ib4+k].d)); - auto scales2 = _mm256_set_m128(scales128, scales128); - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-63.5f)); - qx[0] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+0)), - _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+0), 1); - qx[1] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+1)), - _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+1), 1); - qx[2] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+2)), - _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+2), 1); - qx[3] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+3)), - _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+3), 1); - qx[0] = _mm512_add_epi8(qx[0], m127); - qx[1] = _mm512_add_epi8(qx[1], m127); - qx[2] = _mm512_add_epi8(qx[2], m127); - qx[3] = _mm512_add_epi8(qx[3], m127); + auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[4*ib4+k].d)); + auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[4*ib4+k].d)); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-127.f)); + for (int j = 0; j < 8; ++j) { + qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+j)), + _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+j), 1); + qx[j] = _mm512_add_epi8(qx[j], m127); + } for (int iy = 0; iy < nrc_y; ++iy) { - auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto y4l = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0); + auto y4h = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1); + auto y8l = MM256_SET_M128I(y4l, y4l); + auto y8h = MM256_SET_M128I(y4h, y4h); + auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1); + auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1); auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff))); + sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff))); + auto dy = _mm512_set1_ps(d8[8*iy+k]); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); } } } for (int iy = 0; iy < nrc_y; ++iy) { auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]); + info.store(ix, iy, sum512); acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); - auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); - auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); - info.store(ix+0, iy, sum1); - info.store(ix+4, iy, sum2); + //auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); + //auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); + //info.store(ix+0, iy, sum1); + //info.store(ix+4, iy, sum2); } } } @@ -2994,45 +3006,72 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn #else template static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); + GGML_ASSERT(nrc_x%8 == 0); Q8 q8(info); auto m1 = _mm256_set1_epi16(1); int nb = n / QK8_0; GGML_ASSERT(nb%4 == 0); __m256 acc[nrc_y] = {}; float d8[4*nrc_y]; - for (int ix = 0; ix < nrc_x; ix += 4) { - const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx); + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)); _mm_storeu_ps(d8 + 4*iy, scales); } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq8[4*ib4+k].d)); - auto scales = _mm256_set_m128(scales128, scales128); - auto q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0); - auto q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1); - auto q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2); - auto q4 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3); + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d)); + auto q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0); + auto q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1); + auto q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2); + auto q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3); + auto s0 = _mm256_sign_epi8(q0, q0); auto s1 = _mm256_sign_epi8(q1, q1); auto s2 = _mm256_sign_epi8(q2, q2); auto s3 = _mm256_sign_epi8(q3, q3); - auto s4 = _mm256_sign_epi8(q4, q4); for (int iy = 0; iy < nrc_y; ++iy) { - auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q1))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q2)))); - auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q3))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q4)))); + auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0); + auto y = MM256_SET_M128I(y128, y128); + auto sumi1 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1))) + ); + auto sumi2 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3))) + ); + auto sumi = _mm256_add_epi32(sumi1, sumi2); auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k])); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), acc[iy]); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + } + q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4); + q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5); + q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6); + q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7); + s0 = _mm256_sign_epi8(q0, q0); + s1 = _mm256_sign_epi8(q1, q1); + s2 = _mm256_sign_epi8(q2, q2); + s3 = _mm256_sign_epi8(q3, q3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1); + auto y = MM256_SET_M128I(y128, y128); + auto sumi1 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1))) + ); + auto sumi2 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3))) + ); + auto sumi = _mm256_add_epi32(sumi1, sumi2); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); } } } for (int iy = 0; iy < nrc_y; ++iy) { - auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); - info.store(ix, iy, sum); + info.store(ix, iy, acc[iy]); acc[iy] = _mm256_setzero_ps(); } } @@ -3041,9 +3080,11 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn template static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); + GGML_ASSERT(nrc_x%8 == 0); Q8 q8(info); auto m4 = _mm256_set1_epi8(0xf); + auto m30 = _mm256_set1_epi8(0x30); + auto m32 = _mm256_set1_epi8(32); #ifndef HAVE_FANCY_SIMD auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); @@ -3052,40 +3093,40 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const auto values = load_iq4nl_values_256(); #endif int nbl = n / QK_K; - using helper_t = union { __m256i vec; uint32_t val[8]; }; + using helper_t = union { __m256i vec[2]; uint64_t val[8]; }; helper_t h; __m256 acc[nrc_y] = {}; - __m256i isum[nrc_y] = {}; __m256i qx[4]; - for (int ix = 0; ix < nrc_x; ix += 4) { + for (int ix = 0; ix < nrc_x; ix += 8) { const block_iq4_xs_r4 * iq4 = (const block_iq4_xs_r4 *)((const char *)vx + (ix+0)*bx); for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 - auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[ibl].d)); - auto d4 = _mm256_set_m128(dl, dl); - auto slbits = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_l); - auto sl = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(slbits, 4), slbits), _mm256_set1_epi8(0xf)); - auto aux64 = (const uint64_t *)iq4[ibl].scales_h; - auto shbits = _mm_set_epi64x(aux64[0] >> 2, aux64[0]); - auto sh = _mm256_and_si256(MM256_SET_M128I(shbits, _mm_slli_epi16(shbits, 4)), _mm256_set1_epi8(0x30)); - h.vec = _mm256_sub_epi8(_mm256_or_si256(sl, sh), _mm256_set1_epi8(32)); + auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d)); + auto slbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l); + auto sl1 = _mm256_and_si256(slbits, m4); + auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4); + auto shbits = _mm_loadu_si128((const __m128i*)iq4[ibl].scales_h); + auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits); + h.vec[0] = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(_mm256_slli_epi16(sh, 4), m30)), m32); + h.vec[1] = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(sh, m30)), m32); + __m256i isum[nrc_y] = {}; for (int ib = 0; ib < QK_K/32; ++ib) { #ifdef HAVE_FANCY_SIMD - auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib])); + auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi64x(h.val[ib])); auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales)); - auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-64.f)); + auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-128.f)); for (int iy = 0; iy < nrc_y; ++iy) { float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]); } #else - auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle); + auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(h.val[ib])), s_shuffle); #endif - auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0); - auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1); - qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); - qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); - qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); - qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); + auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+1); + qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits1)); + qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits1, 4))); + qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits2)); + qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits2, 4))); #ifndef HAVE_FANCY_SIMD auto s1 = _mm256_sign_epi8(qx[0], qx[0]); auto s2 = _mm256_sign_epi8(qx[1], qx[1]); @@ -3093,7 +3134,8 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const auto s4 = _mm256_sign_epi8(qx[3], qx[3]); #endif for (int iy = 0; iy < nrc_y; ++iy) { - auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); + auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+2*ib+0); + auto y = MM256_SET_M128I(y128, y128); #ifdef HAVE_FANCY_SIMD auto sumi = _mm256_setzero_si256(); sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); @@ -3106,20 +3148,51 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); - isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2))); - isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4))); + auto sumi = _mm256_add_epi32(_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)), + _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4))); + isum[iy] = _mm256_add_epi32(isum[iy], sumi); +#endif + } + bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+2); + bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+3); + qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits1)); + qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits1, 4))); + qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits2)); + qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits2, 4))); +#ifndef HAVE_FANCY_SIMD + s1 = _mm256_sign_epi8(qx[0], qx[0]); + s2 = _mm256_sign_epi8(qx[1], qx[1]); + s3 = _mm256_sign_epi8(qx[2], qx[2]); + s4 = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+2*ib+1); + auto y = MM256_SET_M128I(y128, y128); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi)); +#else + auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + auto sumi = _mm256_add_epi32(_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)), + _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4))); + isum[iy] = _mm256_add_epi32(isum[iy], sumi); #endif } } for (int iy = 0; iy < nrc_y; ++iy) { acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); - isum[iy] = _mm256_setzero_si256(); } } for (int iy = 0; iy < nrc_y; ++iy) { - auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, acc[iy]); acc[iy] = _mm256_setzero_ps(); - info.store(ix+0, iy, sum); } } } @@ -3127,6 +3200,8 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const #ifdef HAVE_FANCY_SIMD template static void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_iq4_xs_r4_q8_k_avx2(n, vx, bx, info, nrc_x); + return; if constexpr (nrc_y == 1){ mul_mat_iq4_xs_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x); } else { @@ -10529,6 +10604,13 @@ IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31 } +IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x2_t& bits, int8x16_t * qx) { + qx[0] = vqtbl1q_s8(values, vandq_u8( bits.val[0], m4)); + qx[1] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); + qx[2] = vqtbl1q_s8(values, vandq_u8( bits.val[1], m4)); + qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); +} + template void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -10539,43 +10621,92 @@ void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& i auto values = vld1q_s8(iq4k_values); int nbl = n / QK_K; int8x16_t qx[8]; - int8x16x2_t iscales; - int32x4x4_t scales; - float32x4_t acc[nrc_y] = {}; - for (int ix = 0; ix < nrc_x; ix += 4) { + int8x16x4_t iscales; + int32x4x2_t scales; + float32x4_t acc[2*nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 8) { const block_iq4_xs_r4 * iq4 = (const block_iq4_xs_r4 *)((const char *)vx + ix*bx); for (int ibl = 0; ibl < nbl; ++ibl) { - auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d)); - auto sl = vld1q_u8(iq4[ibl].scales_l); - auto sh8 = vld1_u8(iq4[ibl].scales_h); - auto sh = vcombine_u8(sh8, vshr_n_u8(sh8, 2)); - iscales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl, m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32); - iscales.val[1] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl, 4), vandq_u8(sh, m3)), m32); + auto d4_f16 = vld1q_f16((const float16_t *)iq4[ibl].d); + auto d4l = vcvt_f32_f16(vget_low_f16 (d4_f16)); + auto d4h = vcvt_f32_f16(vget_high_f16(d4_f16)); + auto sl = vld1q_u8_x2(iq4[ibl].scales_l); + auto sh = vld1q_u8(iq4[ibl].scales_h); + iscales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32); + iscales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32); + iscales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32); + iscales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32); int32x4_t isum[nrc_y] = {}; - for (int is = 0; is < 2; ++is) { - auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is])); - auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is])); + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[ib64])); + auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[ib64])); scales.val[0] = vmovl_s16(vget_low_s16(iscales16_1)); - scales.val[1] = vmovl_s16(vget_high_s16(iscales16_1)); - scales.val[2] = vmovl_s16(vget_low_s16(iscales16_2)); - scales.val[3] = vmovl_s16(vget_high_s16(iscales16_2)); - for (int ib = 0; ib < 4; ++ib) { - auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib); - prepare_iq4_nl_quants(values, m4, bits, qx); + scales.val[1] = vmovl_s16(vget_low_s16(iscales16_2)); + for (int l = 0; l < 2; ++l) { + uint8x16x2_t bits; + bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l); + bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 32); + prepare_iq4_nl_quants_r8(values, m4, bits, qx+0); + bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 64); + bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 96); + prepare_iq4_nl_quants_r8(values, m4, bits, qx+4); for (int iy = 0; iy < nrc_y; ++iy) { - auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib); - auto sumi = interleaved_dotq(qx, y); - isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi); + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+64*ib64+32*l); + auto sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0); + sumi = vdotq_laneq_s32(sumi, qx[1], y.val[0], 1); + sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 2); + sumi = vdotq_laneq_s32(sumi, qx[3], y.val[0], 3); + sumi = vdotq_laneq_s32(sumi, qx[4], y.val[1], 0); + sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 1); + sumi = vdotq_laneq_s32(sumi, qx[6], y.val[1], 2); + sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3); + isum[iy] = vmlaq_s32(isum[iy], sumi, scales.val[l]); } } } for (int iy = 0; iy < nrc_y; ++iy) { - acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + auto d8 = vdupq_n_f32(q8.scale(iy, ibl)); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(d4l, d8), vcvtq_f32_s32(isum[iy])); + isum[iy] = vdupq_n_s32(0); + } + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[ib64])); + auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[ib64])); + scales.val[0] = vmovl_s16(vget_high_s16(iscales16_1)); + scales.val[1] = vmovl_s16(vget_high_s16(iscales16_2)); + for (int l = 0; l < 2; ++l) { + uint8x16x2_t bits; + bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 16); + bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 48); + prepare_iq4_nl_quants_r8(values, m4, bits, qx+0); + bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 80); + bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l +112); + prepare_iq4_nl_quants_r8(values, m4, bits, qx+4); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+64*ib64+32*l); + auto sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0); + sumi = vdotq_laneq_s32(sumi, qx[1], y.val[0], 1); + sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 2); + sumi = vdotq_laneq_s32(sumi, qx[3], y.val[0], 3); + sumi = vdotq_laneq_s32(sumi, qx[4], y.val[1], 0); + sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 1); + sumi = vdotq_laneq_s32(sumi, qx[6], y.val[1], 2); + sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3); + isum[iy] = vmlaq_s32(isum[iy], sumi, scales.val[l]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto d8 = vdupq_n_f32(q8.scale(iy, ibl)); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(d4h, d8), vcvtq_f32_s32(isum[iy])); } } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, acc[iy]); - acc[iy] = vdupq_n_f32(0.f); + info.store(ix+0, iy, acc[2*iy+0]); + info.store(ix+4, iy, acc[2*iy+1]); + acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f); } } } @@ -12045,81 +12176,54 @@ struct Q6_0_R4_Dequantizer { template void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); + GGML_ASSERT(nrc_x%8 == 0); Q8 q8(info); int nb = n / QK8_0; GGML_ASSERT(nb%4 == 0); - float32x4_t acc[nrc_y] = {}; + float32x4_t acc[2*nrc_y] = {}; + int8x16_t qx[16]; float d8[4*nrc_y]; - for (int ix = 0; ix < nrc_x; ix += 4) { - const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx); + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d))); } for (int k = 0; k < 4; ++k) { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[4*ib4+k].d)); - auto qx1 = vld1q_s8_x4(iq8[4*ib4+k].qs); - auto qx2 = vld1q_s8_x4(iq8[4*ib4+k].qs+64); + auto scales16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d); + auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16)); + auto scales2 = vcvt_f32_f16(vget_high_f16(scales16)); + for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j); for (int iy = 0; iy < nrc_y; ++iy) { auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); - auto sumi = vdupq_n_s32(0); - sumi = vdotq_laneq_s32(sumi, qx1.val[0], y.val[0], 0); - sumi = vdotq_laneq_s32(sumi, qx1.val[1], y.val[1], 0); - sumi = vdotq_laneq_s32(sumi, qx1.val[2], y.val[0], 1); - sumi = vdotq_laneq_s32(sumi, qx1.val[3], y.val[1], 1); - sumi = vdotq_laneq_s32(sumi, qx2.val[0], y.val[0], 2); - sumi = vdotq_laneq_s32(sumi, qx2.val[1], y.val[1], 2); - sumi = vdotq_laneq_s32(sumi, qx2.val[2], y.val[0], 3); - sumi = vdotq_laneq_s32(sumi, qx2.val[3], y.val[1], 3); - auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k])); - acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); + auto sumi1 = vdupq_n_s32(0); + auto sumi2 = vdupq_n_s32(0); + sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0); + sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0); + sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1); + sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1); + sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2); + sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2); + sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3); + sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3); + sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0); + sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0); + sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1); + sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1); + sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2); + sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2); + sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3); + sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3); + auto dy = vdupq_n_f32(d8[4*iy+k]); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1)); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2)); } } } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, acc[iy]); - acc[iy] = vdupq_n_f32(0.f); - } - } -} - -template -void mul_mat_q8_0_r4_q8_0_128(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); - GGML_ASSERT(n == 128); - int8x16x4_t qx[8]; - float32x4_t scales[4]; - float32x4_t scales_y[4]; - for (int ix = 0; ix < nrc_x; ix += 4) { - const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx); - for (int k = 0; k < 4; ++k) { - scales[k] = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[k].d)); - qx[2*k+0] = vld1q_s8_x4(iq8[k].qs); - qx[2*k+1] = vld1q_s8_x4(iq8[k].qs+64); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto by = (const block_q8_0_x4 *)info.src1_row(iy); - auto d8 = vcvt_f32_f16(vld1_f16((const float16_t *)by->d)); - scales_y[0] = vmulq_laneq_f32(scales[0], d8, 0); - scales_y[1] = vmulq_laneq_f32(scales[1], d8, 1); - scales_y[2] = vmulq_laneq_f32(scales[2], d8, 2); - scales_y[3] = vmulq_laneq_f32(scales[3], d8, 3); - auto sumf = vdupq_n_f32(0.f); - for (int k = 0; k < 4; ++k) { - auto y = vld1q_s8_x2(by->qs+32*k); - auto sumi = vdupq_n_s32(0); - sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[0], y.val[0], 0); - sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[1], y.val[1], 0); - sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[2], y.val[0], 1); - sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[3], y.val[1], 1); - sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[0], y.val[0], 2); - sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[1], y.val[1], 2); - sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[2], y.val[0], 3); - sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[3], y.val[1], 3); - sumf = vfmaq_f32(sumf, scales_y[k], vcvtq_f32_s32(sumi)); - } - info.store(ix, iy, sumf); + info.store(ix+0, iy, acc[2*iy+0]); + info.store(ix+4, iy, acc[2*iy+1]); + acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f); } } } @@ -12763,22 +12867,25 @@ struct HelperQ80R4 : public BaseHelper { Base::stride = (D/QK8_0)*sizeof(block_q8_0); } - static std::vector repack(int nk, const HelperQ80 q8) { + static std::vector repack(int nk, const HelperQ80 q8) { static_assert(D%QK8_0 == 0); - GGML_ASSERT(nk%4 == 0); + GGML_ASSERT(nk%8 == 0); constexpr int nblock = D/QK8_0; - std::vector result(nblock * nk/4); + std::vector result(nblock * nk/8); auto y = result.data(); - const block_q8_0 * x4[4]; - for (int row = 0; row < nk; row += 4) { - for (int k = 0; k < 4; ++k) x4[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride); + const block_q8_0 * x8[8]; +#ifdef __ARM_NEON + int8x16x2_t m0, m1, m2, m3; +#endif + for (int row = 0; row < nk; row += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride); for (int ib = 0; ib < nblock; ++ib) { - for (int k = 0; k < 4; ++k) y[ib].d[k] = x4[k][ib].d; + for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d; #ifdef __AVX2__ - auto m0 = _mm256_loadu_si256((const __m256i *)x4[0][ib].qs); - auto m1 = _mm256_loadu_si256((const __m256i *)x4[1][ib].qs); - auto m2 = _mm256_loadu_si256((const __m256i *)x4[2][ib].qs); - auto m3 = _mm256_loadu_si256((const __m256i *)x4[3][ib].qs); + auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs), _mm_loadu_si128((const __m128i *)x8[0][ib].qs)); + auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs), _mm_loadu_si128((const __m128i *)x8[1][ib].qs)); + auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs), _mm_loadu_si128((const __m128i *)x8[2][ib].qs)); + auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs), _mm_loadu_si128((const __m128i *)x8[3][ib].qs)); auto t0 = _mm256_unpacklo_epi32(m0, m1); auto t1 = _mm256_unpacklo_epi32(m2, m3); auto t2 = _mm256_unpackhi_epi32(m0, m1); @@ -12791,32 +12898,50 @@ struct HelperQ80R4 : public BaseHelper { _mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1); _mm256_storeu_si256((__m256i *)y[ib].qs + 2, m2); _mm256_storeu_si256((__m256i *)y[ib].qs + 3, m3); + m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[0][ib].qs+1)); + m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[1][ib].qs+1)); + m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[2][ib].qs+1)); + m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[3][ib].qs+1)); + t0 = _mm256_unpacklo_epi32(m0, m1); + t1 = _mm256_unpacklo_epi32(m2, m3); + t2 = _mm256_unpackhi_epi32(m0, m1); + t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); + _mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0); + _mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1); + _mm256_storeu_si256((__m256i *)y[ib].qs + 6, m2); + _mm256_storeu_si256((__m256i *)y[ib].qs + 7, m3); #elif defined __ARM_NEON - auto m0 = vld1q_s8_x2(x4[0][ib].qs); - auto m1 = vld1q_s8_x2(x4[1][ib].qs); - auto m2 = vld1q_s8_x2(x4[2][ib].qs); - auto m3 = vld1q_s8_x2(x4[3][ib].qs); - auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); - auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); - m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); - m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); - m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); - m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); - row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); - row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); - m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); - m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); - m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); - m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); - vst1q_s8_x2(y[ib].qs + 0, m0); - vst1q_s8_x2(y[ib].qs + 32, m1); - vst1q_s8_x2(y[ib].qs + 64, m2); - vst1q_s8_x2(y[ib].qs + 96, m3); + for (int l = 0; l < 2; ++l) { + m0.val[0] = vld1q_s8(x8[0][ib].qs+16*l); m0.val[1] = vld1q_s8(x8[4][ib].qs+16*l); + m1.val[0] = vld1q_s8(x8[1][ib].qs+16*l); m1.val[1] = vld1q_s8(x8[5][ib].qs+16*l); + m2.val[0] = vld1q_s8(x8[2][ib].qs+16*l); m2.val[1] = vld1q_s8(x8[6][ib].qs+16*l); + m3.val[0] = vld1q_s8(x8[3][ib].qs+16*l); m3.val[1] = vld1q_s8(x8[7][ib].qs+16*l); + auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); + auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); + m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); + row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); + m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + vst1q_s8_x2(y[ib].qs + 0 + 128*l, m0); + vst1q_s8_x2(y[ib].qs + 32 + 128*l, m1); + vst1q_s8_x2(y[ib].qs + 64 + 128*l, m2); + vst1q_s8_x2(y[ib].qs + 96 + 128*l, m3); + } #else for (int l = 0; l < 4; ++l) { - for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) { - y[ib].qs[32*l+4*k+i+ 0] = x4[k][ib].qs[i+4*l+ 0]; - y[ib].qs[32*l+4*k+i+16] = x4[k][ib].qs[i+4*l+16]; + for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; } } #endif @@ -12826,7 +12951,7 @@ struct HelperQ80R4 : public BaseHelper { return result; } - std::vector r4; + std::vector r4; }; template @@ -13370,78 +13495,6 @@ struct FlashQKV { qkv_cache_t qkv_cache[D*q_step] = {}; }; -#ifdef HAVE_FANCY_SIMD -template -static void mul_mat_q8_0_r4_q8_1_128([[maybe_unused]] int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); - GGML_ASSERT(n == 128); - //Q8 q8(info); - __m512i qx[16]; - __m512 scales[4]; - __m512 scales_m[4]; - __m512 dy[4]; - auto m127 = _mm512_set1_epi8(127); - for (int ix = 0; ix < nrc_x; ix += 8) { - const block_q8_0_x4 * q8l = (const block_q8_0_x4 *)((const char *)vx + (ix+0)*bx); - const block_q8_0_x4 * q8h = (const block_q8_0_x4 *)((const char *)vx + (ix+4)*bx); - for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8l[k].d)); - auto scales1 = _mm256_set_m128(scales128, scales128); - scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8h[k].d)); - auto scales2 = _mm256_set_m128(scales128, scales128); - scales[k] = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - scales_m[k] = _mm512_mul_ps(scales[k], _mm512_set1_ps(-63.5f)); - qx[4*k+0] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+0)), - _mm256_loadu_si256((const __m256i *)q8h[k].qs+0), 1); - qx[4*k+1] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+1)), - _mm256_loadu_si256((const __m256i *)q8h[k].qs+1), 1); - qx[4*k+2] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+2)), - _mm256_loadu_si256((const __m256i *)q8h[k].qs+2), 1); - qx[4*k+3] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+3)), - _mm256_loadu_si256((const __m256i *)q8h[k].qs+3), 1); - qx[4*k+0] = _mm512_add_epi8(qx[4*k+0], m127); - qx[4*k+1] = _mm512_add_epi8(qx[4*k+1], m127); - qx[4*k+2] = _mm512_add_epi8(qx[4*k+2], m127); - qx[4*k+3] = _mm512_add_epi8(qx[4*k+3], m127); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto by = (const block_q8_1_x4 *)info.src1_row(iy); - //auto dall = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][0].d)); - auto dall = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)by->d)); - auto d128 = _mm256_castps256_ps128(dall); - auto m128 = _mm256_extractf128_ps(dall, 1); - auto m256 = _mm256_set_m128(m128, m128); - auto m512 = _mm512_insertf32x8(_mm512_castps256_ps512(m256), m256, 1); - auto sumf = _mm512_mul_ps(scales_m[0], _mm512_shuffle_ps(m512, m512, 0x00)); - sumf = _mm512_fmadd_ps(scales_m[1], _mm512_shuffle_ps(m512, m512, 0x55), sumf); - sumf = _mm512_fmadd_ps(scales_m[2], _mm512_shuffle_ps(m512, m512, 0xaa), sumf); - sumf = _mm512_fmadd_ps(scales_m[3], _mm512_shuffle_ps(m512, m512, 0xff), sumf); - auto d256 = _mm256_set_m128(d128, d128); - auto d512 = _mm512_insertf32x8(_mm512_castps256_ps512(d256), d256, 1); - dy[0] = _mm512_mul_ps(scales[0], _mm512_shuffle_ps(d512, d512, 0x00)); - dy[1] = _mm512_mul_ps(scales[1], _mm512_shuffle_ps(d512, d512, 0x55)); - dy[2] = _mm512_mul_ps(scales[2], _mm512_shuffle_ps(d512, d512, 0xaa)); - dy[3] = _mm512_mul_ps(scales[3], _mm512_shuffle_ps(d512, d512, 0xff)); - for (int k = 0; k < 4; ++k) { - //auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][0].qs+k); - auto y8 = _mm256_loadu_si256((const __m256i*)by->qs+k); - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - sumf = _mm512_fmadd_ps(dy[k], _mm512_cvtepi32_ps(sumi), sumf); - } - auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sumf, 0), _mm512_extractf32x4_ps(sumf, 1)); - auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sumf, 2), _mm512_extractf32x4_ps(sumf, 3)); - info.store(ix+0, iy, sum1); - info.store(ix+4, iy, sum2); - } - } -} -#endif - template struct FlashQKfp32 { static_assert(D%F16::block_size == 0 && D <= 256); @@ -13706,44 +13759,9 @@ struct FlashQKfp32 { } else if constexpr (std::is_same_v>) { #ifdef __aarch64__ - if constexpr (D == 128) { - if (q_step >= 64 && nq >= 64) { - return std::make_pair(mul_mat_q8_0_r4_q8_0_128<64>, 64); - } - else if (q_step >= 32 && nq >= 32) { - return std::make_pair(mul_mat_q8_0_r4_q8_0_128<32>, 32); - } - else if (q_step >= 16 && nq >= 16) { - return std::make_pair(mul_mat_q8_0_r4_q8_0_128<16>, 16); - } - else { - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0_128, nq); - } - } else { - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq); - } - //MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq); -#else -#ifdef HAVE_FANCY_SIMD - if constexpr (D == 128) { - if (q_step >= 64 && nq >= 64) { - return std::make_pair(mul_mat_q8_0_r4_q8_1_128<64>, 64); - } - else if (q_step >= 32 && nq >= 32) { - return std::make_pair(mul_mat_q8_0_r4_q8_1_128<32>, 32); - } - else if (q_step >= 16 && nq >= 16) { - return std::make_pair(mul_mat_q8_0_r4_q8_1_128<16>, 16); - } - else { - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1_128, nq); - } - } else { - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1, nq); - } + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq); #else MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1, nq); -#endif #endif } else if constexpr (std::is_same_v>) { diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 221bc48c..59a36c5c 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -3709,63 +3709,63 @@ void vec_dot_q4_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t b // // ========================================= q8_0_r4 // -void quantize_row_q8_0_r4_ref(const float * x, block_q8_0_x4 * y, int64_t k) { +void quantize_row_q8_0_r4_ref(const float * x, block_q8_0_r8 * y, int64_t k) { // we assume we are called with 4 rows - quantize_q8_0_r4(x, (void *)y, 4, k/4, nullptr); + quantize_q8_0_r4(x, (void *)y, 8, k/8, nullptr); } void quantize_row_q8_0_r4(const float * x, void * y, int64_t k) { // we assume we are called with 4 rows - quantize_q8_0_r4(x, y, 4, k/4, nullptr); + quantize_q8_0_r4(x, y, 8, k/8, nullptr); } -static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8_0_x4 * y) { - GGML_ASSERT(nrows%4 == 0); +static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8_0_r8 * y) { + GGML_ASSERT(nrows%8 == 0); GGML_ASSERT(n_per_row%QK8_0 == 0); int nblock = n_per_row/QK8_0; - const block_q8_0 * x4[4]; - for (int row = 0; row < nrows; row += 4) { - for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + const block_q8_0 * x8[8]; + for (int row = 0; row < nrows; row += 8) { + for (int k = 0; k < 8; ++k) x8[k] = x + nblock*k; for (int ib = 0; ib < nblock; ++ib) { - for (int k = 0; k < 4; ++k) y[ib].d[k] = x4[k][ib].d; + for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d; for (int l = 0; l < 4; ++l) { - for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) { - y[ib].qs[32*l+4*k+i+ 0] = x4[k][ib].qs[i+4*l+ 0]; - y[ib].qs[32*l+4*k+i+16] = x4[k][ib].qs[i+4*l+16]; + for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; } } } - x += 4*nblock; + x += 8*nblock; y += nblock; } } size_t quantize_q8_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { - GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(nrows%8 == 0); auto row_size_0 = ggml_row_size(GGML_TYPE_Q8_0, n_per_row); - std::vector qtmp(4*row_size_0); + std::vector qtmp(8*row_size_0); char * qrow = (char *)dst; - for (int row = 0; row < nrows; row += 4) { - quantize_q8_0(src, qtmp.data(), 4, n_per_row, imatrix); - repack_q8_0(4, n_per_row, (const block_q8_0 *)qtmp.data(), (block_q8_0_x4 *)qrow); - src += 4*n_per_row; - qrow += 4*row_size_0; + for (int row = 0; row < nrows; row += 8) { + quantize_q8_0(src, qtmp.data(), 8, n_per_row, imatrix); + repack_q8_0(8, n_per_row, (const block_q8_0 *)qtmp.data(), (block_q8_0_r8 *)qrow); + src += 8*n_per_row; + qrow += 8*row_size_0; } return nrows*row_size_0; } -void dequantize_row_q8_0_r4(const block_q8_0_x4 * x, float * y, int64_t k) { +void dequantize_row_q8_0_r4(const block_q8_0_r8 * x, float * y, int64_t k) { // we assume we are called with 4 rows - int n_per_row = k/4; + int n_per_row = k/8; int nb = n_per_row/QK8_0; - float * yk[4]; - for (int k = 0; k < 4; ++k) yk[k] = y + k*n_per_row; + float * yk[8]; + for (int k = 0; k < 8; ++k) yk[k] = y + k*n_per_row; for (int ib = 0; ib < nb; ++ib) { - for (int k = 0; k < 4; ++k) { + for (int k = 0; k < 8; ++k) { float scale = GGML_FP16_TO_FP32(x[ib].d[k]); for (int l = 0; l < 4; ++l) for (int i = 0; i < 4; ++i) { - yk[k][QK8_0*ib+4*l+i+ 0] = scale * x[ib].qs[QK8_0*l+4*k+i+ 0]; - yk[k][QK8_0*ib+4*l+i+16] = scale * x[ib].qs[QK8_0*l+4*k+i+16]; + yk[k][QK8_0*ib+4*l+i+ 0] = scale * x[ib].qs[32*l+4*k+i+ 0]; + yk[k][QK8_0*ib+4*l+i+16] = scale * x[ib].qs[32*l+4*k+i+128]; } } } @@ -3987,93 +3987,77 @@ void vec_dot_q6_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t b // void quantize_row_iq4_xs_r4_ref(const float * x, block_iq4_xs_r4 * y, int64_t k) { - quantize_iq4_xs_r4(x, (void *)y, 4, k/4, nullptr); + quantize_iq4_xs_r4(x, (void *)y, 8, k/8, nullptr); } void quantize_row_iq4_xs_r4(const float * x, void * y, int64_t k) { - quantize_iq4_xs_r4(x, y, 4, k/4, nullptr); + quantize_iq4_xs_r4(x, y, 8, k/8, nullptr); } static void repack_iq4_xs(int nrows, int n_per_row, const block_iq4_xs * x, block_iq4_xs_r4 * y) { - GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(nrows%8 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; - const block_iq4_xs * x4[4]; - for (int row = 0; row < nrows; row += 4) { - for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + const block_iq4_xs * x8[8]; + for (int row = 0; row < nrows; row += 8) { + for (int k = 0; k < 8; ++k) x8[k] = x + nblock*k; for (int ibl = 0; ibl < nblock; ++ibl) { - std::memset(y[ibl].scales_l, 0, QK_K/16); - std::memset(y[ibl].scales_h, 0, QK_K/32); - for (int k = 0; k < 4; ++k) { - y[ibl].d[k] = x4[k][ibl].d; + std::memset(y[ibl].scales_l, 0, QK_K/8); + std::memset(y[ibl].scales_h, 0, QK_K/16); + for (int k = 0; k < 8; ++k) { + y[ibl].d[k] = x8[k][ibl].d; for (int ib = 0; ib < QK_K/32; ++ib) { - uint8_t sl = (x4[k][ibl].scales_l[ib/2] >> 4*(ib%2)) & 0xf; - uint8_t sh = (x4[k][ibl].scales_h >> 2*ib) & 3; - int i = 4*ib + k; - y[ibl].scales_l[i%16] |= (sl << 4*(i/16)); - y[ibl].scales_h[i%8 ] |= (sh << 2*(i/8)); - } - } - for (int ib = 0; ib < QK_K/32; ++ib) { - for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) { - y[ibl].qs[64*ib+4*k+i+ 0] = (x4[k][ibl].qs[16*ib+i+0] & 0xf) | ((x4[k][ibl].qs[16*ib+i+ 8] & 0x0f) << 4); // 0....3 + 8...11 from each row - y[ibl].qs[64*ib+4*k+i+16] = (x4[k][ibl].qs[16*ib+i+0] >> 4) | ((x4[k][ibl].qs[16*ib+i+ 8] & 0xf0)); // 16...19 + 24...27 from each row - y[ibl].qs[64*ib+4*k+i+32] = (x4[k][ibl].qs[16*ib+i+4] & 0xf) | ((x4[k][ibl].qs[16*ib+i+12] & 0x0f) << 4); // 4....7 + 12...15 from each row - y[ibl].qs[64*ib+4*k+i+48] = (x4[k][ibl].qs[16*ib+i+4] >> 4) | ((x4[k][ibl].qs[16*ib+i+12] & 0xf0)); // 20...23 + 28...31 from each row + uint8_t sl = (x8[k][ibl].scales_l[ib/2] >> 4*(ib%2)) & 0xf; + uint8_t sh = (x8[k][ibl].scales_h >> 2*ib) & 3; + int i = 8*ib + k; + y[ibl].scales_l[i%32] |= (sl << 4*(i/32)); + y[ibl].scales_h[i%16] |= (sh << 2*(i/16)); + for (int i = 0; i < 4; ++i) { + y[ibl].qs[128*ib+4*k+i+ 0] = (x8[k][ibl].qs[16*ib+i+0] & 0xf) | ((x8[k][ibl].qs[16*ib+i+ 4] & 0xf) << 4); + y[ibl].qs[128*ib+4*k+i+32] = (x8[k][ibl].qs[16*ib+i+8] & 0xf) | ((x8[k][ibl].qs[16*ib+i+12] & 0xf) << 4); + y[ibl].qs[128*ib+4*k+i+64] = (x8[k][ibl].qs[16*ib+i+0] >> 4) | ((x8[k][ibl].qs[16*ib+i+ 4] >> 4) << 4); + y[ibl].qs[128*ib+4*k+i+96] = (x8[k][ibl].qs[16*ib+i+8] >> 4) | ((x8[k][ibl].qs[16*ib+i+12] >> 4) << 4); + } } } } - x += 4*nblock; + x += 8*nblock; y += nblock; } } size_t quantize_iq4_xs_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { - GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(nrows%8 == 0); GGML_ASSERT(n_per_row%QK_K == 0); char * qcur = (char *)dst; auto row_size = ggml_row_size(GGML_TYPE_IQ4_XS, n_per_row); - std::vector qtmp(4*row_size); - for (int row = 0; row < nrows; row += 4) { - quantize_iq4_xs(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_iq4_xs(4, n_per_row, (const block_iq4_xs *)qtmp.data(), (block_iq4_xs_r4 *)qcur); - qcur += 4*row_size; - src += 4*n_per_row; + std::vector qtmp(8*row_size); + for (int row = 0; row < nrows; row += 8) { + quantize_iq4_xs(src, (void *)qtmp.data(), 8, n_per_row, imatrix); + repack_iq4_xs(8, n_per_row, (const block_iq4_xs *)qtmp.data(), (block_iq4_xs_r4 *)qcur); + qcur += 8*row_size; + src += 8*n_per_row; } return nrows*row_size; } void dequantize_row_iq4_xs_r4(const block_iq4_xs_r4 * x, float * y, int64_t k) { - auto n_per_row = k/4; - float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + auto n_per_row = k/8; + float * y8[8]; + for (int k = 0; k < 8; ++k) y8[k] = y + n_per_row*k; int nblock = n_per_row/QK_K; for (int ibl = 0; ibl < nblock; ++ibl) { - for (int k = 0; k < 4; ++k) { + for (int k = 0; k < 8; ++k) { const float d = GGML_FP16_TO_FP32(x[ibl].d[k]); for (int ib = 0; ib < QK_K/32; ++ib) { - int is = 4*ib + k; - float dl = d * ((((x[ibl].scales_l[is%16] >> 4*(is/16)) & 0xf) | (((x[ibl].scales_h[is%8] >> 2*(is/8)) & 3) << 4)) - 32); - for (int i = 0; i < 4; ++i) { - y4[k][QK_K*ibl+32*ib+i+ 0] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+ 0] & 0xf]; - y4[k][QK_K*ibl+32*ib+i+ 8] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+ 0] >> 4]; - y4[k][QK_K*ibl+32*ib+i+16] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+16] & 0xf]; - y4[k][QK_K*ibl+32*ib+i+24] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+16] >> 4]; - y4[k][QK_K*ibl+32*ib+i+ 4] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+32] & 0xf]; - y4[k][QK_K*ibl+32*ib+i+12] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+32] >> 4]; - y4[k][QK_K*ibl+32*ib+i+20] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+48] & 0xf]; - y4[k][QK_K*ibl+32*ib+i+28] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+48] >> 4]; + int is = 8*ib + k; + float dl = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32); + for (int l = 0; l < 4; ++l) for (int i = 0; i < 4; ++i) { + y8[k][QK_K*ibl+32*ib+8*l+i+0] = dl * iq4k_values[x[ibl].qs[128*ib+4*k+i+32*l] & 0xf]; + y8[k][QK_K*ibl+32*ib+8*l+i+4] = dl * iq4k_values[x[ibl].qs[128*ib+4*k+i+32*l] >> 4]; } } } - //dequantize_row_iq4_xs(x + ib, ytmp, QK_K); - //for (int k = 0; k < 4; ++k) { - // for (int l = 0; l < 16; ++l) { - // for (int i = 0; i < 4; ++i) { - // //y4[k][ib*kBlockSize + i + 16*(l%4) + 4*(l/4)] = ytmp[16*l + 4*k + i]; - // y4[k][ib*kBlockSize + i + 8*(l%8) + 4*(l/8)] = ytmp[16*l + 4*k + i]; - // } - // } - //} } } @@ -6063,7 +6047,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { { GGML_TYPE_IQ3_K, { GGML_TYPE_IQ3_K_R4, 4, (Repack::repack_func)repack_iq3_k} }, { GGML_TYPE_IQ4_K, { GGML_TYPE_IQ4_K_R4, 4, (Repack::repack_func)repack_iq4_k} }, { GGML_TYPE_IQ5_K, { GGML_TYPE_IQ5_K_R4, 4, (Repack::repack_func)repack_iq5_k} }, - { GGML_TYPE_IQ4_XS, { GGML_TYPE_IQ4_XS_R4, 4, (Repack::repack_func)repack_iq4_xs} }, + { GGML_TYPE_IQ4_XS, { GGML_TYPE_IQ4_XS_R4, 8, (Repack::repack_func)repack_iq4_xs} }, { GGML_TYPE_IQ4_KS, { GGML_TYPE_IQ4_KS_R4, 4, (Repack::repack_func)repack_iq4_ks} }, { GGML_TYPE_IQ4_NL, { GGML_TYPE_IQ4_NL_R4, 4, (Repack::repack_func)repack_iq4_nl} }, { GGML_TYPE_IQ2_BN, { GGML_TYPE_IQ2_BN_R4, 4, (Repack::repack_func)repack_iq2_bn} }, @@ -6080,7 +6064,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { { GGML_TYPE_Q4_0, { GGML_TYPE_Q4_0_R4, 4, (Repack::repack_func)repack_q4_0} }, { GGML_TYPE_Q5_0, { GGML_TYPE_Q5_0_R4, 4, (Repack::repack_func)repack_q5_0} }, { GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4, (Repack::repack_func)repack_q6_0} }, - { GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R4, 4, (Repack::repack_func)repack_q8_0} }, + { GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R4, 8, (Repack::repack_func)repack_q8_0} }, { GGML_TYPE_Q8_K, { GGML_TYPE_Q8_K_R8, 8, (Repack::repack_func)repack_q8_k} }, #ifdef __AVX512BF16__ { GGML_TYPE_BF16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16}}, diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 729b0ec0..64860b4d 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -73,10 +73,10 @@ size_t quantize_q4_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds void dequantize_row_q4_0_r4(const block_iq4_nl_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_q4_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void quantize_row_q8_0_r4_ref(const float * GGML_RESTRICT x, block_q8_0_x4 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0_r4_ref(const float * GGML_RESTRICT x, block_q8_0_r8 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_0_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); size_t quantize_q8_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -void dequantize_row_q8_0_r4(const block_q8_0_x4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q8_0_r4(const block_q8_0_r8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_q8_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void quantize_row_q5_0_r4_ref(const float * GGML_RESTRICT x, block_q5_0_r4 * GGML_RESTRICT y, int64_t k); diff --git a/src/llama.cpp b/src/llama.cpp index c2bc5cc0..836fd97a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -16906,8 +16906,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s else chunk_size_multiplier = 4; } else if (new_type == GGML_TYPE_IQ4_XS_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_XS; - else chunk_size_multiplier = 4; + if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_IQ4_XS; + else chunk_size_multiplier = 8; } else if (new_type == GGML_TYPE_Q4_0_R4) { if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0; @@ -16922,8 +16922,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s else chunk_size_multiplier = 4; } else if (new_type == GGML_TYPE_Q8_0_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q8_0; - else chunk_size_multiplier = 4; + if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0; + else chunk_size_multiplier = 8; } else if (new_type == GGML_TYPE_Q2_K_R4) { if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q2_K; From f725576345582144dfebd7f1e6c8ac93eb1eb0ca Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 27 Jan 2025 18:53:47 +0200 Subject: [PATCH 02/14] Minor performance improvements (#179) * Try interleaving 8 rows for iq4_xs On Zen4, PP-512 goes up from ~260 t/s to 288 t/s for L3-8B. TG-128 reaches max. performance at 2 threads and is slightly higher than 4 interleaved rows (14.48 t/s vs 13.11 t/s @ 2 threads and 14/28 t/s @ 4 threads). * Try interleaving 8 iq4_xs rows It is also faster on AVX2. This is the NEON implementation. It is tiny bit faster than 4 interleaved rows (~0.5%). So, this looks like a winner given the Zen4/AVX2 improvement without associated NEON egression. * Cleanup * 8-rows interleaved q8_0 (AVX2) * 8-rows interleaved q8_0 (Zen4) * 8-rows interleaved q8_0 (Zen4) - slightly better PP-512 is now 284 t/s compared to 257 t/s for 4-rows interleaved. TG-128 reaches peak of 8.16 t/s at just 2 threads compared to 7.95 t/s @ 4 threads before. * 8-rows interleaved q8_0 (NEON) PP-512 is slightly better (138 t/s vs 132.5 t/s), TG-128 is about the same. * FA: repack Q8_0 to Q8_0_R8 * Remove special purpose mul_mat_q8_0_r4_q8_1_128 (Zen4) * FA: repack Q8_0 to Q8_0_R8 (NEON) Very slightly faster than the general purpose gemm, slightly slower than the D = 128 special case gemm mul_mat_q8_0_r4_q8_0_128. Still removing mul_mat_q8_0_r4_q8_0_128 as we simply don't have enough vector registers to hold 8 interleaved rows, so there is no point to have the special purpose implementation. * q4_0_r8 (AVX2) * q4_0_r8 (NEON) Tiny bit faster PP (~128 vs ~126 t/s), same TG. * q4_0_r8 (Zen4) Somehow only marginally faster? 268 t/s vs 261 t/s * q4_0_r8 (Zen4) - slightly better 282 t/s for a pure q4_0 L3-8B quantization. * Apply platform specific modifications when repacking E.g., on NEON it is useful to pre-apply q ^ 0x88 to q4_0. This results in a ~3% performance improvement. Hence, * Changed the signature of the repack_X functions to take a bool argument indicating if the repacking is done online and, if so, apply modifications as appropriate while repacking. * Added iqk_modify_tensor to apply modifications to models that have already been repacked while loading the model. Caveat: just like rtr, this needs to have mmap disabled (else one would need to move the data to a not mmap-ed buffer, so much more complicated). * Apply platform specific modifications when repacking On Zen4 we can pre-convert the signed quants in q8_0_r4 and q8_k_r8 to unsigned thus avoiding these operations in matrix multiplications. With this change we hit PP-512 = 382.40 t/s (q8_k_r8) PP-512 = 306.92 t/s (q8_0_r4) for L3-8B on a Ryzen-7950X using q8_0 KV-cache. * Process up to 16 columns per kernel call for q8_k_r8 This brings PP-512 up to 389 t/s. * Be able to load Deepseek-v2-Lite --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-common.h | 7 +- ggml/src/iqk/iqk_mul_mat.cpp | 308 +++++++++++++++++++++++++--------- ggml/src/iqk/iqk_quantize.cpp | 305 ++++++++++++++++++++++----------- ggml/src/iqk/iqk_quantize.h | 5 +- src/llama.cpp | 16 +- 5 files changed, 455 insertions(+), 186 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index d08870ad..023b0b63 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -528,7 +528,12 @@ typedef struct { ggml_half d[4]; uint8_t qs[2*QK4_NL]; } block_iq4_nl_r4; -static_assert(sizeof(block_iq4_nl_r4) == 4*sizeof(ggml_half) + 2*QK4_NL, "wrong iq4_nl_x4 block size/padding"); +static_assert(sizeof(block_iq4_nl_r4) == 4*sizeof(ggml_half) + 2*QK4_NL, "wrong iq4_nl_r4 block size/padding"); +typedef struct { + ggml_half d[8]; + uint8_t qs[4*QK4_NL]; +} block_iq4_nl_r8; +static_assert(sizeof(block_iq4_nl_r8) == 8*sizeof(ggml_half) + 4*QK4_NL, "wrong iq4_nl_r8 block size/padding"); typedef struct { ggml_half d; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index d8273415..8d2b4090 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -242,7 +242,6 @@ struct MulMat { case GGML_TYPE_Q4_K_R4: case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q6_K_R4: - case GGML_TYPE_Q4_0_R4: case GGML_TYPE_Q5_0_R4: case GGML_TYPE_Q6_0_R4: case GGML_TYPE_IQ4_NL_R4: @@ -258,6 +257,7 @@ struct MulMat { case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_BN_R4: return 4; case GGML_TYPE_IQ4_XS_R4: + case GGML_TYPE_Q4_0_R4: case GGML_TYPE_Q8_0_R4: case GGML_TYPE_Q8_K_R8: return 8; case GGML_TYPE_BF16_R16: return 16; @@ -2538,52 +2538,119 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data } #endif +inline void prepare_q4_0_quants_avx2(const uint8_t * qs, __m256i * v, const __m256i& m4) { + auto bits1 = _mm256_loadu_si256((const __m256i *)qs+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)qs+1); + auto bits3 = _mm256_loadu_si256((const __m256i *)qs+2); + auto bits4 = _mm256_loadu_si256((const __m256i *)qs+3); + v[0] = _mm256_and_si256(bits1, m4); + v[1] = _mm256_and_si256(bits2, m4); + v[2] = _mm256_and_si256(bits3, m4); + v[3] = _mm256_and_si256(bits4, m4); + v[4] = _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4); + v[5] = _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4); + v[6] = _mm256_and_si256(_mm256_srli_epi16(bits3, 4), m4); + v[7] = _mm256_and_si256(_mm256_srli_epi16(bits4, 4), m4); +} + +inline __m256i accum_q4_0_quants(const __m256i * v, const int8_t * qs) { + auto y4l = _mm_loadu_si128((const __m128i*)qs+0); + auto y4h = _mm_loadu_si128((const __m128i*)qs+1); + auto yl = MM256_SET_M128I(y4l, y4l); + auto yh = MM256_SET_M128I(y4h, y4h); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, v[0], _mm256_shuffle_epi32(yl, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, v[1], _mm256_shuffle_epi32(yl, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, v[2], _mm256_shuffle_epi32(yl, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, v[3], _mm256_shuffle_epi32(yl, 0xff)); + sumi = _mm256_dpbusd_epi32(sumi, v[4], _mm256_shuffle_epi32(yh, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, v[5], _mm256_shuffle_epi32(yh, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, v[6], _mm256_shuffle_epi32(yh, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, v[7], _mm256_shuffle_epi32(yh, 0xff)); +#else + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(v[0], _mm256_shuffle_epi32(yl, 0x00)), + _mm256_maddubs_epi16(v[1], _mm256_shuffle_epi32(yl, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(v[2], _mm256_shuffle_epi32(yl, 0xaa)), + _mm256_maddubs_epi16(v[3], _mm256_shuffle_epi32(yl, 0xff))); + auto sumi3 = _mm256_add_epi16(_mm256_maddubs_epi16(v[4], _mm256_shuffle_epi32(yh, 0x00)), + _mm256_maddubs_epi16(v[5], _mm256_shuffle_epi32(yh, 0x55))); + auto sumi4 = _mm256_add_epi16(_mm256_maddubs_epi16(v[6], _mm256_shuffle_epi32(yh, 0xaa)), + _mm256_maddubs_epi16(v[7], _mm256_shuffle_epi32(yh, 0xff))); + auto sumi = _mm256_add_epi32(_mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(sumi1, sumi2)), + _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(sumi3, sumi4))); +#endif + return sumi; +} + template static void mul_mat_q4_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); + GGML_ASSERT(nrc_x%8 == 0); Q8 q8(info); auto m4 = _mm256_set1_epi8(0xf); - auto m1 = _mm256_set1_epi16(1); int nb = n / QK4_NL; + __m256i v[8]; GGML_ASSERT(nb%4 == 0); + if constexpr (nrc_y == 1) { + union { __m256 vec; float val[8]; } helper; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_iq4_nl_r8 * iq4 = (const block_iq4_nl_r8 *)((const char *)vx + ix*bx); + auto acc1 = _mm256_setzero_ps(); + auto acc2 = _mm256_setzero_ps(); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + helper.vec = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)); + for (int k = 0; k < 4; ++k) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d)); + prepare_q4_0_quants_avx2(iq4[4*ib4+k].qs, v, m4); + auto sumi = accum_q4_0_quants(v, q8.y[0][ib4].qs+32*k); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(helper.val[k])); + acc1 = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc1); + acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(helper.val[k+4]), acc2); + } + } + acc1 = _mm256_fmadd_ps(acc2, _mm256_set1_ps(-8.f), acc1); + info.store(ix, 0, acc1); + } + } + else { __m256 acc[nrc_y] = {}; float d8[8*nrc_y]; - for (int ix = 0; ix < nrc_x; ix += 4) { - const block_iq4_nl_r4 * iq4 = (const block_iq4_nl_r4 *)((const char *)vx + ix*bx); + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_iq4_nl_r8 * iq4 = (const block_iq4_nl_r8 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { - for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); - _mm256_storeu_ps(d8 + 8*iy, scales); + { + __m256 d4[4]; + for (int k = 0; k < 4; ++k) { + d4[k] = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d)); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); + _mm256_storeu_ps(d8 + 8*iy, scales); + auto m4 = _mm256_extractf128_ps(scales, 1); + auto m8 = _mm256_set_m128(m4, m4); + auto sumf = _mm256_mul_ps(d4[0], _mm256_shuffle_ps(m8, m8, 0x00)); + sumf = _mm256_fmadd_ps(d4[1], _mm256_shuffle_ps(m8, m8, 0x55), sumf); + sumf = _mm256_fmadd_ps(d4[2], _mm256_shuffle_ps(m8, m8, 0xaa), sumf); + sumf = _mm256_fmadd_ps(d4[3], _mm256_shuffle_ps(m8, m8, 0xff), sumf); + acc[iy] = _mm256_fmadd_ps(sumf, _mm256_set1_ps(-8.f), acc[iy]); + } } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[4*ib4+k].d)); - auto scales = _mm256_set_m128(scales128, scales128); - auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-4.f)); - auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+0); - auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+1); - auto q1 = _mm256_and_si256(bits1, m4); - auto q2 = _mm256_and_si256(bits2, m4); - auto q3 = _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4); - auto q4 = _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4); + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d)); + prepare_q4_0_quants_avx2(iq4[4*ib4+k].qs, v, m4); for (int iy = 0; iy < nrc_y; ++iy) { - auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00)), - _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55))); - auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa)), - _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff))); - auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); + auto sumi = accum_q4_0_quants(v, q8.y[iy][ib4].qs+32*k); auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k])); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(d8[8*iy+4+k]), acc[iy]); } } } for (int iy = 0; iy < nrc_y; ++iy) { - auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); - info.store(ix, iy, sum); + info.store(ix, iy, acc[iy]); acc[iy] = _mm256_setzero_ps(); } } + } } #ifdef HAVE_FANCY_SIMD @@ -2593,53 +2660,67 @@ static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn mul_mat_q4_0_r4_q8_1_avx2<1>(n, vx, bx, info, nrc_x); return; } - GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(nrc_x%16 == 0); Q8 q8(info); auto m4 = _mm512_set1_epi8(0xf); int nb = n / QK4_NL; GGML_ASSERT(nb%4 == 0); __m512 acc[2*nrc_y] = {}; - __m512i qx[4]; - for (int ix = 0; ix < nrc_x; ix += 8) { - const block_iq4_nl_r4 * iq4l = (const block_iq4_nl_r4 *)((const char *)vx + (ix+0)*bx); - const block_iq4_nl_r4 * iq4h = (const block_iq4_nl_r4 *)((const char *)vx + (ix+4)*bx); + __m512i qx[8]; + float d8[8*nrc_y]; + for (int ix = 0; ix < nrc_x; ix += 16) { + const block_iq4_nl_r8 * iq4l = (const block_iq4_nl_r8 *)((const char *)vx + (ix+0)*bx); + const block_iq4_nl_r8 * iq4h = (const block_iq4_nl_r8 *)((const char *)vx + (ix+8)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d))); + } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[4*ib4+k].d)); - auto scales1 = _mm256_set_m128(scales128, scales128); - scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[4*ib4+k].d)); - auto scales2 = _mm256_set_m128(scales128, scales128); + auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l[4*ib4+k].d)); + auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h[4*ib4+k].d)); auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-4.f)); auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)), _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1); auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)), _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1); + auto bits3 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+2)), + _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+2), 1); + auto bits4 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+3)), + _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+3), 1); qx[0] = _mm512_and_si512(bits1, m4); qx[1] = _mm512_and_si512(bits2, m4); - qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4); - qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4); + qx[2] = _mm512_and_si512(bits3, m4); + qx[3] = _mm512_and_si512(bits4, m4); + qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4); + qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4); + qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits3, 4), m4); + qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits4, 4), m4); for (int iy = 0; iy < nrc_y; ++iy) { - auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto y4l = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0); + auto y4h = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1); + auto y8l = MM256_SET_M128I(y4l, y4l); + auto y8h = MM256_SET_M128I(y4h, y4h); + auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1); + auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1); auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff))); + sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff))); + auto dy = _mm512_set1_ps(d8[8*iy+k]); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); } } } for (int iy = 0; iy < nrc_y; ++iy) { - auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]); + auto sum = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]); acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); - auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); - auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); - info.store(ix+0, iy, sum1); - info.store(ix+4, iy, sum2); + info.store(ix, iy, sum); } } } @@ -2907,7 +2988,6 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn int nb = n / QK8_0; GGML_ASSERT(nb%4 == 0); if constexpr (nrc_y == 1) { - auto m127 = _mm256_set1_epi8(127); __m256 acc[2] = {}; __m256i qx[8]; float d8[8]; @@ -2917,15 +2997,14 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn _mm256_storeu_ps(d8, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d))); for (int k = 0; k < 4; ++k) { auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d)); - auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-127.f)); - qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0), m127); - qx[1] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1), m127); - qx[2] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2), m127); - qx[3] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3), m127); - qx[4] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4), m127); - qx[5] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5), m127); - qx[6] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6), m127); - qx[7] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7), m127); + qx[0] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0); + qx[1] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1); + qx[2] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2); + qx[3] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3); + qx[4] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4); + qx[5] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5); + qx[6] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6); + qx[7] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7); auto y4l = _mm_loadu_si128((const __m128i*)q8.y[0][ib4].qs+2*k+0); auto y4h = _mm_loadu_si128((const __m128i*)q8.y[0][ib4].qs+2*k+1); auto yl = MM256_SET_M128I(y4l, y4l); @@ -2941,17 +3020,16 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn sumi = _mm256_dpbusd_epi32(sumi, qx[7], _mm256_shuffle_epi32(yh, 0xff)); auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[k])); acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]); - acc[1] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(d8[k+4]), acc[1]); + acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[k+4]), acc[1]); } } - info.store(ix, 0, _mm256_add_ps(acc[0], acc[1])); + info.store(ix, 0, _mm256_fmadd_ps(_mm256_set1_ps(-127.f), acc[1], acc[0])); acc[0] = acc[1] = _mm256_setzero_ps(); } } else { __m512 acc[2*nrc_y] = {}; __m512i qx[8]; float d8[8*nrc_y]; - auto m127 = _mm512_set1_epi8(127); for (int ix = 0; ix < nrc_x; ix += 16) { const block_q8_0_r8 * q8l = (const block_q8_0_r8 *)((const char *)vx + (ix+0)*bx); const block_q8_0_r8 * q8h = (const block_q8_0_r8 *)((const char *)vx + (ix+8)*bx); @@ -2963,11 +3041,9 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[4*ib4+k].d)); auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[4*ib4+k].d)); auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-127.f)); for (int j = 0; j < 8; ++j) { qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+j)), _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+j), 1); - qx[j] = _mm512_add_epi8(qx[j], m127); } for (int iy = 0; iy < nrc_y; ++iy) { auto y4l = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0); @@ -2987,18 +3063,14 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff))); auto dy = _mm512_set1_ps(d8[8*iy+k]); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); } } } for (int iy = 0; iy < nrc_y; ++iy) { - auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]); + auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-127.f), acc[2*iy+1], acc[2*iy+0]); info.store(ix, iy, sum512); acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); - //auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); - //auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); - //info.store(ix+0, iy, sum1); - //info.store(ix+4, iy, sum2); } } } @@ -4995,12 +5067,7 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn qx[1] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+1); qx[2] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+2); qx[3] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+3); -#ifdef HAVE_FANCY_SIMD - qx[0] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+0), _mm256_set1_epi8(-128)); - qx[1] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+1), _mm256_set1_epi8(-128)); - qx[2] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+2), _mm256_set1_epi8(-128)); - qx[3] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+3), _mm256_set1_epi8(-128)); -#else +#ifndef HAVE_FANCY_SIMD auto s0 = _mm256_sign_epi8(qx[0], qx[0]); auto s1 = _mm256_sign_epi8(qx[1], qx[1]); auto s2 = _mm256_sign_epi8(qx[2], qx[2]); @@ -7924,6 +7991,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[5] = mul_mat_q8_k_r8_q8_k<6>; mm.funcs[6] = mul_mat_q8_k_r8_q8_k<7>; mm.funcs[7] = mul_mat_q8_k_r8_q8_k<8>; +#ifdef HAVE_FANCY_SIMD + mm.func16 = mul_mat_q8_k_r8_q8_k<16>; +#endif expected_typeB = GGML_TYPE_Q8_KR8; break; case GGML_TYPE_IQ4_K_R4: @@ -7989,6 +8059,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[5] = mul_mat_q4_0_r4_q8_1<6>; mm.funcs[6] = mul_mat_q4_0_r4_q8_1<7>; mm.funcs[7] = mul_mat_q4_0_r4_q8_1<8>; +#ifdef HAVE_FANCY_SIMD + mm.func16 = mul_mat_q4_0_r4_q8_1<16>; +#endif expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q5_0_R4: @@ -12067,6 +12140,42 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, } } +template +void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + Dequantizer deq(vx, bx); + int nb = n / QK4_NL; + GGML_ASSERT(nb%4 == 0); + int8x16_t qx[16]; + float d8[4*nrc_y]; + float32x4_t acc[2*nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 8) { + deq.new_row(ix); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d))); + } + for (int k = 0; k < 4; ++k) { + auto scales = deq.prepare(ib4, k, qx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); + auto sumi1 = interleaved_dotq(qx+0, y); + auto sumi2 = interleaved_dotq(qx+8, y); + auto dy = vdupq_n_f32(d8[4*iy+k]); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales.val[0], dy), vcvtq_f32_s32(sumi1)); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales.val[1], dy), vcvtq_f32_s32(sumi2)); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix+0, iy, deq.result(acc[2*iy+0])); + info.store(ix+4, iy, deq.result(acc[2*iy+1])); + acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f); + } + } +} + struct IQ4_NL_R4_Dequantizer { IQ4_NL_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx), values(vld1q_s8(iq4k_values)) {} inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); } @@ -12116,6 +12225,35 @@ struct Q4_0_R4_Dequantizer { const float32x4_t norm = vdupq_n_f32(1.f/16); }; +struct Q4_0_R8_Dequantizer { + Q4_0_R8_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} + inline void new_row(int ix) { iq4 = (const block_iq4_nl_r8 *)(cx + ix*bx); } + inline float32x4x2_t prepare(int ib4, int k, int8x16_t * qx) const { + auto scales16 = vld1q_f16((const float16_t *)iq4[4*ib4+k].d); + float32x4x2_t scales = { vcvt_f32_f16(vget_low_f16(scales16)), vcvt_f32_f16(vget_high_f16(scales16)) }; + for (int j = 0; j < 4; ++j) { + auto bits = vld1q_u8_x2(iq4[4*ib4+k].qs + 32*j); + //bits.val[0] = veorq_u8(m88, bits.val[0]); + //bits.val[1] = veorq_u8(m88, bits.val[1]); + qx[2*j+0] = vshlq_n_u8(bits.val[0], 4); + qx[2*j+1] = vandq_u8(bits.val[0], m4); + qx[2*j+8] = vshlq_n_u8(bits.val[1], 4); + qx[2*j+9] = vandq_u8(bits.val[1], m4); + } + return scales; + } + inline float32x4_t result(float32x4_t acc) const { + return vmulq_f32(norm, acc); + } + + const char * cx; + const size_t bx; + const block_iq4_nl_r8 * iq4; + const uint8x16_t m4 = vdupq_n_u8(0xf0); + const uint8x16_t m88 = vdupq_n_u8(0x88); + const float32x4_t norm = vdupq_n_f32(1.f/16); +}; + struct Q5_0_R4_Dequantizer { Q5_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} inline void new_row(int ix) { iq5 = (const block_q5_0_r4 *)(cx + ix*bx); } @@ -12471,7 +12609,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { expected_Btype = GGML_TYPE_Q8_K; break; case GGML_TYPE_Q4_0_R4: - SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q4_0_R4_Dequantizer); + SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r8_q8_0, Q4_0_R8_Dequantizer); expected_Btype = GGML_TYPE_Q8_0_X4; break; case GGML_TYPE_Q5_0_R4: @@ -12894,6 +13032,12 @@ struct HelperQ80R4 : public BaseHelper { m1 = _mm256_unpackhi_epi64(t0, t1); m2 = _mm256_unpacklo_epi64(t2, t3); m3 = _mm256_unpackhi_epi64(t2, t3); +#ifdef HAVE_FANCY_SIMD + m0 = _mm256_xor_si256(m0, _mm256_set1_epi8(-128)); + m1 = _mm256_xor_si256(m1, _mm256_set1_epi8(-128)); + m2 = _mm256_xor_si256(m2, _mm256_set1_epi8(-128)); + m3 = _mm256_xor_si256(m3, _mm256_set1_epi8(-128)); +#endif _mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0); _mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1); _mm256_storeu_si256((__m256i *)y[ib].qs + 2, m2); @@ -12910,6 +13054,12 @@ struct HelperQ80R4 : public BaseHelper { m1 = _mm256_unpackhi_epi64(t0, t1); m2 = _mm256_unpacklo_epi64(t2, t3); m3 = _mm256_unpackhi_epi64(t2, t3); +#ifdef HAVE_FANCY_SIMD + m0 = _mm256_xor_si256(m0, _mm256_set1_epi8(-128)); + m1 = _mm256_xor_si256(m1, _mm256_set1_epi8(-128)); + m2 = _mm256_xor_si256(m2, _mm256_set1_epi8(-128)); + m3 = _mm256_xor_si256(m3, _mm256_set1_epi8(-128)); +#endif _mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0); _mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1); _mm256_storeu_si256((__m256i *)y[ib].qs + 6, m2); diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 59a36c5c..c1e7771f 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -43,6 +43,15 @@ constexpr int popcount(uint32_t x) { return __builtin_popcount(x); } constexpr int popcount(uint64_t x) { return __builtin_popcountll(x); } #endif +#if defined __x86_64__ +#if defined HAVE_FANCY_SIMD + #undef HAVE_FANCY_SIMD +#endif +#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) + #define HAVE_FANCY_SIMD +#endif +#endif + namespace { inline int nearest_int(float fval) { @@ -3541,7 +3550,7 @@ void quantize_row_iq4_nl_r4(const float * x, void * y, int64_t k) { quantize_iq4_nl_r4(x, y, 4, k/4, nullptr); } -static void repack_iq4_nl(int nrows, int n_per_row, const block_iq4_nl * x, block_iq4_nl_r4 * y) { +static void repack_iq4_nl(int nrows, int n_per_row, const block_iq4_nl * x, block_iq4_nl_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK4_NL == 0); int nblock = n_per_row/QK4_NL; @@ -3569,7 +3578,7 @@ size_t quantize_iq4_nl_r4(const float * src, void * dst, int64_t nrows, int64_t char * qrow = (char *)dst; for (int row = 0; row < nrows; row += 4) { quantize_iq4_nl(src, qtmp.data(), 4, n_per_row, imatrix); - repack_iq4_nl(4, n_per_row, (const block_iq4_nl *)qtmp.data(), (block_iq4_nl_r4 *)qrow); + repack_iq4_nl(4, n_per_row, (const block_iq4_nl *)qtmp.data(), (block_iq4_nl_r4 *)qrow, false); src += 4*n_per_row; qrow += 4*row_size_nl; } @@ -3615,77 +3624,89 @@ void vec_dot_iq4_nl_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t // // ========================================= q4_0_r4 // -void quantize_row_q4_0_r4_ref(const float * x, block_iq4_nl_r4 * y, int64_t k) { - // we assume we are called with 4 rows - quantize_q4_0_r4(x, (void *)y, 4, k/4, nullptr); +void quantize_row_q4_0_r4_ref(const float * x, block_iq4_nl_r8 * y, int64_t k) { + // we assume we are called with 8 rows + quantize_q4_0_r4(x, (void *)y, 8, k/8, nullptr); } void quantize_row_q4_0_r4(const float * x, void * y, int64_t k) { - // we assume we are called with 4 rows - quantize_q4_0_r4(x, y, 4, k/4, nullptr); + // we assume we are called with 8 rows + quantize_q4_0_r4(x, y, 8, k/8, nullptr); } -static void repack_q4_0(int nrows, int n_per_row, const block_q4_0 * x, block_iq4_nl_r4 * y) { - GGML_ASSERT(nrows%4 == 0); - GGML_ASSERT(n_per_row%QK4_NL == 0); - int nblock = n_per_row/QK4_NL; - const block_q4_0 * x4[4]; - for (int row = 0; row < nrows; row += 4) { - for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; +static void repack_q4_0(int nrows, int n_per_row, const block_q4_0 * x, block_iq4_nl_r8 * y, [[maybe_unused]] bool online) { + GGML_ASSERT(nrows%8 == 0); + GGML_ASSERT(n_per_row%QK4_0 == 0); + int nblock = n_per_row/QK4_0; + const block_q4_0 * x8[8]; + for (int row = 0; row < nrows; row += 8) { + for (int k = 0; k < 8; ++k) x8[k] = x + nblock*k; for (int ib = 0; ib < nblock; ++ib) { - //for (int k = 0; k < 4; ++k) y[ib].d[k] = x4[k][ib].d; - //for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) { - // y[ib].qs[4*k+i+ 0] = (x4[k][ib].qs[i+0] & 0xf) | ((x4[k][ib].qs[i+ 8] & 0x0f) << 4); // 0....3 + 8...11 from each row - // y[ib].qs[4*k+i+16] = (x4[k][ib].qs[i+0] >> 4) | ((x4[k][ib].qs[i+ 8] & 0xf0)); // 16...19 + 24...27 from each row - // y[ib].qs[4*k+i+32] = (x4[k][ib].qs[i+4] & 0xf) | ((x4[k][ib].qs[i+12] & 0x0f) << 4); // 4....7 + 12...15 from each row - // y[ib].qs[4*k+i+48] = (x4[k][ib].qs[i+4] >> 4) | ((x4[k][ib].qs[i+12] & 0xf0)); // 20...23 + 28...31 from each row - //} - for (int k = 0; k < 4; ++k) { - y[ib].d[k] = x4[k][ib].d; + for (int k = 0; k < 8; ++k) { + y[ib].d[k] = x8[k][ib].d; for (int l = 0; l < 4; ++l) { - // l = 0 -> 0, 8 with shift 0 -> 4*(l/2), 4*(l/2)+8 with shift 4*(l%2) - // l = 1 -> 0, 8 with shift 4 - // l = 2 -> 4, 12 with shift 0 - // l = 3 -> 4, 12 with shift 4 for (int i = 0; i < 4; ++i) { - y[ib].qs[4*k+i+16*l] = ((x4[k][ib].qs[i+4*(l/2)] >> 4*(l%2)) & 0xf) | (((x4[k][ib].qs[i+4*(l/2)+8] >> 4*(l%2)) & 0xf) << 4); + y[ib].qs[32*l+4*k+i] = x8[k][ib].qs[4*l + i]; } } } +#ifdef __ARM_NEON + if (online) { + for (int l = 0; l < 8; ++l) { + auto v = vld1q_u8(y[ib].qs + 16*l); + vst1q_u8(y[ib].qs + 16*l, veorq_u8(v, vdupq_n_u8(0x88))); + } + } +#endif } - x += 4*nblock; + x += 8*nblock; y += nblock; } } +#ifdef __ARM_NEON +static void modify_q4_0_r4(int64_t k, char * cy) { + auto y = (block_iq4_nl_r8 *)cy; + int nb = k/(32*8); + for (int ib = 0; ib < nb; ++ib) { + auto v1 = vld1q_u8_x4(y[ib].qs); + auto v2 = vld1q_u8_x4(y[ib].qs+64); + for (int j = 0; j < 4; ++j) { + v1.val[j] = veorq_u8(v1.val[j], vdupq_n_u8(0x88)); + v2.val[j] = veorq_u8(v2.val[j], vdupq_n_u8(0x88)); + } + vst1q_u8_x4(y[ib].qs+ 0, v1); + vst1q_u8_x4(y[ib].qs+64, v2); + } +} +#endif size_t quantize_q4_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { - GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(nrows%8 == 0); auto row_size_nl = ggml_row_size(GGML_TYPE_IQ4_NL, n_per_row); - std::vector qtmp(4*row_size_nl); + std::vector qtmp(8*row_size_nl); char * qrow = (char *)dst; - for (int row = 0; row < nrows; row += 4) { - quantize_q4_0(src, qtmp.data(), 4, n_per_row, imatrix); - repack_iq4_nl(4, n_per_row, (const block_iq4_nl *)qtmp.data(), (block_iq4_nl_r4 *)qrow); - src += 4*n_per_row; - qrow += 4*row_size_nl; + for (int row = 0; row < nrows; row += 8) { + quantize_q4_0(src, qtmp.data(), 8, n_per_row, imatrix); + repack_q4_0(8, n_per_row, (const block_q4_0 *)qtmp.data(), (block_iq4_nl_r8 *)qrow, false); + src += 8*n_per_row; + qrow += 8*row_size_nl; } return nrows*row_size_nl; } -void dequantize_row_q4_0_r4(const block_iq4_nl_r4 * x, float * y, int64_t k) { - // we assume we are called with 4 rows - int n_per_row = k/4; +void dequantize_row_q4_0_r4(const block_iq4_nl_r8 * x, float * y, int64_t k) { + // we assume we are called with 8 rows + int n_per_row = k/8; int nb = n_per_row/QK4_0; - float * yk[4]; - for (int k = 0; k < 4; ++k) yk[k] = y + k*n_per_row; + float * yk[8]; + for (int k = 0; k < 8; ++k) yk[k] = y + k*n_per_row; for (int ib = 0; ib < nb; ++ib) { - for (int k = 0; k < 4; ++k) { + for (int k = 0; k < 8; ++k) { float scale = GGML_FP16_TO_FP32(x[ib].d[k]); for (int l = 0; l < 4; ++l) { - int ll = 16*(l%2) + 4*(l/2); for (int i = 0; i < 4; ++i) { - yk[k][QK4_0*ib+i+ll+0] = scale * ((x[ib].qs[4*k+i+16*l] & 0xf) - 8); - yk[k][QK4_0*ib+i+ll+8] = scale * ((x[ib].qs[4*k+i+16*l] >> 4) - 8); + yk[k][QK4_0*ib+4*l+i+ 0] = scale * ((x[ib].qs[32*l+4*k+i] & 0xf) - 8); + yk[k][QK4_0*ib+4*l+i+16] = scale * ((x[ib].qs[32*l+4*k+i] >> 4) - 8); } } } @@ -3719,7 +3740,7 @@ void quantize_row_q8_0_r4(const float * x, void * y, int64_t k) { quantize_q8_0_r4(x, y, 8, k/8, nullptr); } -static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8_0_r8 * y) { +static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8_0_r8 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%8 == 0); GGML_ASSERT(n_per_row%QK8_0 == 0); int nblock = n_per_row/QK8_0; @@ -3734,12 +3755,33 @@ static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8 y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; } } +#ifdef HAVE_FANCY_SIMD + if (online) { + for (int l = 0; l < 4; ++l) { + auto v = _mm512_add_epi8(_mm512_loadu_si512((const __m512i *)y[ib].qs + l), _mm512_set1_epi8(127)); + _mm512_storeu_si512((__m512i *)y[ib].qs + l, v); + } + } +#endif } x += 8*nblock; y += nblock; } } +#ifdef HAVE_FANCY_SIMD +static void modify_q8_0_r4(int64_t k, char * cy) { + auto y = (block_iq4_nl_r8 *)cy; + int nb = k/(32*8); + for (int ib = 0; ib < nb; ++ib) { + for (int l = 0; l < 4; ++l) { + auto v = _mm512_add_epi8(_mm512_loadu_si512((const __m512i *)y[ib].qs + l), _mm512_set1_epi8(127)); + _mm512_storeu_si512((__m512i *)y[ib].qs + l, v); + } + } +} +#endif + size_t quantize_q8_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { GGML_ASSERT(nrows%8 == 0); auto row_size_0 = ggml_row_size(GGML_TYPE_Q8_0, n_per_row); @@ -3747,7 +3789,7 @@ size_t quantize_q8_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_ char * qrow = (char *)dst; for (int row = 0; row < nrows; row += 8) { quantize_q8_0(src, qtmp.data(), 8, n_per_row, imatrix); - repack_q8_0(8, n_per_row, (const block_q8_0 *)qtmp.data(), (block_q8_0_r8 *)qrow); + repack_q8_0(8, n_per_row, (const block_q8_0 *)qtmp.data(), (block_q8_0_r8 *)qrow, false); src += 8*n_per_row; qrow += 8*row_size_0; } @@ -3810,7 +3852,7 @@ static inline void convert_q5_0(const block_q5_0& x, uint8_t * L) { } } -static void repack_q5_0(int nrows, int n_per_row, const block_q5_0 * x, block_q5_0_r4 * y) { +static void repack_q5_0(int nrows, int n_per_row, const block_q5_0 * x, block_q5_0_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK5_0 == 0); int nblock = n_per_row/QK5_0; @@ -3844,7 +3886,7 @@ size_t quantize_q5_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_ char * qrow = (char *)dst; for (int row = 0; row < nrows; row += 4) { quantize_q5_0(src, qtmp.data(), 4, n_per_row, imatrix); - repack_q5_0(4, n_per_row, (const block_q5_0 *)qtmp.data(), (block_q5_0_r4 *)qrow); + repack_q5_0(4, n_per_row, (const block_q5_0 *)qtmp.data(), (block_q5_0_r4 *)qrow, false); src += 4*n_per_row; qrow += 4*row_size_0; } @@ -3907,7 +3949,7 @@ static inline void convert_q6_0(const block_q6_0& x, uint8_t * L) { } } -static void repack_q6_0(int nrows, int n_per_row, const block_q6_0 * x, block_q6_0_r4 * y) { +static void repack_q6_0(int nrows, int n_per_row, const block_q6_0 * x, block_q6_0_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK5_0 == 0); int nblock = n_per_row/QK6_0; @@ -3941,7 +3983,7 @@ size_t quantize_q6_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_ char * qrow = (char *)dst; for (int row = 0; row < nrows; row += 4) { quantize_q6_0(src, qtmp.data(), 4, n_per_row, imatrix); - repack_q6_0(4, n_per_row, (const block_q6_0 *)qtmp.data(), (block_q6_0_r4 *)qrow); + repack_q6_0(4, n_per_row, (const block_q6_0 *)qtmp.data(), (block_q6_0_r4 *)qrow, false); src += 4*n_per_row; qrow += 4*row_size_0; } @@ -3994,7 +4036,7 @@ void quantize_row_iq4_xs_r4(const float * x, void * y, int64_t k) { quantize_iq4_xs_r4(x, y, 8, k/8, nullptr); } -static void repack_iq4_xs(int nrows, int n_per_row, const block_iq4_xs * x, block_iq4_xs_r4 * y) { +static void repack_iq4_xs(int nrows, int n_per_row, const block_iq4_xs * x, block_iq4_xs_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%8 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; @@ -4034,7 +4076,7 @@ size_t quantize_iq4_xs_r4(const float * src, void * dst, int64_t nrows, int64_t std::vector qtmp(8*row_size); for (int row = 0; row < nrows; row += 8) { quantize_iq4_xs(src, (void *)qtmp.data(), 8, n_per_row, imatrix); - repack_iq4_xs(8, n_per_row, (const block_iq4_xs *)qtmp.data(), (block_iq4_xs_r4 *)qcur); + repack_iq4_xs(8, n_per_row, (const block_iq4_xs *)qtmp.data(), (block_iq4_xs_r4 *)qcur, false); qcur += 8*row_size; src += 8*n_per_row; } @@ -4086,7 +4128,7 @@ void quantize_row_iq4_ks_r4(const float * x, void * y, int64_t k) { quantize_iq4_ks_r4(x, y, 4, k/4, nullptr); } -static void repack_iq4_ks(int nrows, int n_per_row, const block_iq4_ks * x, block_iq4_ks_r4 * y) { +static void repack_iq4_ks(int nrows, int n_per_row, const block_iq4_ks * x, block_iq4_ks_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK_K == 0); auto row_size = ggml_row_size(GGML_TYPE_IQ4_KS, n_per_row); @@ -4128,7 +4170,7 @@ size_t quantize_iq4_ks_r4(const float * src, void * dst, int64_t nrows, int64_t std::vector qtmp(4*row_size); for (int row = 0; row < nrows; row += 4) { quantize_iq4_ks(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_iq4_ks(4, n_per_row, (const block_iq4_ks *)qtmp.data(), (block_iq4_ks_r4 *)qcur); + repack_iq4_ks(4, n_per_row, (const block_iq4_ks *)qtmp.data(), (block_iq4_ks_r4 *)qcur, false); qcur += 4*row_size; src += 4*n_per_row; } @@ -4187,7 +4229,7 @@ void quantize_row_iq2_bn_r4(const float * x, void * y, int64_t k) { } namespace { -void repack_iq2_bn(int nrows, int n_per_row, const char * x, char * y) { +void repack_iq2_bn(int nrows, int n_per_row, const char * x, char * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK_IQ1BN == 0); int nblock = n_per_row/QK_IQ1BN; @@ -4256,7 +4298,7 @@ size_t quantize_iq2_bn_r4(const float * src, void * dst, int64_t nrows, int64_t std::vector qtmp(4*row_size); for (int row = 0; row < nrows; row += 4) { quantize_iq2_bn(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_iq2_bn(4, n_per_row, qtmp.data(), qcur); + repack_iq2_bn(4, n_per_row, qtmp.data(), qcur, false); qcur += 4*row_size; src += 4*n_per_row; } @@ -4330,7 +4372,7 @@ inline void convert_q4_k(const block_q4_K& x, uint8_t * L, uint8_t * Ld, uint8_t } } -static void repack_q4_k(int nrows, int n_per_row, const block_q4_K * x, block_q4_k_r4 * y) { +static void repack_q4_k(int nrows, int n_per_row, const block_q4_K * x, block_q4_k_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; @@ -4371,7 +4413,7 @@ size_t quantize_q4_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_ std::vector qtmp(4*row_size); for (int row = 0; row < nrows; row += 4) { quantize_q4_K(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_q4_k(4, n_per_row, (const block_q4_K *)qtmp.data(), (block_q4_k_r4 *)qcur); + repack_q4_k(4, n_per_row, (const block_q4_K *)qtmp.data(), (block_q4_k_r4 *)qcur, false); qcur += 4*row_size; src += 4*n_per_row; } @@ -4448,7 +4490,7 @@ inline void convert_q6_k(const block_q6_K& x, uint8_t * L) { } } -static void repack_q6_k(int nrows, int n_per_row, const block_q6_K * x, block_q6_k_r4 * y) { +static void repack_q6_k(int nrows, int n_per_row, const block_q6_K * x, block_q6_k_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; @@ -4487,7 +4529,7 @@ size_t quantize_q6_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_ std::vector qtmp(4*row_size); for (int row = 0; row < nrows; row += 4) { quantize_q6_K(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_q6_k(4, n_per_row, (const block_q6_K *)qtmp.data(), (block_q6_k_r4 *)qcur); + repack_q6_k(4, n_per_row, (const block_q6_K *)qtmp.data(), (block_q6_k_r4 *)qcur, false); qcur += 4*row_size; src += 4*n_per_row; } @@ -4562,7 +4604,7 @@ inline void convert_q5_k(const block_q5_K& x, uint8_t * L, uint8_t * Ld, uint8_t } } -static void repack_q5_k(int nrows, int n_per_row, const block_q5_K * x, block_q5_k_r4 * y) { +static void repack_q5_k(int nrows, int n_per_row, const block_q5_K * x, block_q5_k_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; @@ -4605,7 +4647,7 @@ size_t quantize_q5_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_ std::vector qtmp(4*row_size); for (int row = 0; row < nrows; row += 4) { quantize_q5_K(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_q5_k(4, n_per_row, (const block_q5_K *)qtmp.data(), (block_q5_k_r4 *)qcur); + repack_q5_k(4, n_per_row, (const block_q5_K *)qtmp.data(), (block_q5_k_r4 *)qcur, false); qcur += 4*row_size; src += 4*n_per_row; } @@ -4698,7 +4740,7 @@ inline void convert_q3_k(const block_q3_K& x, uint8_t * L, uint8_t * Ld) { } } -static void repack_q3_k(int nrows, int n_per_row, const block_q3_K * x, block_q3_k_r4 * y) { +static void repack_q3_k(int nrows, int n_per_row, const block_q3_K * x, block_q3_k_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; @@ -4741,7 +4783,7 @@ size_t quantize_q3_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_ std::vector qtmp(4*row_size); for (int row = 0; row < nrows; row += 4) { quantize_q3_K(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_q3_k(4, n_per_row, (const block_q3_K *)qtmp.data(), (block_q3_k_r4 *)qcur); + repack_q3_k(4, n_per_row, (const block_q3_K *)qtmp.data(), (block_q3_k_r4 *)qcur, false); qcur += 4*row_size; src += 4*n_per_row; } @@ -4820,7 +4862,7 @@ inline void convert_q2_k(const block_q2_K& x, uint8_t * L) { } } -static void repack_q2_k(int nrows, int n_per_row, const block_q2_K * x, block_q2_k_r4 * y) { +static void repack_q2_k(int nrows, int n_per_row, const block_q2_K * x, block_q2_k_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; @@ -4857,7 +4899,7 @@ size_t quantize_q2_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_ std::vector qtmp(4*row_size); for (int row = 0; row < nrows; row += 4) { quantize_q2_K(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_q2_k(4, n_per_row, (const block_q2_K *)qtmp.data(), (block_q2_k_r4 *)qcur); + repack_q2_k(4, n_per_row, (const block_q2_K *)qtmp.data(), (block_q2_k_r4 *)qcur, false); qcur += 4*row_size; src += 4*n_per_row; } @@ -4919,7 +4961,7 @@ void quantize_row_iq4_k_r4(const float * x, void * y, int64_t k) { quantize_iq4_k_r4(x, y, 4, k/4, nullptr); } -static void repack_iq4_k(int nrows, int n_per_row, const block_iq4_k * x, block_iq4_k_r4 * y) { +static void repack_iq4_k(int nrows, int n_per_row, const block_iq4_k * x, block_iq4_k_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; @@ -4972,7 +5014,7 @@ size_t quantize_iq4_k_r4(const float * src, void * dst, int64_t nrows, int64_t n std::vector qtmp(4*row_size); for (int row = 0; row < nrows; row += 4) { quantize_iq4_k(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_iq4_k(4, n_per_row, (const block_iq4_k *)qtmp.data(), (block_iq4_k_r4 *)qcur); + repack_iq4_k(4, n_per_row, (const block_iq4_k *)qtmp.data(), (block_iq4_k_r4 *)qcur, false); qcur += 4*row_size; src += 4*n_per_row; } @@ -5053,7 +5095,7 @@ inline void convert_iq5_k(const block_iq5_k& x, uint8_t * L) { } } -static void repack_iq5_k(int nrows, int n_per_row, const block_iq5_k * x, block_iq5_k_r4 * y) { +static void repack_iq5_k(int nrows, int n_per_row, const block_iq5_k * x, block_iq5_k_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; @@ -5108,7 +5150,7 @@ size_t quantize_iq5_k_r4(const float * src, void * dst, int64_t nrows, int64_t n std::vector qtmp(4*row_size); for (int row = 0; row < nrows; row += 4) { quantize_iq5_k(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_iq5_k(4, n_per_row, (const block_iq5_k *)qtmp.data(), (block_iq5_k_r4 *)qcur); + repack_iq5_k(4, n_per_row, (const block_iq5_k *)qtmp.data(), (block_iq5_k_r4 *)qcur, false); qcur += 4*row_size; src += 4*n_per_row; } @@ -5169,7 +5211,7 @@ void quantize_row_q8_k_r8(const float * x, void * y, int64_t k) { quantize_q8_k_r8(x, y, 8, k/8, nullptr); } -static void repack_q8_k(int nrows, int n_per_row, const block_q8_K * x, block_q8_k_r8 * y) { +static void repack_q8_k(int nrows, int n_per_row, const block_q8_K * x, block_q8_k_r8 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%8 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; @@ -5183,11 +5225,31 @@ static void repack_q8_k(int nrows, int n_per_row, const block_q8_K * x, block_q8 for (int i = 0; i < 4; ++i) y[ibl].qs[32*ib + 4*k + i] = x8[k][ibl].qs[4*ib+i]; } } +#ifdef HAVE_FANCY_SIMD + if (online) { + for (int l = 0; l < 32; ++l) { + auto v = _mm512_xor_si512(_mm512_loadu_si512((const __m512i *)y[ibl].qs + l), _mm512_set1_epi8(-128)); + _mm512_storeu_si512((__m512i *)y[ibl].qs + l, v); + } + } +#endif } x += 8*nblock; y += nblock; } } +#ifdef HAVE_FANCY_SIMD +static void modify_q8_k_r8(int64_t k, char * cy) { + auto y = (block_q8_k_r8 *)cy; + int nb = k/(256*8); + for (int ib = 0; ib < nb; ++ib) { + for (int l = 0; l < 32; ++l) { + auto v = _mm512_xor_si512(_mm512_loadu_si512((const __m512i *)y[ib].qs + l), _mm512_set1_epi8(-128)); + _mm512_storeu_si512((__m512i *)y[ib].qs + l, v); + } + } +} +#endif size_t quantize_q8_k_r8(const float * src, void * dst, int64_t nrows, int64_t n_per_row, [[maybe_unused]] const float * imatrix) { GGML_ASSERT(nrows%8 == 0); @@ -5198,7 +5260,7 @@ size_t quantize_q8_k_r8(const float * src, void * dst, int64_t nrows, int64_t n_ std::vector qtmp(8*row_size_0); for (int row = 0; row < nrows; row += 8) { quantize_row_q8_K32(src, (void *)qtmp.data(), 8*n_per_row); - repack_q8_k(8, n_per_row, (const block_q8_K *)qtmp.data(), (block_q8_k_r8 *)qcur); + repack_q8_k(8, n_per_row, (const block_q8_K *)qtmp.data(), (block_q8_k_r8 *)qcur, false); qcur += 8*row_size_1; src += 8*n_per_row; } @@ -5247,7 +5309,7 @@ inline ggml_bf16_t to_bf16(const float& x) { inline ggml_bf16_t to_bf16(const ggml_half& x) { return to_bf16(GGML_FP16_TO_FP32(x)); } inline ggml_bf16_t to_bf16(const ggml_bf16_t& x) { return x; } template -void repack_bf16(int nrows, int n_per_row, const T * x, ggml_bf16_t * y) { +void repack_bf16(int nrows, int n_per_row, const T * x, ggml_bf16_t * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%16 == 0); GGML_ASSERT(n_per_row%2 == 0); for (int row = 0; row < nrows; row += 16) { @@ -5265,11 +5327,11 @@ void repack_bf16(int nrows, int n_per_row, const T * x, ggml_bf16_t * y) { } void repack_f32_bf16_r16(const void * src, void * dst, int64_t nrows, int64_t n_per_row) { - repack_bf16(nrows, n_per_row, (const float *)src, (ggml_bf16_t *)dst); + repack_bf16(nrows, n_per_row, (const float *)src, (ggml_bf16_t *)dst, false); } void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row) { - repack_bf16(nrows, n_per_row, (const ggml_bf16_t *)src, (ggml_bf16_t *)dst); + repack_bf16(nrows, n_per_row, (const ggml_bf16_t *)src, (ggml_bf16_t *)dst, false); } // @@ -5301,7 +5363,7 @@ inline void convert_iq3_k(const block_iq3_k& x, uint8_t * L) { } } -static void repack_iq3_k(int nrows, int n_per_row, const block_iq3_k * x, block_iq3_k_r4 * y) { +static void repack_iq3_k(int nrows, int n_per_row, const block_iq3_k * x, block_iq3_k_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; @@ -5355,7 +5417,7 @@ size_t quantize_iq3_k_r4(const float * src, void * dst, int64_t nrows, int64_t n std::vector qtmp(4*row_size); for (int row = 0; row < nrows; row += 4) { quantize_iq3_k(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_iq3_k(4, n_per_row, (const block_iq3_k *)qtmp.data(), (block_iq3_k_r4 *)qcur); + repack_iq3_k(4, n_per_row, (const block_iq3_k *)qtmp.data(), (block_iq3_k_r4 *)qcur, false); qcur += 4*row_size; src += 4*n_per_row; } @@ -5435,7 +5497,7 @@ inline void convert_iq2_k(const block_iq2_k& x, uint8_t * L) { } } -static void repack_iq2_k(int nrows, int n_per_row, const block_iq2_k * x, block_iq2_k_r4 * y) { +static void repack_iq2_k(int nrows, int n_per_row, const block_iq2_k * x, block_iq2_k_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; @@ -5480,7 +5542,7 @@ size_t quantize_iq2_k_r4(const float * src, void * dst, int64_t nrows, int64_t n std::vector qtmp(4*row_size); for (int row = 0; row < nrows; row += 4) { quantize_iq2_k(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_iq2_k(4, n_per_row, (const block_iq2_k *)qtmp.data(), (block_iq2_k_r4 *)qcur); + repack_iq2_k(4, n_per_row, (const block_iq2_k *)qtmp.data(), (block_iq2_k_r4 *)qcur, false); qcur += 4*row_size; src += 4*n_per_row; } @@ -5531,15 +5593,6 @@ void vec_dot_iq2_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t GGML_UNUSED(by); } -namespace { -struct Repack { - using repack_func = void (*) (int nrows, int n_per_row, const char * src, char * dst); - ggml_type new_type; - int num_rows; - repack_func repack; -}; -} - namespace { inline uint8_t scrambled_sign(uint8_t s) { static const uint8_t k_table[128] = { @@ -5568,7 +5621,7 @@ void quantize_row_iq2_xxs_r4(const float * x, void * y, int64_t k) { quantize_iq2_xxs_r4(x, y, 4, k/4, nullptr); } -static void repack_iq2_xxs(int nrows, int n_per_row, const block_iq2_xxs * x, block_iq2_xxs_r4 * y) { +static void repack_iq2_xxs(int nrows, int n_per_row, const block_iq2_xxs * x, block_iq2_xxs_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; @@ -5609,7 +5662,7 @@ size_t quantize_iq2_xxs_r4(const float * src, void * dst, int64_t nrows, int64_t std::vector qtmp(4*row_size); for (int row = 0; row < nrows; row += 4) { quantize_iq2_xxs(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_iq2_xxs(4, n_per_row, (const block_iq2_xxs *)qtmp.data(), (block_iq2_xxs_r4 *)qcur); + repack_iq2_xxs(4, n_per_row, (const block_iq2_xxs *)qtmp.data(), (block_iq2_xxs_r4 *)qcur, false); qcur += 4*row_size; src += 4*n_per_row; } @@ -5668,7 +5721,7 @@ void quantize_row_iq2_xs_r4(const float * x, void * y, int64_t k) { quantize_iq2_xs_r4(x, y, 4, k/4, nullptr); } -static void repack_iq2_xs(int nrows, int n_per_row, const block_iq2_xs * x, block_iq2_xs_r4 * y) { +static void repack_iq2_xs(int nrows, int n_per_row, const block_iq2_xs * x, block_iq2_xs_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; @@ -5701,7 +5754,7 @@ size_t quantize_iq2_xs_r4(const float * src, void * dst, int64_t nrows, int64_t std::vector qtmp(4*row_size); for (int row = 0; row < nrows; row += 4) { quantize_iq2_xs(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_iq2_xs(4, n_per_row, (const block_iq2_xs *)qtmp.data(), (block_iq2_xs_r4 *)qcur); + repack_iq2_xs(4, n_per_row, (const block_iq2_xs *)qtmp.data(), (block_iq2_xs_r4 *)qcur, false); qcur += 4*row_size; src += 4*n_per_row; } @@ -5755,7 +5808,7 @@ void quantize_row_iq2_s_r4(const float * x, void * y, int64_t k) { quantize_iq2_s_r4(x, y, 4, k/4, nullptr); } -static void repack_iq2_s(int nrows, int n_per_row, const block_iq2_s * x, block_iq2_s_r4 * y) { +static void repack_iq2_s(int nrows, int n_per_row, const block_iq2_s * x, block_iq2_s_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; @@ -5789,7 +5842,7 @@ size_t quantize_iq2_s_r4(const float * src, void * dst, int64_t nrows, int64_t n std::vector qtmp(4*row_size); for (int row = 0; row < nrows; row += 4) { quantize_iq2_s(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_iq2_s(4, n_per_row, (const block_iq2_s *)qtmp.data(), (block_iq2_s_r4 *)qcur); + repack_iq2_s(4, n_per_row, (const block_iq2_s *)qtmp.data(), (block_iq2_s_r4 *)qcur, false); qcur += 4*row_size; src += 4*n_per_row; } @@ -5845,7 +5898,7 @@ void quantize_row_iq3_xxs_r4(const float * x, void * y, int64_t k) { namespace { } -static void repack_iq3_xxs(int nrows, int n_per_row, const block_iq3_xxs * x, block_iq3_xxs_r4 * y) { +static void repack_iq3_xxs(int nrows, int n_per_row, const block_iq3_xxs * x, block_iq3_xxs_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; @@ -5886,7 +5939,7 @@ size_t quantize_iq3_xxs_r4(const float * src, void * dst, int64_t nrows, int64_t std::vector qtmp(4*row_size); for (int row = 0; row < nrows; row += 4) { quantize_iq3_xxs(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_iq3_xxs(4, n_per_row, (const block_iq3_xxs *)qtmp.data(), (block_iq3_xxs_r4 *)qcur); + repack_iq3_xxs(4, n_per_row, (const block_iq3_xxs *)qtmp.data(), (block_iq3_xxs_r4 *)qcur, false); qcur += 4*row_size; src += 4*n_per_row; } @@ -5945,7 +5998,7 @@ void quantize_row_iq3_s_r4(const float * x, void * y, int64_t k) { quantize_iq3_s_r4(x, y, 4, k/4, nullptr); } -static void repack_iq3_s(int nrows, int n_per_row, const block_iq3_s * x, block_iq3_s_r4 * y) { +static void repack_iq3_s(int nrows, int n_per_row, const block_iq3_s * x, block_iq3_s_r4 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%4 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; @@ -5991,7 +6044,7 @@ size_t quantize_iq3_s_r4(const float * src, void * dst, int64_t nrows, int64_t n std::vector qtmp(4*row_size); for (int row = 0; row < nrows; row += 4) { quantize_iq3_s(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_iq3_s(4, n_per_row, (const block_iq3_s *)qtmp.data(), (block_iq3_s_r4 *)qcur); + repack_iq3_s(4, n_per_row, (const block_iq3_s *)qtmp.data(), (block_iq3_s_r4 *)qcur, false); qcur += 4*row_size; src += 4*n_per_row; } @@ -6036,6 +6089,56 @@ void vec_dot_iq3_s_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t //================================================ +namespace { +struct Repack { + using repack_func = void (*) (int nrows, int n_per_row, const char * src, char * dst, bool online); + ggml_type new_type; + int num_rows; + repack_func repack; +}; +struct Modify { + using modify_func_t = void (*)(int64_t k, char * src_dst); + modify_func_t mod_func; + int nrows; +}; +} + +bool iqk_modify_tensor(struct ggml_tensor * tensor) { + static const std::unordered_map k_mod_map = { +#ifdef __ARM_NEON + { GGML_TYPE_Q4_0_R4, {modify_q4_0_r4, 8} }, +#endif +#ifdef HAVE_FANCY_SIMD + { GGML_TYPE_Q8_0_R4, {modify_q8_0_r4, 8} }, + { GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} }, +#endif + }; + auto it = k_mod_map.find(tensor->type); + if (it == k_mod_map.end()) return false; + + auto& m = it->second; + int nrows = ggml_nrows(tensor); + int nchunks = nrows/m.nrows; + int max_thread = std::max(1, int(std::thread::hardware_concurrency()/2)); + int nthread = std::min(nchunks, max_thread); + auto row_size = ggml_row_size(tensor->type, tensor->ne[0]); + std::atomic counter(0); + auto compute = [&counter, &m, tensor, row_size, nchunks] () { + int64_t n_per_call = m.nrows*tensor->ne[0]; + while (true) { + int row = counter.fetch_add(1); + if (row >= nchunks) break; + m.mod_func(n_per_call, (char *)tensor->data + row_size*row*m.nrows); + } + }; + std::vector workers(nthread-1); + for (auto& w : workers) w = std::thread(compute); + compute(); + for (auto& w : workers) w.join(); + + return true; +} + void iqk_repack_tensor(struct ggml_tensor * tensor) { constexpr int kChunk = 8; if (!tensor) return; @@ -6061,7 +6164,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { { GGML_TYPE_Q4_K, { GGML_TYPE_Q4_K_R4, 4, (Repack::repack_func)repack_q4_k} }, { GGML_TYPE_Q5_K, { GGML_TYPE_Q5_K_R4, 4, (Repack::repack_func)repack_q5_k} }, { GGML_TYPE_Q6_K, { GGML_TYPE_Q6_K_R4, 4, (Repack::repack_func)repack_q6_k} }, - { GGML_TYPE_Q4_0, { GGML_TYPE_Q4_0_R4, 4, (Repack::repack_func)repack_q4_0} }, + { GGML_TYPE_Q4_0, { GGML_TYPE_Q4_0_R4, 8, (Repack::repack_func)repack_q4_0} }, { GGML_TYPE_Q5_0, { GGML_TYPE_Q5_0_R4, 4, (Repack::repack_func)repack_q5_0} }, { GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4, (Repack::repack_func)repack_q6_0} }, { GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R4, 8, (Repack::repack_func)repack_q8_0} }, @@ -6099,7 +6202,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { int last_row = std::min(first_row + chunkSize*r.num_rows, nrows); for (int row = first_row; row < last_row; row += r.num_rows) { std::memcpy(qtmp.data(), data + row*row_size, r.num_rows*row_size); - r.repack(r.num_rows, n_per_row, qtmp.data(), data + row*row_size); + r.repack(r.num_rows, n_per_row, qtmp.data(), data + row*row_size, true); } } }; diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 64860b4d..1a991787 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -67,10 +67,10 @@ size_t quantize_iq4_nl_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT void dequantize_row_iq4_nl_r4(const block_iq4_nl_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq4_nl_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void quantize_row_q4_0_r4_ref(const float * GGML_RESTRICT x, block_iq4_nl_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_0_r4_ref(const float * GGML_RESTRICT x, block_iq4_nl_r8 * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); size_t quantize_q4_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -void dequantize_row_q4_0_r4(const block_iq4_nl_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q4_0_r4(const block_iq4_nl_r8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_q4_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void quantize_row_q8_0_r4_ref(const float * GGML_RESTRICT x, block_q8_0_r8 * GGML_RESTRICT y, int64_t k); @@ -218,6 +218,7 @@ void repack_f32_bf16_r16 (const void * GGML_RESTRICT src, void * GGML_RESTRICT d void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row); void iqk_repack_tensor(struct ggml_tensor * tensor); +bool iqk_modify_tensor(struct ggml_tensor * tensor); // So we can re-pack Microsoft's BitNet I2_S quants void dequantize_row_ms_i2s(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); diff --git a/src/llama.cpp b/src/llama.cpp index 836fd97a..b6a4a06d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7650,7 +7650,7 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } else { layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - layer.ffn_exp_probs_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert} ); + layer.ffn_exp_probs_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 1); GGML_ASSERT(n_expert > 0); GGML_ASSERT(n_expert_used > 0); @@ -8014,6 +8014,16 @@ static bool llm_load_tensors( } } + if (!ml.use_mmap) { + int n_modified = 0; + for (auto& it : model.tensors_by_name) { + if (ggml_backend_buffer_is_host(it.second->buffer)) { + if (iqk_modify_tensor(it.second)) ++n_modified; + } + } + if (n_modified > 0) printf("============ Modified %d tensors\n", n_modified); + } + if (!ml.use_mmap && ml.repack_tensors) { int n_repacked = 0; for (auto& it : model.tensors_by_name) { @@ -16910,8 +16920,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s else chunk_size_multiplier = 8; } else if (new_type == GGML_TYPE_Q4_0_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0; - else chunk_size_multiplier = 4; + if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q4_0; + else chunk_size_multiplier = 8; } else if (new_type == GGML_TYPE_Q5_0_R4) { if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q5_0; From 4a73c250023a74bb1665875bbced7f1a3857b7f6 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 29 Jan 2025 14:05:41 +0200 Subject: [PATCH 03/14] Various (#181) * Adding gp option to llama-bench Similar to pg, but it only looks at TG speed with a given prompt length. * Make q8_0_r4 work with tensor row sizes that are not a multiple of 128 They still need to be divisible by 32. * Make q8_0_r4 work with tensor row sizes that are not a multiple of 128 .. on NEON * Make q8_0_r4 work with tensor row sizes that are not a multiple of 128 .., on AVX2 * Make q4_0_r4 work with tensor row sizes that are not a multiple of 128 .., on AVX2 * Make q4_0_r4 work with tensor row sizes that are not a multiple of 128 ... on NEON * Make q4_0_r4 work with tensor row sizes that are not a multiple of 128 ... on Zen4. Also fix q8_0 K-cache for head sizes that are not multiple of 128. --------- Co-authored-by: Iwan Kawrakow --- examples/llama-bench/llama-bench.cpp | 106 +++++- ggml/src/iqk/iqk_mul_mat.cpp | 494 ++++++++++++++++++--------- 2 files changed, 434 insertions(+), 166 deletions(-) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 75fe40d1..b46bd855 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -220,6 +220,7 @@ struct cmd_params { std::vector n_prompt; std::vector n_gen; std::vector> n_pg; + std::vector> n_gp; std::vector n_batch; std::vector n_ubatch; std::vector type_k; @@ -248,6 +249,7 @@ static const cmd_params cmd_params_defaults = { /* n_prompt */ {512}, /* n_gen */ {128}, /* n_pg */ {}, + /* n_gp */ {}, /* n_batch */ {2048}, /* n_ubatch */ {512}, /* type_k */ {GGML_TYPE_F16}, @@ -280,6 +282,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -p, --n-prompt (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str()); printf(" -n, --n-gen (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str()); printf(" -pg (default: %s)\n", join(transform_to_str(cmd_params_defaults.n_pg, pair_str), ",").c_str()); + printf(" -gp (default: %s)\n", join(transform_to_str(cmd_params_defaults.n_gp, pair_str), ",").c_str()); printf(" -b, --batch-size (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str()); printf(" -ub, --ubatch-size (default: %s)\n", join(cmd_params_defaults.n_ubatch, ",").c_str()); printf(" -ctk, --cache-type-k (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str()); @@ -393,6 +396,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.n_pg.push_back({std::stoi(p[0]), std::stoi(p[1])}); + } else if (arg == "-gp") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], ','); + if (p.size() != 2) { + invalid_param = true; + break; + } + params.n_gp.push_back({ std::stoi(p[0]), std::stoi(p[1]) }); } else if (arg == "-b" || arg == "--batch-size") { if (++i >= argc) { invalid_param = true; @@ -596,6 +610,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; } if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; } if (params.n_pg.empty()) { params.n_pg = cmd_params_defaults.n_pg; } + if (params.n_gp.empty()) { params.n_gp = cmd_params_defaults.n_gp; } if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; } if (params.n_ubatch.empty()) { params.n_ubatch = cmd_params_defaults.n_ubatch; } if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; } @@ -614,7 +629,19 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { return params; } +enum test_kind_type { + // measure mean prompt processing rate without token generation + TEST_KIND_PP, + // measure mean token generation rate without prompt processing + TEST_KIND_TG, + // measure mean prompt processing and token generation rate + TEST_KIND_PG, + // measure mean token generation rate after processing prompt of given length + TEST_KIND_GP, +}; + struct cmd_params_instance { + test_kind_type test_kind; std::string model; int n_prompt; int n_gen; @@ -701,6 +728,7 @@ static std::vector get_cmd_params_instances(const cmd_param continue; } cmd_params_instance instance = { + /* .test_kind = */ TEST_KIND_PP, /* .model = */ m, /* .n_prompt = */ n_prompt, /* .n_gen = */ 0, @@ -728,6 +756,7 @@ static std::vector get_cmd_params_instances(const cmd_param continue; } cmd_params_instance instance = { + /* .test_kind = */ TEST_KIND_PP, /* .model = */ m, /* .n_prompt = */ 0, /* .n_gen = */ n_gen, @@ -755,6 +784,7 @@ static std::vector get_cmd_params_instances(const cmd_param continue; } cmd_params_instance instance = { + /* .test_kind = */ TEST_KIND_PP, /* .model = */ m, /* .n_prompt = */ n_pg.first, /* .n_gen = */ n_pg.second, @@ -776,6 +806,34 @@ static std::vector get_cmd_params_instances(const cmd_param }; instances.push_back(instance); } + + for (const auto & n_gp : params.n_gp) { + if (n_gp.first == 0 && n_gp.second == 0) { + continue; + } + cmd_params_instance instance = { + /* .test_kind = */ TEST_KIND_GP, + /* .model = */ m, + /* .n_prompt = */ n_gp.first, + /* .n_gen = */ n_gp.second, + /* .n_batch = */ nb, + /* .n_ubatch = */ nub, + /* .type_k = */ tk, + /* .type_v = */ tv, + /* .n_threads = */ nt, + /* .n_gpu_layers = */ nl, + /* .rpc_servers = */ rpc, + /* .split_mode = */ sm, + /* .main_gpu = */ mg, + /* .no_kv_offload= */ nkvo, + /* .flash_attn = */ fa, + /* .tensor_split = */ ts, + /* .use_mmap = */ mmp, + /* .embeddings = */ embd, + /* .repack = */ params.repack, + }; + instances.push_back(instance); + } } return instances; @@ -816,6 +874,8 @@ struct test { int n_gen; std::string test_time; std::vector samples_ns; + test_kind_type test_kind; + std::string test_label; test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) { model_filename = inst.model; @@ -841,11 +901,32 @@ struct test { repack = inst.repack; n_prompt = inst.n_prompt; n_gen = inst.n_gen; + test_kind = inst.test_kind; // RFC 3339 date-time format time_t t = time(NULL); std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t)); test_time = buf; + // prepare test label for printing + switch (test_kind) { + case TEST_KIND_PP: + snprintf(buf, sizeof(buf), "pp%d", n_prompt); + break; + case TEST_KIND_TG: + snprintf(buf, sizeof(buf), "tg%d", n_gen); + break; + case TEST_KIND_PG: + snprintf(buf, sizeof(buf), "pp%d+tg%d", n_prompt, n_gen); + break; + case TEST_KIND_GP: + snprintf(buf, sizeof(buf), "tg%d@pp%d", n_gen, n_prompt); + break; + default: + snprintf(buf, sizeof(buf), "unknown"); + break; + } + test_label = buf; + (void) ctx; } @@ -858,7 +939,7 @@ struct test { } std::vector get_ts() const { - int n_tokens = n_prompt + n_gen; + int n_tokens = (test_kind == TEST_KIND_GP ? 0 : n_prompt) + n_gen; std::vector ts; std::transform(samples_ns.begin(), samples_ns.end(), std::back_inserter(ts), [n_tokens](uint64_t t) { return 1e9 * n_tokens / t; }); return ts; @@ -911,7 +992,7 @@ struct test { "tensor_split", "use_mmap", "embeddings", "repack", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", - "avg_ts", "stddev_ts" + "avg_ts", "stddev_ts", "test", }; return fields; } @@ -967,7 +1048,8 @@ struct test { tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), - std::to_string(avg_ts()), std::to_string(stdev_ts()) + std::to_string(avg_ts()), std::to_string(stdev_ts()), + test_label }; return values; } @@ -1269,14 +1351,15 @@ struct markdown_printer : public printer { value += "+RPC"; } } else if (field == "test") { - if (t.n_prompt > 0 && t.n_gen == 0) { - snprintf(buf, sizeof(buf), "pp%d", t.n_prompt); - } else if (t.n_gen > 0 && t.n_prompt == 0) { - snprintf(buf, sizeof(buf), "tg%d", t.n_gen); - } else { - snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen); - } - value = buf; + //if (t.n_prompt > 0 && t.n_gen == 0) { + // snprintf(buf, sizeof(buf), "pp%d", t.n_prompt); + //} else if (t.n_gen > 0 && t.n_prompt == 0) { + // snprintf(buf, sizeof(buf), "tg%d", t.n_gen); + //} else { + // snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen); + //} + //value = buf; + value = t.test_label; } else if (field == "t/s") { snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_ts(), t.stdev_ts()); value = buf; @@ -1489,6 +1572,7 @@ int main(int argc, char ** argv) { if (t.n_prompt > 0) { test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); } + if (t.test_kind == TEST_KIND_GP) t_start = get_time_ns(); if (t.n_gen > 0) { test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads); } diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 8d2b4090..308d0dca 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -111,6 +111,15 @@ struct Perf { #define IQK_ALWAYS_INLINE __attribute__((__always_inline__)) #endif +#if defined __x86_64__ +#if defined HAVE_FANCY_SIMD + #undef HAVE_FANCY_SIMD +#endif +#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) + #define HAVE_FANCY_SIMD +#endif +#endif + namespace { typedef struct { @@ -236,6 +245,35 @@ struct MulMat { } static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny); static inline int num_rows(ggml_type type) { +#ifdef HAVE_FANCY_SIMD + switch (type) { + case GGML_TYPE_Q2_K_R4: + case GGML_TYPE_Q3_K_R4: + case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_IQ2_K_R4: + case GGML_TYPE_IQ3_K_R4: + case GGML_TYPE_IQ4_K_R4: + case GGML_TYPE_IQ5_K_R4: + case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ2_XXS_R4: + case GGML_TYPE_IQ2_XS_R4: + case GGML_TYPE_IQ2_S_R4: + case GGML_TYPE_IQ3_XXS_R4: + case GGML_TYPE_IQ3_S_R4: return 4; + case GGML_TYPE_IQ4_NL_R4: + case GGML_TYPE_Q5_0_R4: + case GGML_TYPE_Q6_0_R4: + case GGML_TYPE_IQ2_BN_R4: + case GGML_TYPE_IQ4_XS_R4: + case GGML_TYPE_Q4_K_R4: + case GGML_TYPE_Q5_K_R4: + case GGML_TYPE_Q8_K_R8: return 8; + case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_Q8_0_R4: + case GGML_TYPE_BF16_R16: return 16; + default: return 1; + } +#else switch (type) { case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K_R4: @@ -263,6 +301,7 @@ struct MulMat { case GGML_TYPE_BF16_R16: return 16; default: return 1; } +#endif } private: template static void set_functions(MulMat& m); @@ -377,13 +416,6 @@ const uint64_t keven_signs[128] = { #if defined __x86_64__ -#if defined HAVE_FANCY_SIMD - #undef HAVE_FANCY_SIMD -#endif -#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) - #define HAVE_FANCY_SIMD -#endif - namespace { inline float hsum_float_4(__m128 x) { @@ -2608,6 +2640,15 @@ static void mul_mat_q4_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(helper.val[k+4]), acc2); } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto qy = (const block_q8_1 *)q8.y[0]; + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d)); + prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4); + auto sumi = accum_q4_0_quants(v, qy[ib].qs); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + acc1 = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc1); + acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc2); + } acc1 = _mm256_fmadd_ps(acc2, _mm256_set1_ps(-8.f), acc1); info.store(ix, 0, acc1); } @@ -2645,6 +2686,18 @@ static void mul_mat_q4_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d)); + auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-8.f)); + prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = accum_q4_0_quants(v, qy[ib].qs); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]); + } + } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, acc[iy]); acc[iy] = _mm256_setzero_ps(); @@ -2664,9 +2717,38 @@ static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn Q8 q8(info); auto m4 = _mm512_set1_epi8(0xf); int nb = n / QK4_NL; - GGML_ASSERT(nb%4 == 0); __m512 acc[2*nrc_y] = {}; __m512i qx[8]; + auto prepare = [&qx, &m4] (const block_iq4_nl_r8& iq4l, const block_iq4_nl_r8& iq4h) { + auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l.d)); + auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h.d)); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + for (int j = 0; j < 4; ++j) { + auto bits = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+j)), + _mm256_loadu_si256((const __m256i *)iq4h.qs+j), 1); + qx[j+0] = _mm512_and_si512(bits, m4); + qx[j+4] = _mm512_and_si512(_mm512_srli_epi16(bits, 4), m4); + } + return scales; + }; + auto dot = [&qx] (const int8_t * qy) { + auto y4l = _mm_loadu_si128((const __m128i*)qy+0); + auto y4h = _mm_loadu_si128((const __m128i*)qy+1); + auto y8l = MM256_SET_M128I(y4l, y4l); + auto y8h = MM256_SET_M128I(y4h, y4h); + auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1); + auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff))); + sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff))); + return sumi; + }; float d8[8*nrc_y]; for (int ix = 0; ix < nrc_x; ix += 16) { const block_iq4_nl_r8 * iq4l = (const block_iq4_nl_r8 *)((const char *)vx + (ix+0)*bx); @@ -2676,47 +2758,25 @@ static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d))); } for (int k = 0; k < 4; ++k) { - auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l[4*ib4+k].d)); - auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h[4*ib4+k].d)); - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)), - _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1); - auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)), - _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1); - auto bits3 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+2)), - _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+2), 1); - auto bits4 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+3)), - _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+3), 1); - qx[0] = _mm512_and_si512(bits1, m4); - qx[1] = _mm512_and_si512(bits2, m4); - qx[2] = _mm512_and_si512(bits3, m4); - qx[3] = _mm512_and_si512(bits4, m4); - qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4); - qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4); - qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits3, 4), m4); - qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits4, 4), m4); + auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]); for (int iy = 0; iy < nrc_y; ++iy) { - auto y4l = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0); - auto y4h = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1); - auto y8l = MM256_SET_M128I(y4l, y4l); - auto y8h = MM256_SET_M128I(y4h, y4h); - auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1); - auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff))); - sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff))); + auto sumi = dot(q8.y[iy][ib4].qs+32*k); auto dy = _mm512_set1_ps(d8[8*iy+k]); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq4l[ib], iq4h[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(qy[ib].qs); + auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + } + } for (int iy = 0; iy < nrc_y; ++iy) { auto sum = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]); acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); @@ -2981,12 +3041,56 @@ static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn #endif #ifdef HAVE_FANCY_SIMD +inline __m512i qx_r8_q8_dot_product(const __m512i * qx, const int8_t * y) { + auto y4l = _mm_loadu_si128((const __m128i*)y+0); + auto y4h = _mm_loadu_si128((const __m128i*)y+1); + auto y8l = MM256_SET_M128I(y4l, y4l); + auto y8h = MM256_SET_M128I(y4h, y4h); + auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1); + auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff))); + sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff))); + return sumi; +} +inline __m256i qx_r8_q8_dot_product(const __m256i * qx, const int8_t * y) { + auto y4l = _mm_loadu_si128((const __m128i*)y+0); + auto y4h = _mm_loadu_si128((const __m128i*)y+1); + auto yl = MM256_SET_M128I(y4l, y4l); + auto yh = MM256_SET_M128I(y4h, y4h); + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(yl, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(yl, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(yl, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(yl, 0xff)); + sumi = _mm256_dpbusd_epi32(sumi, qx[4], _mm256_shuffle_epi32(yh, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[5], _mm256_shuffle_epi32(yh, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[6], _mm256_shuffle_epi32(yh, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[7], _mm256_shuffle_epi32(yh, 0xff)); + return sumi; +} +inline __m256i q8_0_r8_dot_product(const uint8_t * x, const int8_t * y, __m256i * qx) { + qx[0] = _mm256_loadu_si256((const __m256i *)x+0); + qx[1] = _mm256_loadu_si256((const __m256i *)x+1); + qx[2] = _mm256_loadu_si256((const __m256i *)x+2); + qx[3] = _mm256_loadu_si256((const __m256i *)x+3); + qx[4] = _mm256_loadu_si256((const __m256i *)x+4); + qx[5] = _mm256_loadu_si256((const __m256i *)x+5); + qx[6] = _mm256_loadu_si256((const __m256i *)x+6); + qx[7] = _mm256_loadu_si256((const __m256i *)x+7); + return qx_r8_q8_dot_product(qx, y); +} template static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%16 == 0); Q8 q8(info); int nb = n / QK8_0; - GGML_ASSERT(nb%4 == 0); if constexpr (nrc_y == 1) { __m256 acc[2] = {}; __m256i qx[8]; @@ -2997,32 +3101,22 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn _mm256_storeu_ps(d8, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d))); for (int k = 0; k < 4; ++k) { auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d)); - qx[0] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0); - qx[1] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1); - qx[2] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2); - qx[3] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3); - qx[4] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4); - qx[5] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5); - qx[6] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6); - qx[7] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7); - auto y4l = _mm_loadu_si128((const __m128i*)q8.y[0][ib4].qs+2*k+0); - auto y4h = _mm_loadu_si128((const __m128i*)q8.y[0][ib4].qs+2*k+1); - auto yl = MM256_SET_M128I(y4l, y4l); - auto yh = MM256_SET_M128I(y4h, y4h); - auto sumi = _mm256_setzero_si256(); - sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(yl, 0x00)); - sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(yl, 0x55)); - sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(yl, 0xaa)); - sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(yl, 0xff)); - sumi = _mm256_dpbusd_epi32(sumi, qx[4], _mm256_shuffle_epi32(yh, 0x00)); - sumi = _mm256_dpbusd_epi32(sumi, qx[5], _mm256_shuffle_epi32(yh, 0x55)); - sumi = _mm256_dpbusd_epi32(sumi, qx[6], _mm256_shuffle_epi32(yh, 0xaa)); - sumi = _mm256_dpbusd_epi32(sumi, qx[7], _mm256_shuffle_epi32(yh, 0xff)); + auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[4*ib4+k].qs, q8.y[0][ib4].qs+32*k, qx); auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[k])); acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]); acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[k+4]), acc[1]); } } + if (4*(nb/4) < nb) { + auto qy = (const block_q8_1 *)q8.y[0]; + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d)); + auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[ib].qs, qy[ib].qs, qx); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]); + acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[1]); + } + } info.store(ix, 0, _mm256_fmadd_ps(_mm256_set1_ps(-127.f), acc[1], acc[0])); acc[0] = acc[1] = _mm256_setzero_ps(); } @@ -3046,27 +3140,29 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+j), 1); } for (int iy = 0; iy < nrc_y; ++iy) { - auto y4l = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0); - auto y4h = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1); - auto y8l = MM256_SET_M128I(y4l, y4l); - auto y8h = MM256_SET_M128I(y4h, y4h); - auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1); - auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff))); - sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff))); + auto sumi = qx_r8_q8_dot_product(qx, q8.y[iy][ib4].qs+32*k); auto dy = _mm512_set1_ps(d8[8*iy+k]); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[ib].d)); + auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[ib].d)); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + for (int j = 0; j < 8; ++j) { + qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[ib].qs+j)), + _mm256_loadu_si256((const __m256i *)q8h[ib].qs+j), 1); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = qx_r8_q8_dot_product(qx, qy[ib].qs); + auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + } + } for (int iy = 0; iy < nrc_y; ++iy) { auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-127.f), acc[2*iy+1], acc[2*iy+0]); info.store(ix, iy, sum512); @@ -3082,9 +3178,22 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn Q8 q8(info); auto m1 = _mm256_set1_epi16(1); int nb = n / QK8_0; - GGML_ASSERT(nb%4 == 0); __m256 acc[nrc_y] = {}; float d8[4*nrc_y]; + __m256i qx[4], sx[4]; + auto dot = [&qx, &sx, &m1] (const int8_t * qy) { + auto y128 = _mm_loadu_si128((const __m128i*)qy); + auto y = MM256_SET_M128I(y128, y128); + auto sumi1 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]))) + ); + auto sumi2 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]))) + ); + return _mm256_add_epi32(sumi1, sumi2); + }; for (int ix = 0; ix < nrc_x; ix += 8) { const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { @@ -3094,54 +3203,49 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn } for (int k = 0; k < 4; ++k) { auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d)); - auto q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0); - auto q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1); - auto q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2); - auto q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3); - auto s0 = _mm256_sign_epi8(q0, q0); - auto s1 = _mm256_sign_epi8(q1, q1); - auto s2 = _mm256_sign_epi8(q2, q2); - auto s3 = _mm256_sign_epi8(q3, q3); + for (int j = 0; j < 4; ++j) { + qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+j); + sx[j] = _mm256_sign_epi8(qx[j], qx[j]); + } for (int iy = 0; iy < nrc_y; ++iy) { - auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0); - auto y = MM256_SET_M128I(y128, y128); - auto sumi1 = _mm256_add_epi32( - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1))) - ); - auto sumi2 = _mm256_add_epi32( - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3))) - ); - auto sumi = _mm256_add_epi32(sumi1, sumi2); + auto sumi = dot(q8.y[iy][ib4].qs+32*k); auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k])); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); } - q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4); - q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5); - q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6); - q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7); - s0 = _mm256_sign_epi8(q0, q0); - s1 = _mm256_sign_epi8(q1, q1); - s2 = _mm256_sign_epi8(q2, q2); - s3 = _mm256_sign_epi8(q3, q3); + for (int j = 0; j < 4; ++j) { + qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4+j); + sx[j] = _mm256_sign_epi8(qx[j], qx[j]); + } for (int iy = 0; iy < nrc_y; ++iy) { - auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1); - auto y = MM256_SET_M128I(y128, y128); - auto sumi1 = _mm256_add_epi32( - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1))) - ); - auto sumi2 = _mm256_add_epi32( - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3))) - ); - auto sumi = _mm256_add_epi32(sumi1, sumi2); + auto sumi = dot(q8.y[iy][ib4].qs+32*k+16); auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k])); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d)); + for (int j = 0; j < 4; ++j) { + qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+j); + sx[j] = _mm256_sign_epi8(qx[j], qx[j]); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(qy[ib].qs); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + } + for (int j = 0; j < 4; ++j) { + qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+4+j); + sx[j] = _mm256_sign_epi8(qx[j], qx[j]); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(qy[ib].qs+16); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, acc[iy]); acc[iy] = _mm256_setzero_ps(); @@ -7080,6 +7184,7 @@ struct QFBase { static inline Acc acc_first(const Data& y, const Data& x) { return _mm512_mul_ps(y, x); } + static inline Acc add(Acc x, Acc y) { return _mm512_add_ps(x, y); } static inline float hsum(Acc acc) { return _mm512_reduce_add_ps(acc); } @@ -7118,6 +7223,7 @@ struct QFBase { static inline Acc acc(Acc prev, const Data& y, const Data& x) { return _mm256_fmadd_ps(y, x, prev); } + static inline Acc add(Acc x, Acc y) { return _mm256_add_ps(x, y); } static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) { acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc); acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc); @@ -7190,6 +7296,44 @@ template struct QFT final : public QFBase { const Float * y[nrc]; }; +// TBD if we want this +//template +//IQK_NOINLINE void mul_mat_Qx_Qy_Mx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { +// static_assert(Qy::nrc == 1); +// int nb = n/QFBase::k_step; +// int nb4 = n/4; +// Qy y(info); +// Qx x(cx + ix0*bx, bx); +// QFBase::Data xv[2*Qx::nrc]; +// QFBase::Acc acc[2*Qx::nrc]; +// auto yv1 = y.load1(0, 0); +// auto yv2 = y.load1(0, 1); +// for (int ix = 0; ix < Qx::nrc; ++ix) { +// xv[2*ix+0] = x.load1(ix, 0); +// xv[2*ix+1] = x.load1(ix, 1); +// acc[2*ix+0] = QFBase::acc_first(yv1, xv[2*ix+0]); +// acc[2*ix+1] = QFBase::acc_first(yv2, xv[2*ix+1]); +// } +// for (int i = 1; i < nb/2; ++i) { +// yv1 = y.load1(0, 2*i+0); +// yv2 = y.load1(0, 2*i+1); +// for (int ix = 0; ix < Qx::nrc; ++ix) { +// xv[2*ix+0] = x.load1(ix, 2*i+0); +// xv[2*ix+1] = x.load1(ix, 2*i+1); +// acc[2*ix+0] = QFBase::acc(acc[2*ix+0], yv1, xv[2*ix+0]); +// acc[2*ix+1] = QFBase::acc(acc[2*ix+1], yv2, xv[2*ix+1]); +// } +// } +// for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) { +// yv1 = y.load_tail(0, i); +// for (int ix = 0; ix < Qx::nrc; ++ix) { +// xv[ix] = x.load_tail(ix, i); +// acc[2*ix+0] = QFBase::acc(acc[2*ix+0], yv1, xv[ix]); +// } +// } +// for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, 0, QFBase::hsum(QFBase::add(acc[2*ix+0], acc[2*ix+1]))); +//} + template IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { int nb = n/QFBase::k_step; @@ -7287,12 +7431,29 @@ inline void mul_mat_Qx_Qy_MxN_fa4(int D, const char * cx, size_t bx, int ix0, co // f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now. template void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + const char * cx = (const char *)vx; + // TBD if we want this + //if constexpr (nrc_y == 1) { + // constexpr int k_nx = 2; + // for (int ix = 0; ix < nrc_x/k_nx; ++ix) { + // mul_mat_Qx_Qy_Mx1, QFT>(n, cx, bx, ix*k_nx, info); + // } + // if (int lastx = k_nx*(nrc_x/k_nx); lastx < nrc_x) { + // int nx = nrc_x - lastx; + // switch (nx) { + // case 1: mul_mat_Qx_Qy_Mx1, QFT>(n, cx, bx, lastx, info); break; + // case 2: mul_mat_Qx_Qy_Mx1, QFT>(n, cx, bx, lastx, info); break; + // case 3: mul_mat_Qx_Qy_Mx1, QFT>(n, cx, bx, lastx, info); break; + // } + // //mul_mat_Qx_Qy_Mx1, QFT>(n, cx, bx, lastx, info); + // } + // return; + //} #ifdef __AVX512F__ constexpr int k_nx = 5; #else constexpr int k_nx = nrc_y == 1 ? 4 : 2; #endif - const char * cx = (const char *)vx; for (int ix = 0; ix < nrc_x/k_nx; ++ix) { mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, ix*k_nx, info); } @@ -12146,7 +12307,6 @@ void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, Q8 q8(info); Dequantizer deq(vx, bx); int nb = n / QK4_NL; - GGML_ASSERT(nb%4 == 0); int8x16_t qx[16]; float d8[4*nrc_y]; float32x4_t acc[2*nrc_y] = {}; @@ -12168,6 +12328,18 @@ void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = deq.prepare(ib, 0, qx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_0 *)q8.y[iy]; + auto y = vld1q_s8_x2(qy[ib].qs); + auto sumi1 = interleaved_dotq(qx+0, y); + auto sumi2 = interleaved_dotq(qx+8, y); + auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d)); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales.val[0], dy), vcvtq_f32_s32(sumi1)); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales.val[1], dy), vcvtq_f32_s32(sumi2)); + } + } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix+0, iy, deq.result(acc[2*iy+0])); info.store(ix+4, iy, deq.result(acc[2*iy+1])); @@ -12312,12 +12484,32 @@ struct Q6_0_R4_Dequantizer { const int8x16_t m32 = vdupq_n_s8(-32); }; +inline void qx_0_q8_0_dot(const int8x16_t * qx, const int8_t * qy, int32x4_t& sumi1, int32x4_t& sumi2) { + auto y = vld1q_s8_x2(qy); + sumi1 = sumi2 = vdupq_n_s32(0); + sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0); + sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0); + sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1); + sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1); + sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2); + sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2); + sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3); + sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3); + sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0); + sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0); + sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1); + sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1); + sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2); + sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2); + sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3); + sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3); +} + template void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); Q8 q8(info); int nb = n / QK8_0; - GGML_ASSERT(nb%4 == 0); float32x4_t acc[2*nrc_y] = {}; int8x16_t qx[16]; float d8[4*nrc_y]; @@ -12332,32 +12524,29 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16)); auto scales2 = vcvt_f32_f16(vget_high_f16(scales16)); for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j); + int32x4_t sumi1, sumi2; for (int iy = 0; iy < nrc_y; ++iy) { - auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); - auto sumi1 = vdupq_n_s32(0); - auto sumi2 = vdupq_n_s32(0); - sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0); - sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0); - sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1); - sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1); - sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2); - sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2); - sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3); - sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3); - sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0); - sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0); - sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1); - sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1); - sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2); - sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2); - sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3); - sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3); + qx_0_q8_0_dot(qx, q8.y[iy][ib4].qs+32*k, sumi1, sumi2); auto dy = vdupq_n_f32(d8[4*iy+k]); acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1)); acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2)); } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales16 = vld1q_f16((const float16_t *)iq8[ib].d); + auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16)); + auto scales2 = vcvt_f32_f16(vget_high_f16(scales16)); + for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[ib].qs + 16*j); + int32x4_t sumi1, sumi2; + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_0 *)q8.y[iy]; + qx_0_q8_0_dot(qx, qy[ib].qs, sumi1, sumi2); + auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d)); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1)); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2)); + } + } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix+0, iy, acc[2*iy+0]); info.store(ix+4, iy, acc[2*iy+1]); @@ -13033,10 +13222,10 @@ struct HelperQ80R4 : public BaseHelper { m2 = _mm256_unpacklo_epi64(t2, t3); m3 = _mm256_unpackhi_epi64(t2, t3); #ifdef HAVE_FANCY_SIMD - m0 = _mm256_xor_si256(m0, _mm256_set1_epi8(-128)); - m1 = _mm256_xor_si256(m1, _mm256_set1_epi8(-128)); - m2 = _mm256_xor_si256(m2, _mm256_set1_epi8(-128)); - m3 = _mm256_xor_si256(m3, _mm256_set1_epi8(-128)); + m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); #endif _mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0); _mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1); @@ -13055,10 +13244,10 @@ struct HelperQ80R4 : public BaseHelper { m2 = _mm256_unpacklo_epi64(t2, t3); m3 = _mm256_unpackhi_epi64(t2, t3); #ifdef HAVE_FANCY_SIMD - m0 = _mm256_xor_si256(m0, _mm256_set1_epi8(-128)); - m1 = _mm256_xor_si256(m1, _mm256_set1_epi8(-128)); - m2 = _mm256_xor_si256(m2, _mm256_set1_epi8(-128)); - m3 = _mm256_xor_si256(m3, _mm256_set1_epi8(-128)); + m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); #endif _mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0); _mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1); @@ -13895,16 +14084,11 @@ struct FlashQKfp32 { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0= 128) { #ifdef HAVE_FANCY_SIMD - MAKE_FUNCS(mul_mat_qX_1_q8_1_T>) { From 2e6b523853a8659c63283a6deca805051ecd713a Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 30 Jan 2025 09:28:53 +0200 Subject: [PATCH 04/14] Faster Q4_K_R4 and Q5_K_R4 on AVX2/Zen4 (#182) * Slightly faster AVX2 implementation for q4_k_r4 * Even better AVX2 implementation for q4_k_r4 We now arrive at PP-512 = 328 t/s for LLaMA-3.1-8B on a Ryzen-5975WX CPU, up from 291 t/s when I last measured on 3c5f8722. With FA and Q8_0 K-cache we get to 339.5 t/s. * Fix llama-bench labels that I broke with #181 * Faster AVX2 implementation for q5_k_q4 We arrive at 302 t/s for LLaMA-3.1-8B on a Ryzen-5975WX CPU, up from 273 t/s. * Use AVX2 implementation of q4_k_r4 and q5_k_r4 also on Zen4 After the changes I made to AVX2, it ends up being slightly faster compared to what I had for Zen4. * Minor tweak * Cleanup --------- Co-authored-by: Iwan Kawrakow --- examples/llama-bench/llama-bench.cpp | 4 +- ggml/src/iqk/iqk_mul_mat.cpp | 319 ++++++--------------------- 2 files changed, 68 insertions(+), 255 deletions(-) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index b46bd855..42320da8 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -756,7 +756,7 @@ static std::vector get_cmd_params_instances(const cmd_param continue; } cmd_params_instance instance = { - /* .test_kind = */ TEST_KIND_PP, + /* .test_kind = */ TEST_KIND_TG, /* .model = */ m, /* .n_prompt = */ 0, /* .n_gen = */ n_gen, @@ -784,7 +784,7 @@ static std::vector get_cmd_params_instances(const cmd_param continue; } cmd_params_instance instance = { - /* .test_kind = */ TEST_KIND_PP, + /* .test_kind = */ TEST_KIND_PG, /* .model = */ m, /* .n_prompt = */ n_pg.first, /* .n_gen = */ n_pg.second, diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 308d0dca..7fd56c42 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -4430,17 +4430,47 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI } template -static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +inline void process_min_r4_b32(int ibl, __m256 m4, __m256i mins, const Q8& q8, __m256 * acc) { + auto mins_l = _mm256_castsi256_si128(mins); + auto mins_h = _mm256_extracti128_si256(mins, 1); + auto aux1 = _mm_unpacklo_epi32(mins_l, mins_h); + auto aux2 = _mm_unpackhi_epi32(mins_l, mins_h); + auto ic1 = _mm256_cvtepi8_epi32(aux1); + auto ic2 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux1, 0xee)); + auto ic3 = _mm256_cvtepi8_epi32(aux2); + auto ic4 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux2, 0xee)); + if constexpr (nrc_y == 1) { + auto bs = _mm256_loadu_ps((const float *)q8.y[0][ibl].bsums); + auto sumf = _mm256_mul_ps(_mm256_cvtepi32_ps(ic1), _mm256_shuffle_ps(bs, bs, 0x00)); + sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic2), _mm256_shuffle_ps(bs, bs, 0x55), sumf); + sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic3), _mm256_shuffle_ps(bs, bs, 0xaa), sumf); + sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic4), _mm256_shuffle_ps(bs, bs, 0xff), sumf); + acc[0] = _mm256_fmadd_ps(m4, sumf, acc[0]); + } else { + auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic1)); + auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic2)); + auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic3)); + auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic4)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums); + acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]); + acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]); + acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]); + acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]); + } + } +} + +template +static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); Q8 q8(info); auto mf = _mm256_set1_epi8(0xf); auto m3 = _mm256_set1_epi8(0x30); -#ifndef HAVE_FANCY_SIMD - auto m1 = _mm256_set1_epi16(1); -#endif int nbl = n / QK_K; union { __m256i vec; uint32_t val[8]; } hd; __m256 acc[nrc_y] = {}; + __m256i isum[nrc_y] = {}; __m256i qx[4]; for (int ix = 0; ix < nrc_x; ix += 4) { const block_q4_k_r4 * iq4 = (const block_q4_k_r4 *)((const char *)vx + (ix+0)*bx); @@ -4448,31 +4478,20 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d)); auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl)); auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1))); - if constexpr (nrc_y == 1) { - d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl))); - } auto lbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l); auto hbits128 = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_h); auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4)); hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m3)); auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m3)); - auto shuffle = _mm256_set1_epi64x(0x0000000400000000); - auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); - shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); - auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); - shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); - auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); - shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); - auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); - for (int iy = 0; iy < nrc_y; ++iy) { - auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums); - acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]); - acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]); - acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]); - acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]); - } + process_min_r4_b32(ibl, m4, mins, q8, acc); for (int ib = 0; ib < QK_K/32; ++ib) { - auto scales_d = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib])))); +#ifdef HAVE_FANCY_SIMD + auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib])); +#else + auto aux = _mm_set1_epi32(hd.val[ib]); + aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux)); + auto scales_d = MM256_SET_M128I(aux, aux); +#endif auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0); auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1); qx[0] = _mm256_and_si256(bits1, mf); @@ -4487,21 +4506,20 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi)); #else auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); - auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales_d, _mm256_add_epi16(sumi1, sumi2))); #endif - if constexpr (nrc_y == 1) { - acc[iy] = _mm256_fmadd_ps(scales_d, _mm256_cvtepi32_ps(sumi), acc[iy]); - } else { - float d8 = q8.scale(iy, ibl); - acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales_d, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]); - } } } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } } for (int iy = 0; iy < nrc_y; ++iy) { auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); @@ -4511,113 +4529,17 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D } } -#ifdef HAVE_FANCY_SIMD template -static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - //mul_mat_q4_k_r4_q8_k_avx2(n, vx, bx, info, nrc_x); - if constexpr (nrc_y == 1){ - mul_mat_q4_k_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x); - } else { - GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); - auto mf = _mm512_set1_epi8(0xf); - int nbl = n / QK_K; - using helper_t = union { __m512i vec; uint32_t val[16]; }; - helper_t hd, hm; - __m512 acc[nrc_y] = {}; - __m512i isum[nrc_y] = {}; - __m512i qx[4]; - for (int ix = 0; ix < nrc_x; ix += 8) { - const block_q4_k_r4 * iq4l = (const block_q4_k_r4 *)((const char *)vx + (ix+0)*bx); - const block_q4_k_r4 * iq4h = (const block_q4_k_r4 *)((const char *)vx + (ix+4)*bx); - for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 - auto d1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l[ibl].d)); - auto d2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h[ibl].d)); - auto dl = _mm256_castps256_ps128(d1); - auto ml = _mm256_extractf128_ps(d1, 1); - auto dh = _mm256_castps256_ps128(d2); - auto mh = _mm256_extractf128_ps(d2, 1); - auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1); - auto m4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(ml, ml)), _mm256_set_m128(mh, mh), 1); - m4 = _mm512_mul_ps(m4, _mm512_set1_ps(-0.5f)); - auto slbits_l = _mm256_loadu_si256((const __m256i *)iq4l[ibl].scales_l); - auto shbits_l = _mm256_loadu_si256((const __m256i *)iq4h[ibl].scales_l); - auto slb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_l), shbits_l, 1); - auto sld = _mm512_and_si512(slb, mf); - auto slm = _mm512_and_si512(_mm512_srli_epi16(slb, 4), mf); - auto slbits_h = _mm_loadu_si128((const __m128i *)iq4l[ibl].scales_h); - auto shbits_h = _mm_loadu_si128((const __m128i *)iq4h[ibl].scales_h); - auto slbits_h2 = MM256_SET_M128I(_mm_srli_epi16(slbits_h, 4), slbits_h); - auto shbits_h2 = MM256_SET_M128I(_mm_srli_epi16(shbits_h, 4), shbits_h); - auto shb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_h2), shbits_h2, 1); - auto shd = _mm512_and_si512(_mm512_slli_epi16(shb, 4), _mm512_set1_epi8(0x30)); - auto shm = _mm512_and_si512(_mm512_slli_epi16(shb, 2), _mm512_set1_epi8(0x30)); - hd.vec = _mm512_or_si512(sld, shd); - hm.vec = _mm512_or_si512(slm, shm); - for (int ib = 0; ib < QK_K/32; ++ib) { - auto scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+0])); - auto scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+8])); - auto iscales = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1); - scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+0])); - scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+8])); - auto iscales_m = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1); - auto scales_m = _mm512_mul_ps(m4, _mm512_cvtepi32_ps(iscales_m)); - auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+0)), - _mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+0), 1); - auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+1)), - _mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+1), 1); - qx[0] = _mm512_and_si512(bits1, mf); - qx[1] = _mm512_and_si512(bits2, mf); - qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), mf); - qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), mf); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi)); - float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; - acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]); - isum[iy] = _mm512_setzero_si512(); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1)); - auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3)); - info.store(ix+0, iy, sum1); - info.store(ix+4, iy, sum2); - acc[iy] = _mm512_setzero_ps(); - } - } - } -} -#else -template -static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - mul_mat_q4_k_r4_q8_k_avx2(n, vx, bx, info, nrc_x); -} -#endif - -template -static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); Q8 q8(info); auto mf = _mm256_set1_epi8(0xf); auto m10 = _mm256_set1_epi8(0x10); auto m30 = _mm256_set1_epi8(0x30); -#ifndef HAVE_FANCY_SIMD - auto m1 = _mm256_set1_epi16(1); -#endif int nbl = n / QK_K; union { __m256i vec; uint32_t val[8]; } hd; __m256 acc[nrc_y] = {}; + __m256i isum[nrc_y] = {}; __m256i qx[4]; for (int ix = 0; ix < nrc_x; ix += 4) { const block_q5_k_r4 * iq5 = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx); @@ -4625,31 +4547,20 @@ static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5[ibl].d)); auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl)); auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1))); - if constexpr (nrc_y == 1) { - d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl))); - } auto lbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l); auto hbits128 = _mm_loadu_si128((const __m128i *)iq5[ibl].scales_h); auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4)); hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m30)); auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m30)); - auto shuffle = _mm256_set1_epi64x(0x0000000400000000); - auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); - shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); - auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); - shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); - auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); - shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); - auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); - for (int iy = 0; iy < nrc_y; ++iy) { - auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums); - acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]); - acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]); - acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]); - acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]); - } + process_min_r4_b32(ibl, m4, mins, q8, acc); for (int ib = 0; ib < QK_K/32; ++ib) { - auto scales_d = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib])))); +#ifdef HAVE_FANCY_SIMD + auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib])); +#else + auto aux = _mm_set1_epi32(hd.val[ib]); + aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux)); + auto scales_d = MM256_SET_M128I(aux, aux); +#endif auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0); auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1); auto hbits128 = _mm_loadu_si128((const __m128i*)iq5[ibl].qh + ib); @@ -4666,21 +4577,22 @@ static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi)); #else auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); - auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); + // To avoid overflow, we can only add up to 4 q5 x q8 products. + auto sumi = _mm256_add_epi32(_mm256_madd_epi16(scales_d, sumi1), _mm256_madd_epi16(scales_d, sumi2)); + isum[iy] = _mm256_add_epi32(isum[iy], sumi); #endif - if constexpr (nrc_y == 1) { - acc[iy] = _mm256_fmadd_ps(scales_d, _mm256_cvtepi32_ps(sumi), acc[iy]); - } else { - float d8 = q8.scale(iy, ibl); - acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales_d, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]); - } } } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } } for (int iy = 0; iy < nrc_y; ++iy) { auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); @@ -4690,105 +4602,6 @@ static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D } } -#ifdef HAVE_FANCY_SIMD -template -static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - if constexpr (nrc_y == 1){ - mul_mat_q5_k_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x); - } else { - GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); - auto mf = _mm512_set1_epi8(0xf); - auto m10 = _mm512_set1_epi8(0x10); - int nbl = n / QK_K; - using helper_t = union { __m512i vec; uint32_t val[16]; }; - helper_t hd, hm; - __m512 acc[nrc_y] = {}; - __m512i isum[nrc_y] = {}; - __m512i qx[4]; - for (int ix = 0; ix < nrc_x; ix += 8) { - const block_q5_k_r4 * iq5l = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx); - const block_q5_k_r4 * iq5h = (const block_q5_k_r4 *)((const char *)vx + (ix+4)*bx); - for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 - auto d1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5l[ibl].d)); - auto d2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5h[ibl].d)); - auto dl = _mm256_castps256_ps128(d1); - auto ml = _mm256_extractf128_ps(d1, 1); - auto dh = _mm256_castps256_ps128(d2); - auto mh = _mm256_extractf128_ps(d2, 1); - auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1); - auto m4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(ml, ml)), _mm256_set_m128(mh, mh), 1); - m4 = _mm512_mul_ps(m4, _mm512_set1_ps(-0.5f)); - auto slbits_l = _mm256_loadu_si256((const __m256i *)iq5l[ibl].scales_l); - auto shbits_l = _mm256_loadu_si256((const __m256i *)iq5h[ibl].scales_l); - auto slb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_l), shbits_l, 1); - auto sld = _mm512_and_si512(slb, mf); - auto slm = _mm512_and_si512(_mm512_srli_epi16(slb, 4), mf); - auto slbits_h = _mm_loadu_si128((const __m128i *)iq5l[ibl].scales_h); - auto shbits_h = _mm_loadu_si128((const __m128i *)iq5h[ibl].scales_h); - auto slbits_h2 = MM256_SET_M128I(_mm_srli_epi16(slbits_h, 4), slbits_h); - auto shbits_h2 = MM256_SET_M128I(_mm_srli_epi16(shbits_h, 4), shbits_h); - auto shb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_h2), shbits_h2, 1); - auto shd = _mm512_and_si512(_mm512_slli_epi16(shb, 4), _mm512_set1_epi8(0x30)); - auto shm = _mm512_and_si512(_mm512_slli_epi16(shb, 2), _mm512_set1_epi8(0x30)); - hd.vec = _mm512_or_si512(sld, shd); - hm.vec = _mm512_or_si512(slm, shm); - for (int ib = 0; ib < QK_K/32; ++ib) { - auto scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+0])); - auto scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+8])); - auto iscales = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1); - scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+0])); - scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+8])); - auto iscales_m = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1); - auto scales_m = _mm512_mul_ps(m4, _mm512_cvtepi32_ps(iscales_m)); - auto lbits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[ibl].qs+2*ib+0)), - _mm256_loadu_si256((const __m256i *)iq5h[ibl].qs+2*ib+0), 1); - auto lbits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[ibl].qs+2*ib+1)), - _mm256_loadu_si256((const __m256i *)iq5h[ibl].qs+2*ib+1), 1); - auto hbits1 = _mm_loadu_si128((const __m128i*)iq5l[ibl].qh+ib); - auto hbits2 = _mm_loadu_si128((const __m128i*)iq5h[ibl].qh+ib); - auto hbl = MM256_SET_M128I(hbits1, _mm_slli_epi16(hbits1, 4)); - auto hbh = MM256_SET_M128I(hbits2, _mm_slli_epi16(hbits2, 4)); - auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbl), hbh, 1); - qx[0] = _mm512_or_si512(_mm512_and_si512(lbits1, mf), _mm512_and_si512(m10, hbits)); - qx[1] = _mm512_or_si512(_mm512_and_si512(lbits2, mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 2))); - qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(lbits1, 4), mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 1))); - qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(lbits2, 4), mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 3))); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi)); - float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; - acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]); - isum[iy] = _mm512_setzero_si512(); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1)); - auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3)); - info.store(ix+0, iy, sum1); - info.store(ix+4, iy, sum2); - acc[iy] = _mm512_setzero_ps(); - } - } - } -} -#else -template -static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - mul_mat_q5_k_r4_q8_k_avx2(n, vx, bx, info, nrc_x); -} -#endif - template static void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); From ecf111a11ca56ff0731308f94bd6c5e96658b6ef Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 30 Jan 2025 18:36:24 +0200 Subject: [PATCH 05/14] Deepseek-Lite (#184) * Quantization mixes tweaks * Make iq4_nl_r4 work with row size that are not a multiple of 128 ... on Zen4 * Make iq4_nl_r4 work with row size that are not a multiple of 128 ... on AVX2 * Make iq4_nl_r4 work with row size that are not a multiple of 128 ... on AVX2 * Make q6_0_w4 work with row size that are not a multiple of 128 ... on Zen4 * Make q6_0_w4 work with row size that are not a multiple of 128 ... on Zen4 * Make q5_0_r4 work with row size that are not a multiple of 128 ... on Zen4 and AVX2 * Make q5,6_0_r4, iq4_nl_e4 work with row size that are not a multiple of 128 also on NEON. --------- Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 467 ++++++++++++++++++++++------------- src/llama.cpp | 18 +- 2 files changed, 315 insertions(+), 170 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 7fd56c42..f633229d 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -2474,44 +2474,63 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data auto m4 = _mm512_set1_epi8(0xf); auto values = load_iq4nl_values_512(); int nb = n / QK4_NL; - GGML_ASSERT(nb%4 == 0); __m512 acc[2*nrc_y] = {}; __m512i qx[4]; + float d8[8*nrc_y]; + auto prepare = [&qx, &m4, &values] (const block_iq4_nl_r4& iq4l, const block_iq4_nl_r4& iq4h) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l.d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h.d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+0)), + _mm256_loadu_si256((const __m256i *)iq4h.qs+0), 1); + auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+1)), + _mm256_loadu_si256((const __m256i *)iq4h.qs+1), 1); + qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4)); + qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4)); + qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4)); + qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4)); + return scales; + }; + auto dot = [&qx] (__m256i y8) { + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + return sumi; + }; for (int ix = 0; ix < nrc_x; ix += 8) { const block_iq4_nl_r4 * iq4l = (const block_iq4_nl_r4 *)((const char *)vx + (ix+0)*bx); const block_iq4_nl_r4 * iq4h = (const block_iq4_nl_r4 *)((const char *)vx + (ix+4)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d))); + } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[4*ib4+k].d)); - auto scales1 = _mm256_set_m128(scales128, scales128); - scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[4*ib4+k].d)); - auto scales2 = _mm256_set_m128(scales128, scales128); - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-64.f)); - auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)), - _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1); - auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)), - _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1); - qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4)); - qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4)); - qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4)); - qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4)); + auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]); for (int iy = 0; iy < nrc_y; ++iy) { - auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])); + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); + auto dy = _mm512_set1_ps(d8[8*iy+k]); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq4l[ib], iq4h[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + } + } for (int iy = 0; iy < nrc_y; ++iy) { - auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]); + auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-64.f), acc[2*iy+1], acc[2*iy+0]); acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); @@ -2530,37 +2549,57 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); auto values = MM256_SET_M128I(values128, values128); int nb = n / QK4_NL; - GGML_ASSERT(nb%4 == 0); __m256 acc[nrc_y] = {}; - //__m256 acc[2*nrc_y] = {}; + __m256i qs[4]; + float d8[4*nrc_y]; + auto prepare = [&qs, &values, &m4] (const block_iq4_nl_r4& iq4) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4.d)); + auto scales = _mm256_set_m128(scales128, scales128); + auto bits1 = _mm256_loadu_si256((const __m256i *)iq4.qs+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq4.qs+1); + qs[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); + qs[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); + qs[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); + qs[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); + return scales; + }; + auto dot = [&qs, &m1] (__m256i y) { + auto u1 = _mm256_sign_epi8(qs[0], qs[0]); + auto u2 = _mm256_sign_epi8(qs[1], qs[1]); + auto sumi1 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qs[0]))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qs[1])))); + u1 = _mm256_sign_epi8(qs[2], qs[2]); + u2 = _mm256_sign_epi8(qs[3], qs[3]); + auto sumi2 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qs[2]))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qs[3])))); + return _mm256_add_epi32(sumi1, sumi2); + }; for (int ix = 0; ix < nrc_x; ix += 4) { const block_iq4_nl_r4 * iq4 = (const block_iq4_nl_r4 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + _mm_storeu_ps(d8+4*iy, _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d))); + } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[4*ib4+k].d)); - auto scales = _mm256_set_m128(scales128, scales128); - auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+0); - auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+1); - auto q1 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); - auto q2 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); - auto q3 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); - auto q4 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); - auto s1 = _mm256_sign_epi8(q1, q1); - auto s2 = _mm256_sign_epi8(q2, q2); - auto s3 = _mm256_sign_epi8(q3, q3); - auto s4 = _mm256_sign_epi8(q4, q4); - + auto scales = prepare(iq4[4*ib4+k]); for (int iy = 0; iy < nrc_y; ++iy) { - auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q1))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q2)))); - auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q3))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q4)))); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]))); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), acc[iy]); + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq4[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } for (int iy = 0; iy < nrc_y; ++iy) { auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); info.store(ix, iy, sum); @@ -2797,43 +2836,73 @@ static void mul_mat_q5_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D Q8 q8(info); auto m4 = _mm256_set1_epi8(0xf); auto m5 = _mm256_set1_epi8(0x10); +#ifndef HAVE_FANCY_SIMD auto m1 = _mm256_set1_epi16(1); +#endif + auto mscale = _mm256_set_m128(_mm_set1_ps(-8.f), _mm_set1_ps(1.f)); int nb = n / QK5_0; - GGML_ASSERT(nb%4 == 0); __m256 acc[nrc_y] = {}; + __m256i qx[4]; float d8[8*nrc_y]; + auto prepare = [&qx, &m4, &m5] (const block_q5_0_r4& iq5) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5.d)); + auto scales = _mm256_set_m128(scales128, scales128); + auto bits1 = _mm256_loadu_si256((const __m256i *)iq5.qs+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq5.qs+1); + auto hbits = _mm_loadu_si128((const __m128i *)iq5.qh); + auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 1), hbits); + qx[0] = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 4), m5)); + qx[1] = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 2), m5)); + qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hb, m5)); + qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hb, 2), m5));; + return scales; + }; +#ifdef HAVE_FANCY_SIMD + auto dot = [&qx] (__m256i y) { + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + return sumi; + }; +#else + auto dot = [&qx, &m1] (__m256i y) { + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); + return sumi; + }; +#endif for (int ix = 0; ix < nrc_x; ix += 4) { const block_q5_0_r4 * iq5 = (const block_q5_0_r4 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); - _mm256_storeu_ps(d8 + 8*iy, scales); + _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(mscale, scales)); } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5[4*ib4+k].d)); - auto scales = _mm256_set_m128(scales128, scales128); - auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-8.f)); - auto bits1 = _mm256_loadu_si256((const __m256i *)iq5[4*ib4+k].qs+0); - auto bits2 = _mm256_loadu_si256((const __m256i *)iq5[4*ib4+k].qs+1); - auto hbits = _mm_loadu_si128((const __m128i *)iq5[4*ib4+k].qh); - auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 1), hbits); - auto q1 = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 4), m5)); - auto q2 = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 2), m5)); - auto q3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hb, m5)); - auto q4 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hb, 2), m5));; + auto scales = prepare(iq5[4*ib4+k]); for (int iy = 0; iy < nrc_y; ++iy) { - auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00)), - _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55))); - auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa)), - _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff))); - auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k])); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]); } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq5[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-8.f*GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]); + } + } for (int iy = 0; iy < nrc_y; ++iy) { auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); info.store(ix, iy, sum); @@ -2853,50 +2922,68 @@ static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn auto m4 = _mm512_set1_epi8(0xf); auto m5 = _mm512_set1_epi8(0x10); int nb = n / QK5_0; - GGML_ASSERT(nb%4 == 0); __m512 acc[2*nrc_y] = {}; __m512i qx[4]; + float d8[8*nrc_y]; + auto prepare = [&qx, &m4, &m5] (const block_q5_0_r4& iq5l, const block_q5_0_r4& iq5h) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5l.d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5h.d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l.qs+0)), + _mm256_loadu_si256((const __m256i *)iq5h.qs+0), 1); + auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l.qs+1)), + _mm256_loadu_si256((const __m256i *)iq5h.qs+1), 1); + auto hbits1 = _mm_loadu_si128((const __m128i *)iq5l.qh); + auto hbits2 = _mm_loadu_si128((const __m128i *)iq5h.qh); + auto hb1 = MM256_SET_M128I(_mm_srli_epi16(hbits1, 1), hbits1); + auto hb2 = MM256_SET_M128I(_mm_srli_epi16(hbits2, 1), hbits2); + auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hb1), hb2, 1); + qx[0] = _mm512_or_si512(_mm512_and_si512(bits1, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 4), m5)); + qx[1] = _mm512_or_si512(_mm512_and_si512(bits2, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5)); + qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4), _mm512_and_si512(hb, m5)); + qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4), _mm512_and_si512(_mm512_srli_epi16(hb, 2), m5)); + return scales; + }; + auto dot = [&qx] (__m256i y8) { + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + return sumi; + }; for (int ix = 0; ix < nrc_x; ix += 8) { const block_q5_0_r4 * iq5l = (const block_q5_0_r4 *)((const char *)vx + (ix+0)*bx); const block_q5_0_r4 * iq5h = (const block_q5_0_r4 *)((const char *)vx + (ix+4)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d))); + } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5l[4*ib4+k].d)); - auto scales1 = _mm256_set_m128(scales128, scales128); - scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5h[4*ib4+k].d)); - auto scales2 = _mm256_set_m128(scales128, scales128); - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-8.f)); - auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[4*ib4+k].qs+0)), - _mm256_loadu_si256((const __m256i *)iq5h[4*ib4+k].qs+0), 1); - auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[4*ib4+k].qs+1)), - _mm256_loadu_si256((const __m256i *)iq5h[4*ib4+k].qs+1), 1); - auto hbits1 = _mm_loadu_si128((const __m128i *)iq5l[4*ib4+k].qh); - auto hbits2 = _mm_loadu_si128((const __m128i *)iq5h[4*ib4+k].qh); - auto hb1 = MM256_SET_M128I(_mm_srli_epi16(hbits1, 1), hbits1); - auto hb2 = MM256_SET_M128I(_mm_srli_epi16(hbits2, 1), hbits2); - auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hb1), hb2, 1); - qx[0] = _mm512_or_si512(_mm512_and_si512(bits1, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 4), m5)); - qx[1] = _mm512_or_si512(_mm512_and_si512(bits2, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5)); - //qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5); - qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4), _mm512_and_si512(hb, m5)); - qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4), _mm512_and_si512(_mm512_srli_epi16(hb, 2), m5)); + auto scales = prepare(iq5l[4*ib4+k], iq5h[4*ib4+k]); for (int iy = 0; iy < nrc_y; ++iy) { - auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])); + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); + auto dy = _mm512_set1_ps(d8[8*iy+k]); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq5l[ib], iq5h[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + } + } for (int iy = 0; iy < nrc_y; ++iy) { - auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]); + auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]); acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); @@ -2919,51 +3006,72 @@ static void mul_mat_q6_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D Q8 q8(info); auto m4 = _mm256_set1_epi8(0xf); auto m6 = _mm256_set1_epi8(0x30); + auto mscale = _mm256_set_m128(_mm_set1_ps(-16.f), _mm_set1_ps(1.f)); #ifndef HAVE_FANCY_SIMD auto m1 = _mm256_set1_epi16(1); #endif int nb = n / QK6_0; - GGML_ASSERT(nb%4 == 0); __m256 acc[nrc_y] = {}; float d8[8*nrc_y]; + __m256i qx[4]; + auto prepare = [&qx, &m4, &m6] (const block_q6_0_r4& iq6) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6.d)); + auto scales = _mm256_set_m128(scales128, scales128); + auto bits1 = _mm256_loadu_si256((const __m256i *)iq6.qs+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq6.qs+1); + auto hbits = _mm256_loadu_si256((const __m256i *)iq6.qh); + qx[0] = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), m6)); + qx[1] = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), m6)); + qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hbits, m6)); + qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m6)); + return scales; + }; +#ifdef HAVE_FANCY_SIMD + auto dot = [&qx] (__m256i y) { + auto sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + return sumi; + }; +#else + auto dot = [&qx, &m1] (__m256i y) { + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2)); + return sumi; + }; +#endif for (int ix = 0; ix < nrc_x; ix += 4) { const block_q6_0_r4 * iq6 = (const block_q6_0_r4 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); - _mm256_storeu_ps(d8 + 8*iy, scales); + _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(scales, mscale)); } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6[4*ib4+k].d)); - auto scales = _mm256_set_m128(scales128, scales128); - auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-16.f)); - auto bits1 = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qs+0); - auto bits2 = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qs+1); - auto hbits = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qh); - auto q1 = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), m6)); - auto q2 = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), m6)); - auto q3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hbits, m6)); - auto q4 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m6)); + auto scales = prepare(iq6[4*ib4+k]); for (int iy = 0; iy < nrc_y; ++iy) { - auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); -#ifdef HAVE_FANCY_SIMD - auto sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), q1, _mm256_shuffle_epi32(y, 0x00)); - sumi = _mm256_dpbusd_epi32(sumi, q2, _mm256_shuffle_epi32(y, 0x55)); - sumi = _mm256_dpbusd_epi32(sumi, q3, _mm256_shuffle_epi32(y, 0xaa)); - sumi = _mm256_dpbusd_epi32(sumi, q4, _mm256_shuffle_epi32(y, 0xff)); -#else - auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00)), - _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55))); - auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa)), - _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff))); - auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2)); -#endif + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k])); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]); } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq6[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-16.f*GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); info.store(ix, iy, sum); @@ -2983,47 +3091,67 @@ static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn auto m4 = _mm512_set1_epi8(0xf); auto m6 = _mm512_set1_epi8(0x30); int nb = n / QK6_0; - GGML_ASSERT(nb%4 == 0); __m512 acc[2*nrc_y] = {}; __m512i qx[4]; + float d8[8*nrc_y]; + auto prepare = [&qx, &m4, &m6] (const block_q6_0_r4& iq6l, const block_q6_0_r4& iq6h) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6l.d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6h.d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l.qs+0)), + _mm256_loadu_si256((const __m256i *)iq6h.qs+0), 1); + auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l.qs+1)), + _mm256_loadu_si256((const __m256i *)iq6h.qs+1), 1); + auto hbits1 = _mm256_loadu_si256((const __m256i *)iq6l.qh); + auto hbits2 = _mm256_loadu_si256((const __m256i *)iq6h.qh); + auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hbits1), hbits2, 1); + qx[0] = _mm512_and_si512(bits1, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 4), m6); + qx[1] = _mm512_and_si512(bits2, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m6);; + qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(hb, m6); + qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4) | _mm512_and_si512(_mm512_srli_epi16(hb, 2), m6); + return scales; + }; + auto dot = [&qx] (__m256i y8) { + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + return sumi; + }; for (int ix = 0; ix < nrc_x; ix += 8) { const block_q6_0_r4 * iq6l = (const block_q6_0_r4 *)((const char *)vx + (ix+0)*bx); const block_q6_0_r4 * iq6h = (const block_q6_0_r4 *)((const char *)vx + (ix+4)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); + _mm256_storeu_ps(d8 + 8*iy, scales); + } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6l[4*ib4+k].d)); - auto scales1 = _mm256_set_m128(scales128, scales128); - scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6h[4*ib4+k].d)); - auto scales2 = _mm256_set_m128(scales128, scales128); - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-16.f)); - auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qs+0)), - _mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qs+0), 1); - auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qs+1)), - _mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qs+1), 1); - auto hbits1 = _mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qh); - auto hbits2 = _mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qh); - auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hbits1), hbits2, 1); - qx[0] = _mm512_and_si512(bits1, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 4), m6); - qx[1] = _mm512_and_si512(bits2, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m6);; - qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(hb, m6); - qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4) | _mm512_and_si512(_mm512_srli_epi16(hb, 2), m6); + auto scales = prepare(iq6l[4*ib4+k], iq6h[4*ib4+k]); for (int iy = 0; iy < nrc_y; ++iy) { - auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])); + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); + auto dy = _mm512_set1_ps(d8[8*iy+k]); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq6l[ib], iq6h[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + } + } for (int iy = 0; iy < nrc_y; ++iy) { - auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]); + auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-16.f), acc[2*iy+1], acc[2*iy+0]); acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); @@ -12087,7 +12215,6 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, Q8 q8(info); Dequantizer deq(vx, bx); int nb = n / QK4_NL; - GGML_ASSERT(nb%4 == 0); int8x16_t qx[8]; float d8[4*nrc_y]; float32x4_t acc[nrc_y] = {}; @@ -12098,7 +12225,7 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d))); } for (int k = 0; k < 4; ++k) { - auto scales = deq.prepare(ib4, k, qx); + auto scales = deq.prepare(4*ib4+k, qx); for (int iy = 0; iy < nrc_y; ++iy) { auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); auto sumi = interleaved_dotq(qx, y); @@ -12107,6 +12234,16 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = deq.prepare(ib, qx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_0 *)q8.y[iy]; + auto y = vld1q_s8_x2(qy[ib].qs); + auto sumi = interleaved_dotq(qx, y); + auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d))); + acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); + } + } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, deq.result(acc[iy])); acc[iy] = vdupq_n_f32(0.f); @@ -12164,9 +12301,9 @@ void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, struct IQ4_NL_R4_Dequantizer { IQ4_NL_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx), values(vld1q_s8(iq4k_values)) {} inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); } - inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d)); - auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs); + inline float32x4_t prepare(int ib, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ib].d)); + auto bits = vld1q_u8_x4(iq4[ib].qs); prepare_iq4_nl_quants(values, m4, bits, qx); return scales; } @@ -12242,10 +12379,10 @@ struct Q4_0_R8_Dequantizer { struct Q5_0_R4_Dequantizer { Q5_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} inline void new_row(int ix) { iq5 = (const block_q5_0_r4 *)(cx + ix*bx); } - inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[4*ib4+k].d)); - auto lbits = vld1q_u8_x4(iq5[4*ib4+k].qs); - auto hbits = vld1q_u8(iq5[4*ib4+k].qh); + inline float32x4_t prepare(int ib, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ib].d)); + auto lbits = vld1q_u8_x4(iq5[ib].qs); + auto hbits = vld1q_u8(iq5[ib].qh); qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3 qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19 qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7 @@ -12271,10 +12408,10 @@ struct Q5_0_R4_Dequantizer { struct Q6_0_R4_Dequantizer { Q6_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} inline void new_row(int ix) { iq6 = (const block_q6_0_r4 *)(cx + ix*bx); } - inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[4*ib4+k].d)); - auto lbits = vld1q_u8_x4(iq6[4*ib4+k].qs); - auto hbits = vld1q_u8_x2(iq6[4*ib4+k].qh); + inline float32x4_t prepare(int ib, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[ib].d)); + auto lbits = vld1q_u8_x4(iq6[ib].qs); + auto hbits = vld1q_u8_x2(iq6[ib].qh); qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 4), m6), m32); // 0...3 qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 4), m6), m32); // 16..19 qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 2), m6), m32); // 4...7 diff --git a/src/llama.cpp b/src/llama.cpp index b6a4a06d..570c056c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -16075,7 +16075,10 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || is_iq2_m ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; ++qs.i_attention_wv; } - else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) { + else if (qs.model.hparams.n_expert >= 8 && name.find("attn_k") != std::string::npos) { + new_type = GGML_TYPE_Q4_K; + } + else if (qs.model.hparams.n_expert >= 8 && name.find("attn_q") != std::string::npos) { new_type = GGML_TYPE_Q4_K; } else if (name.find("attn_qkv.weight") != std::string::npos) { @@ -16088,7 +16091,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ++qs.i_ffn_down; } else if (name.find("attn_output.weight") != std::string::npos) { - if (qs.model.hparams.n_expert == 8) { + if (qs.model.hparams.n_expert >= 4) { new_type = GGML_TYPE_Q5_K; } else { if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_K; @@ -16188,9 +16191,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (new_type == GGML_TYPE_Q5_K) new_type = GGML_TYPE_Q6_K; } ++qs.i_attention_wv; - } else if (name.find("attn_k.weight") != std::string::npos) { + } else if (name.find("attn_k") != std::string::npos) { if (qs.params->attn_k_type < GGML_TYPE_COUNT) new_type = qs.params->attn_k_type; - else if (qs.model.hparams.n_expert == 8) { + else if (qs.model.hparams.n_expert >= 8) { // for the 8-expert model, bumping this to Q8_0 trades just ~128MB // TODO: explore better strategies new_type = GGML_TYPE_Q8_0; @@ -16201,8 +16204,13 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4) { new_type = GGML_TYPE_IQ2_S; } - } else if (name.find("attn_q.weight") != std::string::npos) { + } else if (name.find("attn_q") != std::string::npos) { if (qs.params->attn_q_type < GGML_TYPE_COUNT) new_type = qs.params->attn_q_type; + else if (qs.model.hparams.n_expert >= 8) { + // for the 8-expert model, bumping this to Q8_0 trades just ~128MB + // TODO: explore better strategies + new_type = GGML_TYPE_Q8_0; + } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) { new_type = GGML_TYPE_IQ3_XXS; } From 8b7536bda8b65107794c4df710f14ddfde430160 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 5 Feb 2025 13:49:39 +0200 Subject: [PATCH 06/14] IQ1_S_R4: better 1.5 bpw quants (#185) * iq1_s_r4: basics - quantize/dequantize * iq1_s_r4: gemm/gemv works on AVX2/Zen4 * Don't forget to make sure we have a multiple of 4 rows per thread * iq1_s_r4: this is better * iq1_s_r4: fix Zen4 after AVX2 changes * iq1_s_r4: NEON gemm/gemv * iq1_s_r4: more bits for shared experts With this mix we arrive at PPL(512) = 9.4140 for Deepseek-Lite using 1.766 bpw for the repeating layers. On the Ryzen-7950X we get PP-512 = 494 t/s and TG-128 = 52 t/s @ 16 threads. * Forgotten counter increment * iq1_s_r4: slightly faster AVX2/Zen4 gemm/gemv * Compiler warnings --------- Co-authored-by: Iwan Kawrakow --- examples/quantize/quantize.cpp | 2 + ggml/include/ggml.h | 2 + ggml/src/ggml-common.h | 6 + ggml/src/ggml-quants.c | 289 ++++++++++---- ggml/src/ggml-quants.h | 5 + ggml/src/ggml.c | 27 +- ggml/src/iqk/iqk_mul_mat.cpp | 703 ++++++++++++++++++++++++++++++++- ggml/src/iqk/iqk_quantize.cpp | 106 +++++ ggml/src/iqk/iqk_quantize.h | 6 + include/llama.h | 1 + src/llama.cpp | 50 ++- 11 files changed, 1104 insertions(+), 93 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 5ffdbc84..1c847e6b 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -29,6 +29,7 @@ static const std::vector QUANT_OPTIONS = { { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, { "IQ2_M_R4", LLAMA_FTYPE_MOSTLY_IQ2_M_R4, " 2.7 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, + { "IQ1_S_R4", LLAMA_FTYPE_MOSTLY_IQ1_S_R4, " 1.5 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, { "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.62 bpw quantization (Bitnet)", }, { "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", }, @@ -510,6 +511,7 @@ int main(int argc, char ** argv) { params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 || params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS_R4 || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || + params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4 || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_M)) { fprintf(stderr, "\n==========================================================================================================\n"); fprintf(stderr, "Please do not use IQ1_S, IQ1_M, IQ2_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n"); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 5eea7dcd..9668dc32 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -427,6 +427,7 @@ extern "C" { GGML_TYPE_IQ2_XXS_R4= 216, GGML_TYPE_IQ2_XS_R4 = 217, GGML_TYPE_IQ3_XXS_R4= 218, + GGML_TYPE_IQ1_S_R4 = 219, GGML_TYPE_IQ4_NL_R4 = 220, GGML_TYPE_IQ3_S_R4 = 221, GGML_TYPE_IQ2_S_R4 = 222, @@ -510,6 +511,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ2_XXS_R4= 215, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_XS_R4 = 216, // except 1d tensors GGML_FTYPE_MOSTLY_IQ3_XXS_R4= 217, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ1_S_R4 = 218, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_NL_R4 = 219, // except 1d tensors GGML_FTYPE_MOSTLY_IQ3_S_R4 = 220, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_S_R4 = 221, // except 1d tensors diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 023b0b63..14813161 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -485,6 +485,12 @@ typedef struct { } block_iq1_s; static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding"); +typedef struct { + uint8_t qs[16]; + uint16_t qh[4]; +} block_iq1_s_r4; +static_assert(sizeof(block_iq1_s_r4) == 24, "wrong iq1_s_r4 block size/padding"); + // 1.75 bpw typedef struct { uint8_t qs[QK_K/8]; // grid index, low 8 bits diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 391d9e2e..3c4711f3 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -13991,6 +13991,105 @@ static int iq1_sort_helper(const void * left, const void * right) { return *l < *r ? -1 : *l > *r ? 1 : 0; } +void iq1s_process_1block(int block_size, const float * xb, const float * weight, int8_t * L, float * the_scale, uint16_t * the_index, int * the_shift, + float * pairs, float * sumx, float * sumw) { + float max = fabsf(xb[0]); + for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); + if (max < GROUP_MAX_EPS_IQ1_S) { + *the_scale = 0; + *the_shift = 1; + for (int k = 0; k < block_size/8; ++k) the_index[k] = 0; + return; + } + const int gindex = iq2_data_index(GGML_TYPE_IQ1_S); + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; + + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + + const float x_p[3] = {-1 + IQ1S_DELTA, IQ1S_DELTA, 1 + IQ1S_DELTA}; + const float x_m[3] = {-1 - IQ1S_DELTA, -IQ1S_DELTA, 1 - IQ1S_DELTA}; + + // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem. + // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two + // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights + // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and + // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale + // for each possible and score for each split. + int * idx = (int *)(pairs + 1); + for (int j = 0; j < block_size; ++j) { + pairs[2*j] = xb[j]; + idx[2*j] = j; + } + qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); + { + sumx[0] = sumw[0] = 0; + for (int j = 0; j < block_size; ++j) { + int i = idx[2*j]; + sumx[j+1] = sumx[j] + weight[i]*xb[i]; + sumw[j+1] = sumw[j] + weight[i]; + } + } + float best_score = -FLT_MIN, scale = max; + int besti1 = -1, besti2 = -1, best_shift = 0; + for (int i1 = 0; i1 <= block_size; ++i1) { + for (int i2 = i1; i2 <= block_size; ++i2) { + float sumqx = (sumx[i1] - sumx[0])*x_p[0] + (sumx[i2] - sumx[i1])*x_p[1] + (sumx[block_size] - sumx[i2])*x_p[2]; + float sumq2 = (sumw[i1] - sumw[0])*x_p[0]*x_p[0] + (sumw[i2] - sumw[i1])*x_p[1]*x_p[1] + (sumw[block_size] - sumw[i2])*x_p[2]*x_p[2]; + if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { + scale = sumqx/sumq2; best_score = scale*sumqx; + besti1 = i1; besti2 = i2; best_shift = 1; + } + sumqx = (sumx[i1] - sumx[0])*x_m[0] + (sumx[i2] - sumx[i1])*x_m[1] + (sumx[block_size] - sumx[i2])*x_m[2]; + sumq2 = (sumw[i1] - sumw[0])*x_m[0]*x_m[0] + (sumw[i2] - sumw[i1])*x_m[1]*x_m[1] + (sumw[block_size] - sumw[i2])*x_m[2]*x_m[2]; + if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { + scale = sumqx/sumq2; best_score = scale*sumqx; + besti1 = i1; besti2 = i2; best_shift = -1; + } + } + } + GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0); + for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; + for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; + for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; + if (scale < 0) { + for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j]; + scale = -scale; best_shift = -best_shift; + } + bool all_on_grid = true; + const float * xx = best_shift == 1 ? x_p : x_m; + for (int k = 0; k < block_size/8; ++k) { + uint16_t u = 0; + for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + all_on_grid = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S); + GGML_ASSERT(grid_index >= 0); + } + the_index[k] = grid_index; + } + if (!all_on_grid) { + float sumqx = 0, sumq2 = 0; + for (int k = 0; k < block_size/8; ++k) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + the_index[k]); + for (int j = 0; j < 8; ++j) { + float w = weight[8*k + j]; + float q = xx[(pg[j] - 1)/2]; + sumqx += w*q*xb[8*k+j]; + sumq2 += w*q*q; + } + } + if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2; + } + *the_scale = scale; + *the_shift = best_shift; +} + #define IQ1S_BLOCK_SIZE 32 #define IQ1M_BLOCK_SIZE 16 static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights, @@ -14021,11 +14120,10 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy const int block_size = IQ1S_BLOCK_SIZE; - const float x_p[3] = {-1 + IQ1S_DELTA, IQ1S_DELTA, 1 + IQ1S_DELTA}; - const float x_m[3] = {-1 - IQ1S_DELTA, -IQ1S_DELTA, 1 - IQ1S_DELTA}; + //const float x_p[3] = {-1 + IQ1S_DELTA, IQ1S_DELTA, 1 + IQ1S_DELTA}; + //const float x_m[3] = {-1 - IQ1S_DELTA, -IQ1S_DELTA, 1 - IQ1S_DELTA}; - - int * idx = (int *)(pairs + 1); + //int * idx = (int *)(pairs + 1); for (int ibl = 0; ibl < nbl; ++ibl) { @@ -14044,95 +14142,100 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy const float * xb = xbl + block_size*ib; const float * qw = quant_weights + QK_K*ibl + block_size*ib; for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); - float max = fabsf(xb[0]); - for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); - if (max < GROUP_MAX_EPS_IQ1_S) { - scales[ib] = 0; - memset(L, 1, block_size); - continue; - } - // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem. - // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two - // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights - // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and - // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale - // for each possible and score for each split. - for (int j = 0; j < block_size; ++j) { - pairs[2*j] = xb[j]; - idx[2*j] = j; - } - qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); - { - sumx[0] = sumw[0] = 0; - for (int j = 0; j < block_size; ++j) { - int i = idx[2*j]; - sumx[j+1] = sumx[j] + weight[i]*xb[i]; - sumw[j+1] = sumw[j] + weight[i]; - } - } - float best_score = -FLT_MIN, scale = max; - int besti1 = -1, besti2 = -1, best_shift = 0; - for (int i1 = 0; i1 <= block_size; ++i1) { - for (int i2 = i1; i2 <= block_size; ++i2) { - float sumqx = (sumx[i1] - sumx[0])*x_p[0] + (sumx[i2] - sumx[i1])*x_p[1] + (sumx[block_size] - sumx[i2])*x_p[2]; - float sumq2 = (sumw[i1] - sumw[0])*x_p[0]*x_p[0] + (sumw[i2] - sumw[i1])*x_p[1]*x_p[1] + (sumw[block_size] - sumw[i2])*x_p[2]*x_p[2]; - if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { - scale = sumqx/sumq2; best_score = scale*sumqx; - besti1 = i1; besti2 = i2; best_shift = 1; - } - sumqx = (sumx[i1] - sumx[0])*x_m[0] + (sumx[i2] - sumx[i1])*x_m[1] + (sumx[block_size] - sumx[i2])*x_m[2]; - sumq2 = (sumw[i1] - sumw[0])*x_m[0]*x_m[0] + (sumw[i2] - sumw[i1])*x_m[1]*x_m[1] + (sumw[block_size] - sumw[i2])*x_m[2]*x_m[2]; - if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { - scale = sumqx/sumq2; best_score = scale*sumqx; - besti1 = i1; besti2 = i2; best_shift = -1; - } - } - } - GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0); - for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; - for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; - for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; - if (scale < 0) { - for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j]; - scale = -scale; best_shift = -best_shift; - } - bool all_on_grid = true; - const float * xx = best_shift == 1 ? x_p : x_m; - for (int k = 0; k < block_size/8; ++k) { - uint16_t u = 0; - for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j); - int grid_index = kmap_q2xs[u]; - if (grid_index < 0) { - all_on_grid = false; - const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; - grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S); - GGML_ASSERT(grid_index >= 0); - } - index[k] = grid_index; - } - if (!all_on_grid) { - float sumqx = 0, sumq2 = 0; - for (int k = 0; k < block_size/8; ++k) { - const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]); - for (int j = 0; j < 8; ++j) { - float w = weight[8*k + j]; - float q = xx[(pg[j] - 1)/2]; - sumqx += w*q*xb[8*k+j]; - sumq2 += w*q*q; - } - } - if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2; - } + int best_shift; + iq1s_process_1block(block_size, xb, weight, L, &scales[ib], index, &best_shift, pairs, sumx, sumw); + +// float max = fabsf(xb[0]); +// for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); +// if (max < GROUP_MAX_EPS_IQ1_S) { +// scales[ib] = 0; +// memset(L, 1, block_size); +// continue; +// } +// // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem. +// // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two +// // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights +// // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and +// // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale +// // for each possible and score for each split. +// for (int j = 0; j < block_size; ++j) { +// pairs[2*j] = xb[j]; +// idx[2*j] = j; +// } +// qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); +// { +// sumx[0] = sumw[0] = 0; +// for (int j = 0; j < block_size; ++j) { +// int i = idx[2*j]; +// sumx[j+1] = sumx[j] + weight[i]*xb[i]; +// sumw[j+1] = sumw[j] + weight[i]; +// } +// } +// float best_score = -FLT_MIN, scale = max; +// int besti1 = -1, besti2 = -1, best_shift = 0; +// for (int i1 = 0; i1 <= block_size; ++i1) { +// for (int i2 = i1; i2 <= block_size; ++i2) { +// float sumqx = (sumx[i1] - sumx[0])*x_p[0] + (sumx[i2] - sumx[i1])*x_p[1] + (sumx[block_size] - sumx[i2])*x_p[2]; +// float sumq2 = (sumw[i1] - sumw[0])*x_p[0]*x_p[0] + (sumw[i2] - sumw[i1])*x_p[1]*x_p[1] + (sumw[block_size] - sumw[i2])*x_p[2]*x_p[2]; +// if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { +// scale = sumqx/sumq2; best_score = scale*sumqx; +// besti1 = i1; besti2 = i2; best_shift = 1; +// } +// sumqx = (sumx[i1] - sumx[0])*x_m[0] + (sumx[i2] - sumx[i1])*x_m[1] + (sumx[block_size] - sumx[i2])*x_m[2]; +// sumq2 = (sumw[i1] - sumw[0])*x_m[0]*x_m[0] + (sumw[i2] - sumw[i1])*x_m[1]*x_m[1] + (sumw[block_size] - sumw[i2])*x_m[2]*x_m[2]; +// if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { +// scale = sumqx/sumq2; best_score = scale*sumqx; +// besti1 = i1; besti2 = i2; best_shift = -1; +// } +// } +// } +// GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0); +// for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; +// for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; +// for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; +// if (scale < 0) { +// for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j]; +// scale = -scale; best_shift = -best_shift; +// } +// bool all_on_grid = true; +// const float * xx = best_shift == 1 ? x_p : x_m; +// for (int k = 0; k < block_size/8; ++k) { +// uint16_t u = 0; +// for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j); +// int grid_index = kmap_q2xs[u]; +// if (grid_index < 0) { +// all_on_grid = false; +// const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; +// grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S); +// GGML_ASSERT(grid_index >= 0); +// } +// index[k] = grid_index; +// } +// if (!all_on_grid) { +// float sumqx = 0, sumq2 = 0; +// for (int k = 0; k < block_size/8; ++k) { +// const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]); +// for (int j = 0; j < 8; ++j) { +// float w = weight[8*k + j]; +// float q = xx[(pg[j] - 1)/2]; +// sumqx += w*q*xb[8*k+j]; +// sumq2 += w*q*q; +// } +// } +// if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2; +// } uint16_t h = 0; for (int k = 0; k < block_size/8; ++k) { y[ibl].qs[(block_size/8)*ib + k] = index[k] & 255; h |= (index[k] >> 8) << 3*k; } y[ibl].qh[ib] = h; - GGML_ASSERT(scale >= 0); - scales[ib] = scale; + GGML_ASSERT(scales[ib] >= 0); + max_scale = MAX(max_scale, scales[ib]); + //GGML_ASSERT(scale >= 0); + //scales[ib] = scale; shifts[ib] = best_shift; - max_scale = MAX(max_scale, scale); + //max_scale = MAX(max_scale, scale); } if (!max_scale) { @@ -14171,6 +14274,19 @@ size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int64_t n return nrow * nblock * sizeof(block_iq1_s); } +void quantize_row_iq1_s_ref (const float * GGML_RESTRICT x, block_iq1_s * GGML_RESTRICT y, int64_t k) { + int nblock = k/QK_K; + float qw[QK_K]; + for (int j = 0; j < QK_K; ++j) qw[j] = 1; + for (int ibl = 0; ibl < nblock; ++ibl) { + quantize_iq1_s(x + ibl*QK_K, &y[ibl], 1, QK_K, qw); + } +} + +void quantize_row_iq1_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_iq1_s_ref(x, (block_iq1_s *)y, k); +} + static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights, float * scales, float * weight, @@ -15129,6 +15245,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_IQ3_XXS_R4: break; case GGML_TYPE_IQ3_S_R4: break; case GGML_TYPE_IQ2_S_R4: break; + case GGML_TYPE_IQ1_S_R4: break; case GGML_TYPE_Q4_0_R4: break; case GGML_TYPE_Q5_0_R4: break; case GGML_TYPE_Q6_0_R4: break; diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index b6d69011..4753f342 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -42,6 +42,7 @@ void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGM void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); void quantize_row_iq1_bn_ref (const float * GGML_RESTRICT x, block_iq1_bn * GGML_RESTRICT y, int64_t k); +void quantize_row_iq1_s_ref (const float * GGML_RESTRICT x, block_iq1_s * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -66,6 +67,7 @@ void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq1_bn (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq1_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dequantization void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -148,6 +150,9 @@ void iq2xs_free_impl(enum ggml_type type); void iq3xs_init_impl(int grid_size); void iq3xs_free_impl(int grid_size); +void iq1s_process_1block(int block_size, const float * xb, const float * weight, int8_t * L, + float * the_scale, uint16_t * the_index, int * the_shift, float * pairs, float * sumx, float * sumw); + #if defined(__ARM_FEATURE_SVE) extern int ggml_sve_cnt_b; #endif diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b3c8a951..64b7d3ce 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1176,13 +1176,26 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq1_s), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq1_s, - .from_float = NULL, - .from_float_ref = NULL, + .from_float = quantize_row_iq1_s, + .from_float_ref = (ggml_from_float_t)quantize_row_iq1_s_ref, .vec_dot = ggml_vec_dot_iq1_s_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_IQ1_S_R4] = { + .type_name = "iq1_s_r4", + .blck_size = 32, + .type_size = sizeof(block_iq1_s_r4)/4, + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq1_s_r4, + .from_float = quantize_row_iq1_s_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_iq1_s_r4_ref, + .vec_dot = vec_dot_iq1_s_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_1_X4, + .nrows = 1, + .row_meta_size = 2, + }, [GGML_TYPE_IQ1_M] = { .type_name = "iq1_m", .blck_size = QK_K, @@ -4387,6 +4400,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ3_S_R4: wtype = GGML_TYPE_IQ3_S_R4; break; case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; case GGML_FTYPE_MOSTLY_IQ2_S_R4: wtype = GGML_TYPE_IQ2_S_R4; break; + case GGML_FTYPE_MOSTLY_IQ1_S_R4: wtype = GGML_TYPE_IQ1_S_R4; break; case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break; case GGML_FTYPE_MOSTLY_Q4_0_4_8: wtype = GGML_TYPE_Q4_0_4_8; break; case GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = GGML_TYPE_Q4_0_8_8; break; @@ -10934,6 +10948,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S_R4: + case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -11402,6 +11417,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S_R4: + case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -11567,6 +11583,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S_R4: + case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -14805,6 +14822,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S_R4: + case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -15210,6 +15228,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S_R4: + case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -15509,6 +15528,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S_R4: + case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -16137,6 +16157,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S_R4: + case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_Q8_K: case GGML_TYPE_Q8_K64: case GGML_TYPE_Q8_K16: @@ -22893,6 +22914,7 @@ void ggml_quantize_init(enum ggml_type type) { case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: iq2xs_init_impl(type); break; + case GGML_TYPE_IQ1_S_R4:iq2xs_init_impl(GGML_TYPE_IQ1_S); break; case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break; case GGML_TYPE_IQ3_S_R4: @@ -22975,6 +22997,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ3_S_R4:result = quantize_iq3_s_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_S: result = quantize_iq2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_S_R4:result = quantize_iq2_s_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ1_S_R4:result = quantize_iq1_s_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ1_BN: result = quantize_iq1_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index f633229d..559cff05 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -259,6 +259,7 @@ struct MulMat { case GGML_TYPE_IQ2_XS_R4: case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_IQ3_XXS_R4: + case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ3_S_R4: return 4; case GGML_TYPE_IQ4_NL_R4: case GGML_TYPE_Q5_0_R4: @@ -293,6 +294,7 @@ struct MulMat { case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ3_S_R4: + case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ2_BN_R4: return 4; case GGML_TYPE_IQ4_XS_R4: case GGML_TYPE_Q4_0_R4: @@ -375,6 +377,523 @@ inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) { aux32[0] = a0 & 0x3f3f3f3f; } +#ifdef __AVX2__ +static const uint64_t iq1s_grid_us[2048] = { + 0x0000000000000000, 0x0000000000000002, 0x0000000000000101, 0x0000000000000200, + 0x0000000000000202, 0x0000000000010001, 0x0000000000010101, 0x0000000000020000, + 0x0000000000020002, 0x0000000000020200, 0x0000000000020202, 0x0000000001000101, + 0x0000000001010001, 0x0000000001010100, 0x0000000001010102, 0x0000000001020101, + 0x0000000002000000, 0x0000000002000002, 0x0000000002000200, 0x0000000002000202, + 0x0000000002010101, 0x0000000002020000, 0x0000000002020002, 0x0000000002020200, + 0x0000000002020202, 0x0000000100000100, 0x0000000100000101, 0x0000000100010001, + 0x0000000100010100, 0x0000000100010102, 0x0000000100010201, 0x0000000100010202, + 0x0000000100020101, 0x0000000101000001, 0x0000000101000102, 0x0000000101000201, + 0x0000000101010002, 0x0000000101010101, 0x0000000101010202, 0x0000000101020001, + 0x0000000101020100, 0x0000000101020102, 0x0000000101020200, 0x0000000102000101, + 0x0000000102010001, 0x0000000102010100, 0x0000000102010102, 0x0000000102020101, + 0x0000000200000000, 0x0000000200000002, 0x0000000200000200, 0x0000000200000202, + 0x0000000200010101, 0x0000000200020000, 0x0000000200020002, 0x0000000200020200, + 0x0000000200020202, 0x0000000201000101, 0x0000000201010001, 0x0000000201010201, + 0x0000000201020100, 0x0000000201020201, 0x0000000202000000, 0x0000000202000002, + 0x0000000202000200, 0x0000000202000202, 0x0000000202010001, 0x0000000202010101, + 0x0000000202010201, 0x0000000202020000, 0x0000000202020002, 0x0000000202020200, + 0x0000000202020202, 0x0000010000010001, 0x0000010000010100, 0x0000010000010102, + 0x0000010000020101, 0x0000010001000001, 0x0000010001000201, 0x0000010001010101, + 0x0000010001010202, 0x0000010001020100, 0x0000010001020101, 0x0000010002010001, + 0x0000010002010201, 0x0000010002020101, 0x0000010100000001, 0x0000010100000100, + 0x0000010100000101, 0x0000010100000102, 0x0000010100010101, 0x0000010100010200, + 0x0000010100010202, 0x0000010100020201, 0x0000010101000000, 0x0000010101000101, + 0x0000010101000202, 0x0000010101010000, 0x0000010101010001, 0x0000010101010100, + 0x0000010101010101, 0x0000010101010102, 0x0000010101010201, 0x0000010101020000, + 0x0000010101020002, 0x0000010101020101, 0x0000010101020200, 0x0000010101020202, + 0x0000010102000001, 0x0000010102010001, 0x0000010102010101, 0x0000010102010200, + 0x0000010102010202, 0x0000010102020001, 0x0000010102020100, 0x0000010102020101, + 0x0000010102020102, 0x0000010102020201, 0x0000010200010100, 0x0000010200010201, + 0x0000010201000001, 0x0000010201000100, 0x0000010201010000, 0x0000010201010002, + 0x0000010201010101, 0x0000010201010200, 0x0000010201020000, 0x0000010201020001, + 0x0000010201020102, 0x0000010201020201, 0x0000010202000101, 0x0000010202010001, + 0x0000010202010100, 0x0000010202010201, 0x0000020000000000, 0x0000020000000002, + 0x0000020000000200, 0x0000020000000202, 0x0000020000010101, 0x0000020000020000, + 0x0000020000020002, 0x0000020000020200, 0x0000020000020202, 0x0000020001000101, + 0x0000020001010001, 0x0000020001010102, 0x0000020001020101, 0x0000020002000000, + 0x0000020002000002, 0x0000020002000200, 0x0000020002000202, 0x0000020002010101, + 0x0000020002020000, 0x0000020002020002, 0x0000020002020200, 0x0000020002020202, + 0x0000020100000101, 0x0000020100010001, 0x0000020100010100, 0x0000020100010201, + 0x0000020100020100, 0x0000020100020101, 0x0000020101000001, 0x0000020101010000, + 0x0000020101010001, 0x0000020101010101, 0x0000020101020001, 0x0000020101020100, + 0x0000020101020201, 0x0000020102010001, 0x0000020102010100, 0x0000020102010102, + 0x0000020102010201, 0x0000020102020101, 0x0000020200000000, 0x0000020200000002, + 0x0000020200000200, 0x0000020200000202, 0x0000020200010101, 0x0000020200020000, + 0x0000020200020002, 0x0000020200020200, 0x0000020200020202, 0x0000020201000101, + 0x0000020201010001, 0x0000020201010201, 0x0000020201020001, 0x0000020201020101, + 0x0000020202000000, 0x0000020202000002, 0x0000020202000101, 0x0000020202000200, + 0x0000020202000202, 0x0000020202010101, 0x0000020202020000, 0x0000020202020002, + 0x0000020202020200, 0x0000020202020202, 0x0001000000010000, 0x0001000000010001, + 0x0001000000010100, 0x0001000000010201, 0x0001000000020100, 0x0001000000020101, + 0x0001000001000001, 0x0001000001000100, 0x0001000001010000, 0x0001000001010101, + 0x0001000001010200, 0x0001000001020001, 0x0001000001020100, 0x0001000001020101, + 0x0001000001020201, 0x0001000002010001, 0x0001000002010100, 0x0001000002010102, + 0x0001000002020001, 0x0001000002020101, 0x0001000100000001, 0x0001000100000100, + 0x0001000100000102, 0x0001000100000201, 0x0001000100010000, 0x0001000100010002, + 0x0001000100010101, 0x0001000100010200, 0x0001000100020001, 0x0001000100020100, + 0x0001000100020201, 0x0001000101000101, 0x0001000101000202, 0x0001000101010000, + 0x0001000101010001, 0x0001000101010002, 0x0001000101010100, 0x0001000101010101, + 0x0001000101010102, 0x0001000101010201, 0x0001000101020000, 0x0001000101020101, + 0x0001000102000100, 0x0001000102010002, 0x0001000102010101, 0x0001000102020001, + 0x0001000102020100, 0x0001000200010001, 0x0001000200010100, 0x0001000200010102, + 0x0001000200020101, 0x0001000201000000, 0x0001000201000102, 0x0001000201000201, + 0x0001000201010002, 0x0001000201010101, 0x0001000201010200, 0x0001000201010202, + 0x0001000201020100, 0x0001000201020102, 0x0001000202000101, 0x0001000202010001, + 0x0001000202010100, 0x0001000202010102, 0x0001000202020101, 0x0001010000000001, + 0x0001010000000102, 0x0001010000000201, 0x0001010000010100, 0x0001010000010101, + 0x0001010000010200, 0x0001010000010201, 0x0001010000020001, 0x0001010000020102, + 0x0001010001000001, 0x0001010001000101, 0x0001010001000102, 0x0001010001000200, + 0x0001010001000202, 0x0001010001010001, 0x0001010001010100, 0x0001010001010101, + 0x0001010001010102, 0x0001010001010201, 0x0001010001020002, 0x0001010001020101, + 0x0001010001020200, 0x0001010002000100, 0x0001010002000201, 0x0001010002010000, + 0x0001010002010100, 0x0001010002010101, 0x0001010002010200, 0x0001010002010201, + 0x0001010002010202, 0x0001010002020001, 0x0001010002020100, 0x0001010002020101, + 0x0001010002020201, 0x0001010100000002, 0x0001010100000101, 0x0001010100000202, + 0x0001010100010001, 0x0001010100010100, 0x0001010100010101, 0x0001010100010102, + 0x0001010100010201, 0x0001010100020000, 0x0001010100020002, 0x0001010100020101, + 0x0001010100020200, 0x0001010100020202, 0x0001010101000001, 0x0001010101000100, + 0x0001010101000101, 0x0001010101000102, 0x0001010101010001, 0x0001010101010002, + 0x0001010101010100, 0x0001010101010101, 0x0001010101010102, 0x0001010101010201, + 0x0001010101010202, 0x0001010101020001, 0x0001010101020100, 0x0001010101020101, + 0x0001010101020102, 0x0001010101020201, 0x0001010102000000, 0x0001010102000002, + 0x0001010102000100, 0x0001010102000101, 0x0001010102000200, 0x0001010102000202, + 0x0001010102010000, 0x0001010102010001, 0x0001010102010100, 0x0001010102010101, + 0x0001010102010102, 0x0001010102010201, 0x0001010102010202, 0x0001010102020000, + 0x0001010102020002, 0x0001010102020101, 0x0001010200000001, 0x0001010200000100, + 0x0001010200000101, 0x0001010200000102, 0x0001010200010101, 0x0001010200010102, + 0x0001010200010200, 0x0001010200010202, 0x0001010200020001, 0x0001010200020102, + 0x0001010201000000, 0x0001010201000002, 0x0001010201000100, 0x0001010201000101, + 0x0001010201000200, 0x0001010201000202, 0x0001010201010001, 0x0001010201010101, + 0x0001010201010102, 0x0001010201010200, 0x0001010201010201, 0x0001010201020001, + 0x0001010201020100, 0x0001010201020101, 0x0001010201020200, 0x0001010201020201, + 0x0001010201020202, 0x0001010202000102, 0x0001010202000202, 0x0001010202010002, + 0x0001010202010101, 0x0001010202020100, 0x0001010202020201, 0x0001020000010001, + 0x0001020000010102, 0x0001020000020101, 0x0001020001000001, 0x0001020001000100, + 0x0001020001000102, 0x0001020001000201, 0x0001020001010000, 0x0001020001010101, + 0x0001020001010200, 0x0001020001010202, 0x0001020001020000, 0x0001020001020001, + 0x0001020001020100, 0x0001020001020102, 0x0001020001020201, 0x0001020002000101, + 0x0001020002010001, 0x0001020002010100, 0x0001020002020101, 0x0001020100010000, + 0x0001020100010002, 0x0001020100010101, 0x0001020100010202, 0x0001020100020001, + 0x0001020100020101, 0x0001020101000002, 0x0001020101000100, 0x0001020101000101, + 0x0001020101000200, 0x0001020101010001, 0x0001020101010100, 0x0001020101010101, + 0x0001020101010102, 0x0001020101010201, 0x0001020101010202, 0x0001020101020000, + 0x0001020101020101, 0x0001020101020202, 0x0001020102000201, 0x0001020102010001, + 0x0001020102010002, 0x0001020102010101, 0x0001020102010200, 0x0001020102020001, + 0x0001020102020102, 0x0001020102020201, 0x0001020200000201, 0x0001020200010102, + 0x0001020200020100, 0x0001020200020102, 0x0001020201000100, 0x0001020201000102, + 0x0001020201000201, 0x0001020201010000, 0x0001020201010002, 0x0001020201010101, + 0x0001020201010200, 0x0001020201020001, 0x0001020201020102, 0x0001020201020201, + 0x0001020202000101, 0x0001020202010001, 0x0001020202010102, 0x0001020202010202, + 0x0002000000000000, 0x0002000000000002, 0x0002000000000200, 0x0002000000000202, + 0x0002000000010101, 0x0002000000020000, 0x0002000000020002, 0x0002000000020101, + 0x0002000000020200, 0x0002000000020202, 0x0002000001000101, 0x0002000001010001, + 0x0002000001010201, 0x0002000001020001, 0x0002000001020101, 0x0002000002000000, + 0x0002000002000002, 0x0002000002000200, 0x0002000002000202, 0x0002000002010101, + 0x0002000002020000, 0x0002000002020002, 0x0002000002020101, 0x0002000002020200, + 0x0002000002020202, 0x0002000100000101, 0x0002000100010001, 0x0002000100010100, + 0x0002000100010201, 0x0002000100020101, 0x0002000101000002, 0x0002000101000100, + 0x0002000101000201, 0x0002000101010101, 0x0002000101010200, 0x0002000101010202, + 0x0002000101020001, 0x0002000101020100, 0x0002000101020101, 0x0002000101020102, + 0x0002000102000101, 0x0002000102010000, 0x0002000102010102, 0x0002000102010201, + 0x0002000102020101, 0x0002000200000001, 0x0002000200000200, 0x0002000200000202, + 0x0002000200010001, 0x0002000200010101, 0x0002000200020000, 0x0002000200020002, + 0x0002000200020200, 0x0002000200020202, 0x0002000201000101, 0x0002000201010001, + 0x0002000201010102, 0x0002000201010201, 0x0002000201020101, 0x0002000202000001, + 0x0002000202000200, 0x0002000202000202, 0x0002000202010001, 0x0002000202010101, + 0x0002000202020000, 0x0002000202020002, 0x0002000202020200, 0x0002000202020202, + 0x0002010000000101, 0x0002010000010100, 0x0002010000010102, 0x0002010000010201, + 0x0002010000020101, 0x0002010001000100, 0x0002010001000101, 0x0002010001000102, + 0x0002010001000201, 0x0002010001010002, 0x0002010001010101, 0x0002010001010200, + 0x0002010001010202, 0x0002010001020102, 0x0002010002000101, 0x0002010002010001, + 0x0002010002010100, 0x0002010002010201, 0x0002010002020001, 0x0002010002020101, + 0x0002010100000201, 0x0002010100010101, 0x0002010100020001, 0x0002010100020201, + 0x0002010101000000, 0x0002010101000101, 0x0002010101000200, 0x0002010101010001, + 0x0002010101010100, 0x0002010101010101, 0x0002010101010201, 0x0002010101020002, + 0x0002010101020101, 0x0002010101020200, 0x0002010102000201, 0x0002010102010000, + 0x0002010102010100, 0x0002010102010101, 0x0002010102010200, 0x0002010102010202, + 0x0002010102020001, 0x0002010102020100, 0x0002010102020102, 0x0002010102020201, + 0x0002010200000101, 0x0002010200010000, 0x0002010200010002, 0x0002010200010201, + 0x0002010200020101, 0x0002010201000001, 0x0002010201000201, 0x0002010201010101, + 0x0002010201020000, 0x0002010201020001, 0x0002010201020201, 0x0002010202000100, + 0x0002010202000102, 0x0002010202010000, 0x0002010202010202, 0x0002020000000000, + 0x0002020000000002, 0x0002020000000200, 0x0002020000000202, 0x0002020000010101, + 0x0002020000020000, 0x0002020000020002, 0x0002020000020200, 0x0002020000020202, + 0x0002020001000101, 0x0002020001010001, 0x0002020001010100, 0x0002020001020101, + 0x0002020002000000, 0x0002020002000002, 0x0002020002000200, 0x0002020002000202, + 0x0002020002020000, 0x0002020002020002, 0x0002020002020200, 0x0002020002020202, + 0x0002020100000201, 0x0002020100010001, 0x0002020100010100, 0x0002020100010201, + 0x0002020100020101, 0x0002020101000102, 0x0002020101000201, 0x0002020101010002, + 0x0002020101010101, 0x0002020101020001, 0x0002020101020100, 0x0002020101020102, + 0x0002020101020201, 0x0002020102000101, 0x0002020102010000, 0x0002020102010102, + 0x0002020102010201, 0x0002020102020100, 0x0002020102020101, 0x0002020200000000, + 0x0002020200000002, 0x0002020200000200, 0x0002020200000202, 0x0002020200020000, + 0x0002020200020002, 0x0002020200020200, 0x0002020200020202, 0x0002020201000101, + 0x0002020201010001, 0x0002020201010102, 0x0002020201010201, 0x0002020201020101, + 0x0002020202000000, 0x0002020202000002, 0x0002020202000200, 0x0002020202000202, + 0x0002020202010101, 0x0002020202020000, 0x0002020202020002, 0x0002020202020200, + 0x0002020202020202, 0x0100000000000101, 0x0100000000010001, 0x0100000000010102, + 0x0100000000020101, 0x0100000001000201, 0x0100000001010002, 0x0100000001010101, + 0x0100000001010200, 0x0100000001010202, 0x0100000001020001, 0x0100000001020100, + 0x0100000001020102, 0x0100000002010100, 0x0100000002010201, 0x0100000002020001, + 0x0100000002020102, 0x0100000100000000, 0x0100000100000001, 0x0100000100000100, + 0x0100000100000102, 0x0100000100000201, 0x0100000100010002, 0x0100000100010101, + 0x0100000100010102, 0x0100000100010200, 0x0100000100010202, 0x0100000100020001, + 0x0100000100020102, 0x0100000100020201, 0x0100000101000101, 0x0100000101000200, + 0x0100000101000202, 0x0100000101010001, 0x0100000101010100, 0x0100000101010101, + 0x0100000101010102, 0x0100000101010201, 0x0100000101010202, 0x0100000101020101, + 0x0100000101020200, 0x0100000101020202, 0x0100000102000001, 0x0100000102000100, + 0x0100000102000102, 0x0100000102010000, 0x0100000102010002, 0x0100000102010101, + 0x0100000102020000, 0x0100000102020001, 0x0100000102020002, 0x0100000200000101, + 0x0100000200010001, 0x0100000200010100, 0x0100000200010102, 0x0100000200020101, + 0x0100000201000001, 0x0100000201010002, 0x0100000201010101, 0x0100000201010202, + 0x0100000201020100, 0x0100000201020201, 0x0100000202000201, 0x0100000202010100, + 0x0100000202020101, 0x0100010000000001, 0x0100010000010101, 0x0100010000010201, + 0x0100010000020201, 0x0100010001000101, 0x0100010001000200, 0x0100010001000202, + 0x0100010001010001, 0x0100010001010100, 0x0100010001010101, 0x0100010001010102, + 0x0100010001020001, 0x0100010001020002, 0x0100010001020101, 0x0100010001020200, + 0x0100010001020202, 0x0100010002000001, 0x0100010002000102, 0x0100010002000201, + 0x0100010002010000, 0x0100010002010002, 0x0100010002010101, 0x0100010002020000, + 0x0100010002020001, 0x0100010002020201, 0x0100010100000001, 0x0100010100000002, + 0x0100010100000101, 0x0100010100000202, 0x0100010100010001, 0x0100010100010100, + 0x0100010100010101, 0x0100010100010102, 0x0100010100010201, 0x0100010100020000, + 0x0100010100020101, 0x0100010100020202, 0x0100010101000001, 0x0100010101000100, + 0x0100010101000101, 0x0100010101000102, 0x0100010101000201, 0x0100010101010000, + 0x0100010101010001, 0x0100010101010100, 0x0100010101010101, 0x0100010101010102, + 0x0100010101010200, 0x0100010101010201, 0x0100010101020001, 0x0100010101020100, + 0x0100010101020101, 0x0100010101020102, 0x0100010101020201, 0x0100010102000002, + 0x0100010102000100, 0x0100010102000101, 0x0100010102000200, 0x0100010102010001, + 0x0100010102010100, 0x0100010102010101, 0x0100010102010102, 0x0100010102010201, + 0x0100010102010202, 0x0100010102020101, 0x0100010102020200, 0x0100010102020202, + 0x0100010200000001, 0x0100010200000101, 0x0100010200000201, 0x0100010200010100, + 0x0100010200010101, 0x0100010200010200, 0x0100010200010202, 0x0100010200020001, + 0x0100010200020100, 0x0100010200020201, 0x0100010201000000, 0x0100010201000002, + 0x0100010201000101, 0x0100010201000200, 0x0100010201010000, 0x0100010201010001, + 0x0100010201010002, 0x0100010201010101, 0x0100010201010102, 0x0100010201010201, + 0x0100010201020002, 0x0100010201020101, 0x0100010201020200, 0x0100010202000001, + 0x0100010202000101, 0x0100010202000202, 0x0100010202010100, 0x0100010202010101, + 0x0100010202020001, 0x0100010202020100, 0x0100010202020102, 0x0100020000000101, + 0x0100020000010001, 0x0100020000010101, 0x0100020000010202, 0x0100020000020101, + 0x0100020001000002, 0x0100020001000201, 0x0100020001010000, 0x0100020001010101, + 0x0100020001010200, 0x0100020001020001, 0x0100020001020100, 0x0100020001020102, + 0x0100020001020201, 0x0100020002000101, 0x0100020002010001, 0x0100020002010100, + 0x0100020002010102, 0x0100020002010201, 0x0100020002020101, 0x0100020100000001, + 0x0100020100000101, 0x0100020100000102, 0x0100020100000202, 0x0100020100010000, + 0x0100020100010100, 0x0100020100010101, 0x0100020100010200, 0x0100020100020001, + 0x0100020100020100, 0x0100020100020102, 0x0100020101000000, 0x0100020101000101, + 0x0100020101000202, 0x0100020101010001, 0x0100020101010002, 0x0100020101010100, + 0x0100020101010101, 0x0100020101010102, 0x0100020101010201, 0x0100020101020000, + 0x0100020101020002, 0x0100020101020101, 0x0100020101020102, 0x0100020101020202, + 0x0100020102000102, 0x0100020102000201, 0x0100020102010002, 0x0100020102010101, + 0x0100020102010102, 0x0100020102010200, 0x0100020102020001, 0x0100020102020100, + 0x0100020102020102, 0x0100020102020201, 0x0100020200010102, 0x0100020201000100, + 0x0100020201000102, 0x0100020201000201, 0x0100020201010101, 0x0100020201010200, + 0x0100020201010202, 0x0100020201020100, 0x0100020201020201, 0x0100020202010100, + 0x0100020202020101, 0x0101000000000001, 0x0101000000000100, 0x0101000000000101, + 0x0101000000000102, 0x0101000000000201, 0x0101000000010002, 0x0101000000010101, + 0x0101000000010202, 0x0101000000020001, 0x0101000000020100, 0x0101000000020201, + 0x0101000001000000, 0x0101000001000101, 0x0101000001000200, 0x0101000001010001, + 0x0101000001010100, 0x0101000001010101, 0x0101000001010102, 0x0101000001010201, + 0x0101000001020101, 0x0101000001020200, 0x0101000002000102, 0x0101000002000201, + 0x0101000002010101, 0x0101000002010200, 0x0101000002020000, 0x0101000002020001, + 0x0101000002020102, 0x0101000002020201, 0x0101000100000101, 0x0101000100000200, + 0x0101000100000201, 0x0101000100000202, 0x0101000100010001, 0x0101000100010100, + 0x0101000100010101, 0x0101000100010102, 0x0101000100010200, 0x0101000100010201, + 0x0101000100020000, 0x0101000100020101, 0x0101000100020102, 0x0101000100020200, + 0x0101000100020202, 0x0101000101000001, 0x0101000101000100, 0x0101000101000101, + 0x0101000101000102, 0x0101000101000201, 0x0101000101010000, 0x0101000101010001, + 0x0101000101010002, 0x0101000101010100, 0x0101000101010101, 0x0101000101010102, + 0x0101000101010200, 0x0101000101010201, 0x0101000101010202, 0x0101000101020001, + 0x0101000101020100, 0x0101000101020101, 0x0101000101020102, 0x0101000101020201, + 0x0101000102000002, 0x0101000102000101, 0x0101000102010001, 0x0101000102010100, + 0x0101000102010101, 0x0101000102010102, 0x0101000102010201, 0x0101000102020000, + 0x0101000102020101, 0x0101000102020202, 0x0101000200000001, 0x0101000200000102, + 0x0101000200010002, 0x0101000200010101, 0x0101000200010202, 0x0101000200020001, + 0x0101000200020100, 0x0101000201000002, 0x0101000201000101, 0x0101000201000202, + 0x0101000201010001, 0x0101000201010100, 0x0101000201010101, 0x0101000201010102, + 0x0101000201010201, 0x0101000201020002, 0x0101000201020101, 0x0101000202000101, + 0x0101000202010000, 0x0101000202010002, 0x0101000202010101, 0x0101000202010201, + 0x0101000202010202, 0x0101000202020100, 0x0101010000000100, 0x0101010000000101, + 0x0101010000010001, 0x0101010000010100, 0x0101010000010101, 0x0101010000010102, + 0x0101010000010200, 0x0101010000010201, 0x0101010000020001, 0x0101010000020101, + 0x0101010000020200, 0x0101010000020202, 0x0101010001000001, 0x0101010001000100, + 0x0101010001000101, 0x0101010001000102, 0x0101010001000201, 0x0101010001000202, + 0x0101010001010000, 0x0101010001010001, 0x0101010001010100, 0x0101010001010101, + 0x0101010001010102, 0x0101010001010200, 0x0101010001010201, 0x0101010001010202, + 0x0101010001020001, 0x0101010001020002, 0x0101010001020100, 0x0101010001020101, + 0x0101010001020102, 0x0101010001020201, 0x0101010002000000, 0x0101010002000200, + 0x0101010002000202, 0x0101010002010001, 0x0101010002010100, 0x0101010002010101, + 0x0101010002010102, 0x0101010002010201, 0x0101010002020001, 0x0101010002020100, + 0x0101010002020101, 0x0101010002020202, 0x0101010100000001, 0x0101010100000002, + 0x0101010100000100, 0x0101010100000101, 0x0101010100000102, 0x0101010100000201, + 0x0101010100010000, 0x0101010100010001, 0x0101010100010002, 0x0101010100010100, + 0x0101010100010101, 0x0101010100010102, 0x0101010100010201, 0x0101010100010202, + 0x0101010100020001, 0x0101010100020100, 0x0101010100020101, 0x0101010100020102, + 0x0101010100020201, 0x0101010101000000, 0x0101010101000001, 0x0101010101000002, + 0x0101010101000100, 0x0101010101000101, 0x0101010101000102, 0x0101010101000200, + 0x0101010101000201, 0x0101010101010000, 0x0101010101010001, 0x0101010101010002, + 0x0101010101010100, 0x0101010101010101, 0x0101010101010102, 0x0101010101010200, + 0x0101010101010201, 0x0101010101010202, 0x0101010101020000, 0x0101010101020001, + 0x0101010101020100, 0x0101010101020101, 0x0101010101020102, 0x0101010101020200, + 0x0101010101020201, 0x0101010101020202, 0x0101010102000001, 0x0101010102000100, + 0x0101010102000101, 0x0101010102000201, 0x0101010102000202, 0x0101010102010000, + 0x0101010102010001, 0x0101010102010100, 0x0101010102010101, 0x0101010102010102, + 0x0101010102010200, 0x0101010102010201, 0x0101010102020001, 0x0101010102020100, + 0x0101010102020101, 0x0101010102020102, 0x0101010102020201, 0x0101010200000000, + 0x0101010200000001, 0x0101010200000002, 0x0101010200000100, 0x0101010200000102, + 0x0101010200000200, 0x0101010200000201, 0x0101010200010001, 0x0101010200010100, + 0x0101010200010101, 0x0101010200010200, 0x0101010200010201, 0x0101010200020000, + 0x0101010200020001, 0x0101010200020002, 0x0101010200020100, 0x0101010200020101, + 0x0101010200020102, 0x0101010200020200, 0x0101010200020201, 0x0101010201000001, + 0x0101010201000101, 0x0101010201000102, 0x0101010201000200, 0x0101010201000201, + 0x0101010201000202, 0x0101010201010000, 0x0101010201010001, 0x0101010201010002, + 0x0101010201010100, 0x0101010201010101, 0x0101010201010102, 0x0101010201010200, + 0x0101010201010201, 0x0101010201010202, 0x0101010201020001, 0x0101010201020100, + 0x0101010201020101, 0x0101010201020201, 0x0101010202000002, 0x0101010202000101, + 0x0101010202000102, 0x0101010202000200, 0x0101010202000201, 0x0101010202000202, + 0x0101010202010001, 0x0101010202010101, 0x0101010202010202, 0x0101010202020002, + 0x0101010202020101, 0x0101010202020102, 0x0101010202020200, 0x0101010202020201, + 0x0101020000000100, 0x0101020000000101, 0x0101020000000102, 0x0101020000000201, + 0x0101020000010000, 0x0101020000010101, 0x0101020000010200, 0x0101020000020001, + 0x0101020000020202, 0x0101020001000101, 0x0101020001000200, 0x0101020001000202, + 0x0101020001010001, 0x0101020001010100, 0x0101020001010101, 0x0101020001010102, + 0x0101020001010200, 0x0101020001010201, 0x0101020001020000, 0x0101020001020002, + 0x0101020001020100, 0x0101020001020101, 0x0101020002000002, 0x0101020002000201, + 0x0101020002010000, 0x0101020002010002, 0x0101020002010101, 0x0101020002010200, + 0x0101020002020001, 0x0101020002020201, 0x0101020100000001, 0x0101020100000002, + 0x0101020100000101, 0x0101020100000202, 0x0101020100010001, 0x0101020100010100, + 0x0101020100010101, 0x0101020100010102, 0x0101020100010201, 0x0101020100020101, + 0x0101020101000001, 0x0101020101000100, 0x0101020101000101, 0x0101020101000102, + 0x0101020101000201, 0x0101020101010000, 0x0101020101010001, 0x0101020101010002, + 0x0101020101010100, 0x0101020101010101, 0x0101020101010102, 0x0101020101010200, + 0x0101020101010201, 0x0101020101010202, 0x0101020101020001, 0x0101020101020100, + 0x0101020101020101, 0x0101020101020102, 0x0101020101020201, 0x0101020102000001, + 0x0101020102000101, 0x0101020102000201, 0x0101020102010001, 0x0101020102010100, + 0x0101020102010101, 0x0101020102010102, 0x0101020102010200, 0x0101020102010201, + 0x0101020102020101, 0x0101020200000100, 0x0101020200000200, 0x0101020200010101, + 0x0101020200010202, 0x0101020200020000, 0x0101020200020101, 0x0101020200020102, + 0x0101020200020201, 0x0101020201000101, 0x0101020201000200, 0x0101020201000201, + 0x0101020201010001, 0x0101020201010101, 0x0101020201010102, 0x0101020201010200, + 0x0101020201010201, 0x0101020201020002, 0x0101020201020101, 0x0101020201020200, + 0x0101020201020202, 0x0101020202000001, 0x0101020202000202, 0x0101020202010002, + 0x0101020202010101, 0x0101020202010102, 0x0101020202010200, 0x0101020202010202, + 0x0101020202020001, 0x0102000000000101, 0x0102000000010100, 0x0102000000010102, + 0x0102000000010201, 0x0102000000020101, 0x0102000001000100, 0x0102000001010000, + 0x0102000001010101, 0x0102000001010102, 0x0102000001010200, 0x0102000001010202, + 0x0102000001020001, 0x0102000001020100, 0x0102000001020102, 0x0102000001020201, + 0x0102000002000001, 0x0102000002010102, 0x0102000002020101, 0x0102000100000001, + 0x0102000100000100, 0x0102000100000102, 0x0102000100000201, 0x0102000100010002, + 0x0102000100010101, 0x0102000100020001, 0x0102000100020002, 0x0102000100020102, + 0x0102000100020201, 0x0102000101000101, 0x0102000101000201, 0x0102000101010001, + 0x0102000101010101, 0x0102000101010102, 0x0102000101010201, 0x0102000101020101, + 0x0102000101020102, 0x0102000101020202, 0x0102000102000100, 0x0102000102000202, + 0x0102000102010002, 0x0102000102010101, 0x0102000102020001, 0x0102000102020102, + 0x0102000102020201, 0x0102000200010001, 0x0102000200010102, 0x0102000200010201, + 0x0102000201000000, 0x0102000201000001, 0x0102000201000102, 0x0102000201010101, + 0x0102000201010102, 0x0102000201010200, 0x0102000201020000, 0x0102000202000101, + 0x0102000202010001, 0x0102000202010102, 0x0102000202020101, 0x0102010000010001, + 0x0102010000010002, 0x0102010000010101, 0x0102010000010102, 0x0102010000010202, + 0x0102010000020001, 0x0102010000020102, 0x0102010000020201, 0x0102010001000000, + 0x0102010001000002, 0x0102010001000101, 0x0102010001000200, 0x0102010001000202, + 0x0102010001010001, 0x0102010001010100, 0x0102010001010101, 0x0102010001010102, + 0x0102010001010201, 0x0102010001010202, 0x0102010001020000, 0x0102010001020002, + 0x0102010001020101, 0x0102010002000100, 0x0102010002000101, 0x0102010002000201, + 0x0102010002010000, 0x0102010002010002, 0x0102010002010100, 0x0102010002010101, + 0x0102010002010102, 0x0102010002010200, 0x0102010002010202, 0x0102010002020001, + 0x0102010002020100, 0x0102010002020201, 0x0102010100000101, 0x0102010100000200, + 0x0102010100000202, 0x0102010100010001, 0x0102010100010101, 0x0102010100010102, + 0x0102010100010201, 0x0102010101000100, 0x0102010101000101, 0x0102010101000102, + 0x0102010101000201, 0x0102010101010000, 0x0102010101010001, 0x0102010101010100, + 0x0102010101010101, 0x0102010101010102, 0x0102010101010201, 0x0102010101020001, + 0x0102010101020100, 0x0102010101020101, 0x0102010101020102, 0x0102010101020201, + 0x0102010102000102, 0x0102010102000201, 0x0102010102000202, 0x0102010102010001, + 0x0102010102010101, 0x0102010102010102, 0x0102010102010201, 0x0102010102010202, + 0x0102010102020002, 0x0102010102020101, 0x0102010102020102, 0x0102010102020200, + 0x0102010200000002, 0x0102010200000201, 0x0102010200010101, 0x0102010200020000, + 0x0102010200020102, 0x0102010200020200, 0x0102010200020201, 0x0102010201000000, + 0x0102010201000101, 0x0102010201000200, 0x0102010201000202, 0x0102010201010001, + 0x0102010201010100, 0x0102010201010101, 0x0102010201010102, 0x0102010201010200, + 0x0102010201010202, 0x0102010201020000, 0x0102010201020101, 0x0102010201020200, + 0x0102010202000000, 0x0102010202000002, 0x0102010202000101, 0x0102010202000202, + 0x0102010202010100, 0x0102010202010102, 0x0102010202010200, 0x0102010202010201, + 0x0102010202020000, 0x0102010202020100, 0x0102010202020102, 0x0102010202020202, + 0x0102020000010102, 0x0102020000010201, 0x0102020000020101, 0x0102020001000001, + 0x0102020001010002, 0x0102020001010101, 0x0102020001010202, 0x0102020001020001, + 0x0102020001020201, 0x0102020002000101, 0x0102020002010001, 0x0102020002010200, + 0x0102020002020102, 0x0102020100000001, 0x0102020100000100, 0x0102020100010000, + 0x0102020100010101, 0x0102020100020001, 0x0102020100020100, 0x0102020100020102, + 0x0102020100020201, 0x0102020101000000, 0x0102020101000001, 0x0102020101000101, + 0x0102020101000102, 0x0102020101000200, 0x0102020101010001, 0x0102020101010100, + 0x0102020101010101, 0x0102020101010102, 0x0102020101010201, 0x0102020101020000, + 0x0102020101020101, 0x0102020101020202, 0x0102020102000002, 0x0102020102000100, + 0x0102020102000202, 0x0102020102010101, 0x0102020102020001, 0x0102020102020100, + 0x0102020102020101, 0x0102020102020201, 0x0102020200010001, 0x0102020200010102, + 0x0102020200010200, 0x0102020201000001, 0x0102020201000100, 0x0102020201000201, + 0x0102020201010000, 0x0102020201010101, 0x0102020201010200, 0x0102020201010202, + 0x0102020201020100, 0x0102020201020101, 0x0102020201020201, 0x0102020202000102, + 0x0102020202010100, 0x0102020202010200, 0x0102020202010202, 0x0102020202020102, + 0x0200000000000000, 0x0200000000000002, 0x0200000000000200, 0x0200000000000202, + 0x0200000000020000, 0x0200000000020002, 0x0200000000020200, 0x0200000000020202, + 0x0200000001000101, 0x0200000001010000, 0x0200000001010001, 0x0200000001010100, + 0x0200000001010102, 0x0200000001010201, 0x0200000001020101, 0x0200000002000000, + 0x0200000002000002, 0x0200000002000200, 0x0200000002000202, 0x0200000002010101, + 0x0200000002020000, 0x0200000002020002, 0x0200000002020200, 0x0200000002020202, + 0x0200000100000101, 0x0200000100010001, 0x0200000100010100, 0x0200000100010102, + 0x0200000100010201, 0x0200000100020101, 0x0200000101000001, 0x0200000101000100, + 0x0200000101000201, 0x0200000101010000, 0x0200000101010002, 0x0200000101010101, + 0x0200000101010102, 0x0200000101010200, 0x0200000101010201, 0x0200000101020100, + 0x0200000101020102, 0x0200000101020201, 0x0200000102000101, 0x0200000102000201, + 0x0200000102010100, 0x0200000102010102, 0x0200000102010201, 0x0200000102020101, + 0x0200000200000000, 0x0200000200000002, 0x0200000200000200, 0x0200000200000202, + 0x0200000200010101, 0x0200000200020000, 0x0200000200020002, 0x0200000200020200, + 0x0200000200020202, 0x0200000201010001, 0x0200000201010100, 0x0200000201010201, + 0x0200000201020101, 0x0200000202000000, 0x0200000202000002, 0x0200000202000200, + 0x0200000202000202, 0x0200000202010101, 0x0200000202020000, 0x0200000202020002, + 0x0200000202020200, 0x0200000202020202, 0x0200010000010100, 0x0200010000010201, + 0x0200010001000001, 0x0200010001000100, 0x0200010001010001, 0x0200010001010101, + 0x0200010001010202, 0x0200010001020001, 0x0200010001020100, 0x0200010001020201, + 0x0200010002010100, 0x0200010002010201, 0x0200010100000001, 0x0200010100000201, + 0x0200010100010002, 0x0200010100010101, 0x0200010100010202, 0x0200010100020102, + 0x0200010100020201, 0x0200010101000000, 0x0200010101000001, 0x0200010101000101, + 0x0200010101000200, 0x0200010101010001, 0x0200010101010100, 0x0200010101010101, + 0x0200010101010102, 0x0200010101010201, 0x0200010101010202, 0x0200010101020101, + 0x0200010101020102, 0x0200010101020200, 0x0200010101020202, 0x0200010102000001, + 0x0200010102000100, 0x0200010102000102, 0x0200010102000201, 0x0200010102010000, + 0x0200010102010002, 0x0200010102010101, 0x0200010102010200, 0x0200010102020102, + 0x0200010200010001, 0x0200010200010102, 0x0200010200010201, 0x0200010200020101, + 0x0200010201000001, 0x0200010201000100, 0x0200010201000201, 0x0200010201000202, + 0x0200010201010000, 0x0200010201010101, 0x0200010201010201, 0x0200010201010202, + 0x0200010201020001, 0x0200010201020102, 0x0200010201020202, 0x0200010202000101, + 0x0200010202010001, 0x0200010202010202, 0x0200010202020100, 0x0200020000000000, + 0x0200020000000002, 0x0200020000000200, 0x0200020000000202, 0x0200020000010101, + 0x0200020000020000, 0x0200020000020002, 0x0200020000020200, 0x0200020000020202, + 0x0200020001000001, 0x0200020001000101, 0x0200020001010001, 0x0200020001010100, + 0x0200020001010201, 0x0200020001020101, 0x0200020001020201, 0x0200020002000000, + 0x0200020002000002, 0x0200020002000200, 0x0200020002000202, 0x0200020002010101, + 0x0200020002020000, 0x0200020002020002, 0x0200020002020200, 0x0200020002020202, + 0x0200020100000101, 0x0200020100000102, 0x0200020100010001, 0x0200020100010100, + 0x0200020100010102, 0x0200020100020101, 0x0200020101000001, 0x0200020101000100, + 0x0200020101000102, 0x0200020101000201, 0x0200020101010000, 0x0200020101010002, + 0x0200020101010101, 0x0200020101010202, 0x0200020101020001, 0x0200020101020100, + 0x0200020102000101, 0x0200020102010102, 0x0200020102010201, 0x0200020102020101, + 0x0200020200000000, 0x0200020200000002, 0x0200020200000200, 0x0200020200000202, + 0x0200020200010101, 0x0200020200020000, 0x0200020200020002, 0x0200020200020200, + 0x0200020200020202, 0x0200020201000101, 0x0200020201010001, 0x0200020201010100, + 0x0200020201010102, 0x0200020202000000, 0x0200020202000002, 0x0200020202000200, + 0x0200020202000202, 0x0200020202010101, 0x0200020202020000, 0x0200020202020002, + 0x0200020202020200, 0x0200020202020202, 0x0201000000000101, 0x0201000000010001, + 0x0201000000010102, 0x0201000000010200, 0x0201000000010201, 0x0201000000020101, + 0x0201000001000001, 0x0201000001000102, 0x0201000001000201, 0x0201000001010101, + 0x0201000001010200, 0x0201000001010202, 0x0201000001020201, 0x0201000001020202, + 0x0201000002000101, 0x0201000002010001, 0x0201000002010100, 0x0201000002010102, + 0x0201000002010201, 0x0201000002020101, 0x0201000100000001, 0x0201000100000100, + 0x0201000100000102, 0x0201000100000201, 0x0201000100010000, 0x0201000100010101, + 0x0201000100010200, 0x0201000100010202, 0x0201000100020001, 0x0201000100020100, + 0x0201000100020102, 0x0201000100020201, 0x0201000101000000, 0x0201000101000101, + 0x0201000101010000, 0x0201000101010001, 0x0201000101010100, 0x0201000101010101, + 0x0201000101010102, 0x0201000101010201, 0x0201000101020002, 0x0201000101020101, + 0x0201000102000100, 0x0201000102000102, 0x0201000102010002, 0x0201000102010101, + 0x0201000102010200, 0x0201000102020001, 0x0201000102020100, 0x0201000102020102, + 0x0201000102020201, 0x0201000200000101, 0x0201000200010001, 0x0201000200010100, + 0x0201000200010201, 0x0201000200020101, 0x0201000201000100, 0x0201000201000102, + 0x0201000201000201, 0x0201000201010000, 0x0201000201010002, 0x0201000201010101, + 0x0201000201010200, 0x0201000201020102, 0x0201000201020201, 0x0201000202000101, + 0x0201000202010100, 0x0201000202010102, 0x0201000202020201, 0x0201010000000001, + 0x0201010000000100, 0x0201010000000102, 0x0201010000010000, 0x0201010000010101, + 0x0201010000010200, 0x0201010000020102, 0x0201010001000000, 0x0201010001000202, + 0x0201010001010001, 0x0201010001010100, 0x0201010001010101, 0x0201010001010102, + 0x0201010001010200, 0x0201010001010201, 0x0201010001020000, 0x0201010001020001, + 0x0201010001020002, 0x0201010001020101, 0x0201010002000100, 0x0201010002000102, + 0x0201010002010002, 0x0201010002010100, 0x0201010002010101, 0x0201010002010200, + 0x0201010002020001, 0x0201010002020201, 0x0201010100000000, 0x0201010100000101, + 0x0201010100000200, 0x0201010100000202, 0x0201010100010000, 0x0201010100010001, + 0x0201010100010100, 0x0201010100010101, 0x0201010100010102, 0x0201010100010201, + 0x0201010100020001, 0x0201010100020101, 0x0201010100020201, 0x0201010100020202, + 0x0201010101000001, 0x0201010101000100, 0x0201010101000101, 0x0201010101000102, + 0x0201010101000201, 0x0201010101010000, 0x0201010101010001, 0x0201010101010002, + 0x0201010101010100, 0x0201010101010101, 0x0201010101010102, 0x0201010101010200, + 0x0201010101010201, 0x0201010101010202, 0x0201010101020001, 0x0201010101020100, + 0x0201010101020101, 0x0201010101020102, 0x0201010101020201, 0x0201010102000001, + 0x0201010102000101, 0x0201010102000200, 0x0201010102010001, 0x0201010102010002, + 0x0201010102010100, 0x0201010102010101, 0x0201010102010102, 0x0201010102010201, + 0x0201010102010202, 0x0201010102020000, 0x0201010102020002, 0x0201010102020101, + 0x0201010102020200, 0x0201010102020202, 0x0201010200000001, 0x0201010200000100, + 0x0201010200010000, 0x0201010200010101, 0x0201010200010201, 0x0201010200020000, + 0x0201010200020102, 0x0201010200020201, 0x0201010201000101, 0x0201010201000200, + 0x0201010201000201, 0x0201010201010001, 0x0201010201010002, 0x0201010201010101, + 0x0201010201010102, 0x0201010201010201, 0x0201010201020101, 0x0201010201020200, + 0x0201010202000002, 0x0201010202000100, 0x0201010202000201, 0x0201010202000202, + 0x0201010202010002, 0x0201010202010100, 0x0201010202010101, 0x0201010202020100, + 0x0201010202020102, 0x0201010202020201, 0x0201020000000101, 0x0201020000010102, + 0x0201020000010201, 0x0201020000020101, 0x0201020001000001, 0x0201020001000102, + 0x0201020001010000, 0x0201020001010002, 0x0201020001010101, 0x0201020001010102, + 0x0201020001010202, 0x0201020001020100, 0x0201020001020101, 0x0201020002000101, + 0x0201020002010001, 0x0201020002010102, 0x0201020002010201, 0x0201020002020101, + 0x0201020100000100, 0x0201020100000102, 0x0201020100000201, 0x0201020100010000, + 0x0201020100010002, 0x0201020100010101, 0x0201020100010200, 0x0201020100010202, + 0x0201020100020000, 0x0201020100020001, 0x0201020100020100, 0x0201020100020102, + 0x0201020101000000, 0x0201020101000002, 0x0201020101000101, 0x0201020101000200, + 0x0201020101000202, 0x0201020101010001, 0x0201020101010100, 0x0201020101010101, + 0x0201020101010102, 0x0201020101010201, 0x0201020101020002, 0x0201020101020101, + 0x0201020101020102, 0x0201020101020202, 0x0201020102000001, 0x0201020102000100, + 0x0201020102010000, 0x0201020102010002, 0x0201020102010101, 0x0201020102010202, + 0x0201020102020001, 0x0201020102020102, 0x0201020200000101, 0x0201020200010101, + 0x0201020200020101, 0x0201020201000100, 0x0201020201000102, 0x0201020201000201, + 0x0201020201010000, 0x0201020201010101, 0x0201020201010200, 0x0201020201020001, + 0x0201020202000101, 0x0201020202010001, 0x0201020202010100, 0x0201020202010101, + 0x0201020202010102, 0x0202000000000000, 0x0202000000000002, 0x0202000000000200, + 0x0202000000000202, 0x0202000000010101, 0x0202000000020000, 0x0202000000020002, + 0x0202000000020200, 0x0202000000020202, 0x0202000001000101, 0x0202000001010001, + 0x0202000001010100, 0x0202000001010102, 0x0202000001010201, 0x0202000002000000, + 0x0202000002000002, 0x0202000002000200, 0x0202000002000202, 0x0202000002010101, + 0x0202000002020000, 0x0202000002020002, 0x0202000002020200, 0x0202000002020202, + 0x0202000100000101, 0x0202000100000201, 0x0202000100010001, 0x0202000100010100, + 0x0202000100010102, 0x0202000100010201, 0x0202000100010202, 0x0202000101000102, + 0x0202000101000201, 0x0202000101010001, 0x0202000101010101, 0x0202000101010200, + 0x0202000101010202, 0x0202000101020001, 0x0202000101020100, 0x0202000102000101, + 0x0202000102010000, 0x0202000102010002, 0x0202000102010102, 0x0202000102010201, + 0x0202000200000002, 0x0202000200000200, 0x0202000200000202, 0x0202000200010000, + 0x0202000200010201, 0x0202000200020002, 0x0202000200020200, 0x0202000200020202, + 0x0202000201000101, 0x0202000201010001, 0x0202000201010102, 0x0202000201010201, + 0x0202000201020101, 0x0202000202000000, 0x0202000202000002, 0x0202000202000200, + 0x0202000202000202, 0x0202000202010101, 0x0202000202020000, 0x0202000202020002, + 0x0202000202020200, 0x0202000202020202, 0x0202010000010201, 0x0202010000020101, + 0x0202010001000001, 0x0202010001000100, 0x0202010001010000, 0x0202010001010100, + 0x0202010001010101, 0x0202010001010200, 0x0202010001010202, 0x0202010001020001, + 0x0202010001020101, 0x0202010001020102, 0x0202010001020200, 0x0202010001020201, + 0x0202010002000101, 0x0202010100000102, 0x0202010100000201, 0x0202010100010000, + 0x0202010100010002, 0x0202010100010101, 0x0202010100010200, 0x0202010100020102, + 0x0202010100020201, 0x0202010101000002, 0x0202010101000101, 0x0202010101010001, + 0x0202010101010100, 0x0202010101010101, 0x0202010101010102, 0x0202010101010201, + 0x0202010101020101, 0x0202010101020202, 0x0202010102000001, 0x0202010102000100, + 0x0202010102000101, 0x0202010102000102, 0x0202010102000201, 0x0202010102010002, + 0x0202010102010101, 0x0202010102010200, 0x0202010200000101, 0x0202010200010001, + 0x0202010200010102, 0x0202010200010202, 0x0202010200020001, 0x0202010200020101, + 0x0202010201000100, 0x0202010201000102, 0x0202010201000202, 0x0202010201010002, + 0x0202010201010101, 0x0202010201010102, 0x0202010201010200, 0x0202010201020000, + 0x0202010201020002, 0x0202010202000102, 0x0202010202010000, 0x0202010202010101, + 0x0202010202010102, 0x0202010202010201, 0x0202010202020001, 0x0202010202020100, + 0x0202010202020102, 0x0202020000000000, 0x0202020000000002, 0x0202020000000200, + 0x0202020000000202, 0x0202020000020000, 0x0202020000020002, 0x0202020000020200, + 0x0202020000020202, 0x0202020001010001, 0x0202020001010100, 0x0202020001010102, + 0x0202020001010201, 0x0202020002000000, 0x0202020002000002, 0x0202020002000200, + 0x0202020002000202, 0x0202020002010101, 0x0202020002020000, 0x0202020002020002, + 0x0202020002020200, 0x0202020002020202, 0x0202020100000101, 0x0202020100010100, + 0x0202020100010201, 0x0202020100020001, 0x0202020100020101, 0x0202020101000001, + 0x0202020101010000, 0x0202020101010101, 0x0202020101010202, 0x0202020101020001, + 0x0202020101020102, 0x0202020101020201, 0x0202020102010000, 0x0202020102010102, + 0x0202020200000000, 0x0202020200000002, 0x0202020200000200, 0x0202020200000202, + 0x0202020200020000, 0x0202020200020002, 0x0202020200020200, 0x0202020200020202, + 0x0202020201010001, 0x0202020201010100, 0x0202020201010102, 0x0202020202000000, + 0x0202020202000002, 0x0202020202000200, 0x0202020202000202, 0x0202020202010101, + 0x0202020202020000, 0x0202020202020002, 0x0202020202020200, 0x0202020202020202, +}; +#endif + #ifndef HAVE_FANCY_SIMD const uint64_t keven_signs[128] = { 0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff, @@ -2745,6 +3264,92 @@ static void mul_mat_q4_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D } } +template +static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nb = n / 32; + GGML_ASSERT(nb%4 == 0); + __m256i qx[4]; + __m256 acc[nrc_y] = {}; + auto m1 = _mm256_set1_epi16(1); + auto ms = _mm_set1_epi16(-32768); + float d8[8*nrc_y]; + union { __m256i vec; uint16_t val[16]; } helper; + struct aux_iq1_s_r4 { + uint8_t qs[16]; + uint64_t qh; + }; + for (int ix= 0; ix < nrc_x; ix += 4) { + auto dptr = (const ggml_half *)((const char *)vx + ix*bx); + auto d1 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)dptr)); + auto x = (const aux_iq1_s_r4 *)(dptr + 4); + for (int ib = 0; ib < nb/4; ++ib) { + for (int iy = 0; iy < nrc_y; ++iy) { + _mm256_storeu_ps(d8 + 8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib].d))); + } + for (int k = 0; k < 4; ++k) { + auto idxh = _mm256_set1_epi64x(x[4*ib+k].qh); + auto sas = _mm256_castsi256_si128(idxh); + auto scales4 = _mm_and_si128(_mm_srli_epi16(sas, 12), _mm_set1_epi16(7)); + scales4 = _mm_or_si128(_mm_slli_epi16(scales4, 1), _mm_set1_epi16(1)); + auto signs = _mm_or_si128(_mm_cmpeq_epi16(_mm_and_si128(sas, ms), ms), _mm256_castsi256_si128(m1)); + signs = _mm_add_epi16(_mm_set1_epi16(-8), signs); + auto delta4 = _mm_mul_ps(_mm_set1_ps(0.0625f), _mm_cvtepi32_ps(_mm_cvtepi16_epi32( + _mm_mullo_epi16(scales4, signs)))); + auto delta = _mm256_set_m128(delta4, delta4); + scales4 = _mm_unpacklo_epi16(scales4, scales4); // 0,0, 1,1, 2,2, 3,3 + auto scales = MM256_SET_M128I(scales4, scales4); + auto idxl = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)x[4*ib+k].qs)); + idxh = _mm256_sllv_epi64(idxh, _mm256_set_epi64x(0, 2, 5, 8)); + idxh = _mm256_srlv_epi64(idxh, _mm256_set_epi64x(1, 0, 0, 0)); + helper.vec = _mm256_or_si256(idxl, _mm256_and_si256(_mm256_set1_epi16(0x0700), idxh)); + qx[0] = _mm256_set_epi64x(iq1s_grid_us[helper.val[ 9]], iq1s_grid_us[helper.val[ 8]], + iq1s_grid_us[helper.val[ 1]], iq1s_grid_us[helper.val[ 0]]); + qx[1] = _mm256_set_epi64x(iq1s_grid_us[helper.val[13]], iq1s_grid_us[helper.val[12]], + iq1s_grid_us[helper.val[ 5]], iq1s_grid_us[helper.val[ 4]]); + qx[2] = _mm256_set_epi64x(iq1s_grid_us[helper.val[11]], iq1s_grid_us[helper.val[10]], + iq1s_grid_us[helper.val[ 3]], iq1s_grid_us[helper.val[ 2]]); + qx[3] = _mm256_set_epi64x(iq1s_grid_us[helper.val[15]], iq1s_grid_us[helper.val[14]], + iq1s_grid_us[helper.val[ 7]], iq1s_grid_us[helper.val[ 6]]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ib].qs + k); +#ifdef HAVE_FANCY_SIMD + // 0,0, 1,1, 0,0, 1,1 as int32_t + auto sumi1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), + qx[0], _mm256_shuffle_epi32(y, 0x44)), qx[1], _mm256_shuffle_epi32(y, 0xee)); + // 2,2, 3,3, 2,2, 3,3 as int32_t + auto sumi2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), + qx[2], _mm256_shuffle_epi32(y, 0x44)), qx[3], _mm256_shuffle_epi32(y, 0xee)); + auto sumi = _mm256_packs_epi32(sumi1, sumi2); +#else + // 4 x row 0, 4 x row 1, 4 x row 0, 4 x row 1 + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x44)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0xee))); + // 4 x row 2, 4 x row 3, 4 x row 2, 4 x row 3 + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0x44)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xee))); + // 0,0, 1,1, 0,0, 1,1 as int32_t + sumi1 = _mm256_madd_epi16(m1, sumi1); + // 2,2, 3,3, 2,2, 3,3 as int32_t + sumi2 = _mm256_madd_epi16(m1, sumi2); + // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 as int16_t + auto sumi = _mm256_packs_epi32(sumi1, sumi2); +#endif + sumi = _mm256_madd_epi16(scales, sumi); + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[8*iy+k+0]), _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[8*iy+k+4]), delta, acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumf = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, _mm_mul_ps(d1, sumf)); + acc[iy] = _mm256_setzero_ps(); + } + } +} + #ifdef HAVE_FANCY_SIMD template static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { @@ -7042,14 +7647,14 @@ struct Q8_0_x4_Unpacker_512 { auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x[i].d)); for (int j = 0; j < 4; ++j) { qx[j] = _mm256_loadu_si256((const __m256i *)x[i].qs + j); - qx[j] = _mm256_xor_si256(qx[j], _mm256_set1_epi8(0x80)); + qx[j] = _mm256_xor_si256(qx[j], _mm256_set1_epi8(-128)); } return _mm256_set_m128(_mm_mul_ps(scales, min), scales); } inline auto set_block(int i) { auto q8 = (const block_q8_0 *)(x + i); qx[0] = _mm256_loadu_si256((const __m256i *)q8->qs); - qx[0] = _mm256_xor_si256(qx[0], _mm256_set1_epi8(0x80)); + qx[0] = _mm256_xor_si256(qx[0], _mm256_set1_epi8(-128)); float d = GGML_FP16_TO_FP32(q8->d); return std::make_pair(d, -128.f*d); } @@ -8202,6 +8807,21 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[7] = mul_mat_q8_0_r4_q8_1<8>; expected_typeB = GGML_TYPE_Q8_1_X4; break; + case GGML_TYPE_IQ1_S_R4: + assert (ne00 % QK4_NL == 0); + mm.funcs[0] = mul_mat_iq1_s_r4_q8_1<1>; + mm.funcs[1] = mul_mat_iq1_s_r4_q8_1<2>; + mm.funcs[2] = mul_mat_iq1_s_r4_q8_1<3>; + mm.funcs[3] = mul_mat_iq1_s_r4_q8_1<4>; + mm.funcs[4] = mul_mat_iq1_s_r4_q8_1<5>; + mm.funcs[5] = mul_mat_iq1_s_r4_q8_1<6>; + mm.funcs[6] = mul_mat_iq1_s_r4_q8_1<7>; + mm.funcs[7] = mul_mat_iq1_s_r4_q8_1<8>; +#ifdef HAVE_FANCY_SIMD + mm.func16 = mul_mat_iq1_s_r4_q8_1<16>; +#endif + expected_typeB = GGML_TYPE_Q8_1_X4; + break; default: return false; @@ -11078,6 +11698,78 @@ static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data } } +template +static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nb = n / 32; + GGML_ASSERT(nb%4 == 0); + int8x16_t qx[8]; + int32x4_t acc[nrc_y] = {}; + auto ms = vdup_n_u16(0x8000); + float d8[8*nrc_y]; + for (int ix= 0; ix < nrc_x; ix += 4) { + auto dptr = (const ggml_half *)((const char *)vx + ix*bx); + auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr)); + auto x = (const block_iq1_s_r4 *)(dptr + 4); + for (int ib = 0; ib < nb/4; ++ib) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto scales = vld1q_f16((const float16_t *)q8.y[iy][ib].d); + vst1q_f32(d8+8*iy+0, vcvt_f32_f16(vget_low_f16(scales))); + vst1q_f32(d8+8*iy+4, vcvt_f32_f16(vget_high_f16(scales))); + } + for (int k = 0; k < 4; ++k) { + auto sas = vld1_u16(x[4*ib+k].qh); + auto scales4 = vand_u16(vshr_n_u16(sas, 12), vdup_n_u16(7)); + scales4 = vorr_u16(vshl_n_u16(scales4, 1), vdup_n_u16(1)); + auto signs = vorr_u16(vceq_u16(vand_u16(sas, ms), ms), vdup_n_u16(1)); + auto delta4 = vmulq_f32(vdupq_n_f32(IQ1S_DELTA), vcvtq_f32_s32(vmull_s16(signs, scales4))); + qx[0] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)], + iq1s_grid[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)]}); + qx[1] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)], + iq1s_grid[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)]}); + qx[2] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)], + iq1s_grid[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)]}); + qx[3] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)], + iq1s_grid[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)]}); + qx[4] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)], + iq1s_grid[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]}); + qx[5] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)], + iq1s_grid[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)]}); + qx[6] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)], + iq1s_grid[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]}); + qx[7] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)], + iq1s_grid[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]}); + auto scales = vmovl_u16(scales4); + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi1 = vdupq_n_s32(0); + auto sumi2 = vdupq_n_s32(0); + auto y = vld1_s8_x4(q8.y[iy][ib].qs + 32*k); + auto y1 = vcombine_s8(y.val[0], y.val[0]); + auto y2 = vcombine_s8(y.val[1], y.val[1]); + sumi1 = ggml_vdotq_s32(sumi1, qx[0], y1); + sumi2 = ggml_vdotq_s32(sumi2, qx[4], y1); + sumi1 = ggml_vdotq_s32(sumi1, qx[2], y2); + sumi2 = ggml_vdotq_s32(sumi2, qx[6], y2); + y1 = vcombine_s8(y.val[2], y.val[2]); + y2 = vcombine_s8(y.val[3], y.val[3]); + sumi1 = ggml_vdotq_s32(sumi1, qx[1], y1); + sumi2 = ggml_vdotq_s32(sumi2, qx[5], y1); + sumi1 = ggml_vdotq_s32(sumi1, qx[3], y2); + sumi2 = ggml_vdotq_s32(sumi2, qx[7], y2); + auto sumi = vmulq_s32(scales, vpaddq_s32(sumi1, sumi2)); + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[8*iy+k+0]), vcvtq_f32_s32(sumi)); + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[8*iy+k+4]), delta4); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vmulq_f32(d1, acc[iy])); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + template static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -12697,6 +13389,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { m.func16 = mul_mat_iq2_s_r4_q8_k<16>; expected_Btype = GGML_TYPE_Q8_K; break; + case GGML_TYPE_IQ1_S_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_r4_q8_1); + m.func16 = mul_mat_iq1_s_r4_q8_1<16>; + expected_Btype = GGML_TYPE_Q8_1_X4; + break; case GGML_TYPE_IQ3_XXS_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_xxs_r4_q8_k); m.func16 = mul_mat_iq3_xxs_r4_q8_k<16>; @@ -12995,7 +13692,7 @@ struct F16 { using Data = float16x8_t; constexpr static int block_size = 8; //constexpr static int num_registers = 32; - constexpr static int q_step = 8; + //constexpr static int q_step = 8; static inline Data zero() { return vdupq_n_f16(0); } static inline Data load(const char * ptr, int i) { return vld1q_f16((const float16_t *)ptr + block_size*i); } static inline Data load(const float16_t * ptr, int i) { return vld1q_f16(ptr + block_size*i); } diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index c1e7771f..a8553b43 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -6087,6 +6087,112 @@ void vec_dot_iq3_s_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t GGML_UNUSED(by); } +void quantize_row_iq1_s_r4_ref(const float * x, block_iq1_s_r4 * y, int64_t k) { + quantize_iq1_s_r4(x, y, 4, k/4, nullptr); +} + +void quantize_row_iq1_s_r4(const float * x, void * y, int64_t k) { + quantize_iq1_s_r4(x, y, 4, k/4, nullptr); +} + +size_t quantize_iq1_s_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + constexpr int kBlockSize = 32; + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%kBlockSize == 0); + int nblock = n_per_row/kBlockSize; + float weight[kBlockSize]; + int8_t L[kBlockSize]; + float pairs[2*kBlockSize]; + float sumx[kBlockSize+1], sumw[kBlockSize+1]; + float max[4]; + uint16_t index[4]; + int shift; + float invd[4]; + std::vector scales(4*nblock); + auto row_size = ggml_row_size(GGML_TYPE_IQ1_S_R4, n_per_row); + char * cy = (char *)dst; + for (int row = 0; row < nrows; row += 4) { + ggml_half * dptr = (ggml_half *)cy; + auto y = (block_iq1_s_r4 *)(dptr + 4); + for (int k = 0; k < 4; ++k) max[k] = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + if (imatrix) { + for (int j = 0; j < kBlockSize; ++j) weight[j] = imatrix[kBlockSize*ibl + j]; + } + for (int k = 0; k < 4; ++k) { + auto xb = src + k*n_per_row + kBlockSize*ibl; + float sumx2 = 0; + for (int j = 0; j < kBlockSize; ++j) sumx2 += xb[j]*xb[j]; + float sigma2 = 1.5f*sumx2/kBlockSize; + if (imatrix) { + for (int j = 0; j < kBlockSize; ++j) weight[j] = imatrix[kBlockSize*ibl + j]*sqrt(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < kBlockSize; ++j) weight[j] = sqrt(sigma2 + xb[j]*xb[j]); + } + iq1s_process_1block(kBlockSize, xb, weight, L, scales.data() + 4*ibl + k, index, &shift, pairs, sumx, sumw); + max[k] = std::max(max[k], scales[4*ibl+k]); + uint16_t h = 0; + for (int i = 0; i < 4; ++i) { + y[ibl].qs[4*i + k] = index[i] & 255; + h |= (index[i] >> 8) << 3*i; + } + if (shift < 0) h |= 0x8000; + y[ibl].qh[k] = h; + } + } + for (int k = 0; k < 4; ++k) { + dptr[k] = GGML_FP32_TO_FP16(1.0625f*max[k]/15);; + invd[k] = max[k] ? 15/max[k] : 0.f; + } + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + int ls = nearest_int(0.5f*(scales[4*ibl+k]*invd[k] - 1)); + ls = std::max(0, std::min(7, ls)); + y[ibl].qh[k] |= (ls << 12); + } + } + cy += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_iq1_s_r4(const block_iq1_s_r4 * x, float * y, int64_t n) { + auto dptr = (const ggml_half *)x; + x = (const block_iq1_s_r4 *)(dptr + 4); + float d[4]; + for (int k = 0; k < 4; ++k) d[k] = GGML_FP16_TO_FP32(dptr[k]); + int n_per_row = n/4; + GGML_ASSERT(n_per_row%32 == 0); + int nblock = n_per_row/32; + float * yk[4]; + for (int k = 0; k < 4; ++k) yk[k] = y + k*n_per_row; + for (int ib = 0; ib < nblock; ++ib) { + for (int k = 0; k < 4; ++k) { + float shift = x[ib].qh[k] & 0x8000 ? -IQ1S_DELTA : IQ1S_DELTA; + float dl = d[k]*(2*((x[ib].qh[k] >> 12) & 7) + 1); + for (int i = 0; i < 4; ++i) { + auto idx = x[ib].qs[4*i+k] | (((x[ib].qh[k] >> 3*i) & 7) << 8); + auto grid = (const int8_t *)(iq1s_grid + idx); + for (int j = 0; j < 8; ++j) yk[k][32*ib + 8*i + j] = dl*(grid[j] + shift); + } + } + } +} + +void vec_dot_iq1_s_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ1_S_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + //================================================ namespace { diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 1a991787..9a3c5dc6 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -199,6 +199,12 @@ size_t quantize_iq3_s_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT d void dequantize_row_iq3_s_r4(const block_iq3_s_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq3_s_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_iq1_s_r4_ref(const float * GGML_RESTRICT x, block_iq1_s_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_iq1_s_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq1_s_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq1_s_r4(const block_iq1_s_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq1_s_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void quantize_row_q8_k_r8_ref(const float * GGML_RESTRICT x, block_q8_k_r8 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_k_r8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); size_t quantize_q8_k_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/include/llama.h b/include/llama.h index c21671c6..0f6d15ac 100644 --- a/include/llama.h +++ b/include/llama.h @@ -192,6 +192,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 = 219, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_XS_R4 = 220, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 = 223, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ1_S_R4 = 224, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 = 225, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ3_S_R4 = 226, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_M_R4 = 229, // except 1d tensors diff --git a/src/llama.cpp b/src/llama.cpp index 570c056c..943b945a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3954,6 +3954,7 @@ struct llama_model_loader { case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break; case GGML_TYPE_IQ3_XXS_R4: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4; break; case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break; + case GGML_TYPE_IQ1_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ1_S_R4;break; case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break; case GGML_TYPE_IQ1_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ1_BN; break; case GGML_TYPE_IQ2_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN; break; @@ -4688,6 +4689,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4: return "IQ3_XXS_R4 - 3.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_S_R4: return "IQ1_S_R4 - 1.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL_R4:return "IQ4_NL_R4 - 4.5 bpw"; @@ -15966,7 +15968,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K || ftype == LLAMA_FTYPE_MOSTLY_IQ2_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K_R4 || - ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M_R4) { + ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ2_M_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4) { new_type = !qs.has_output ? GGML_TYPE_IQ4_K : GGML_TYPE_Q5_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS_R4) { @@ -15987,7 +15990,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n } else { if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || - ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS_R4) { + ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4) { new_type = GGML_TYPE_Q2_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M_R4) { @@ -16064,6 +16068,41 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = GGML_TYPE_BF16; } } + } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4) { + if (name.find("attn_v.weight") != std::string::npos) { + if (qs.model.hparams.n_expert >= 4 || qs.model.hparams.n_gqa() >= 4) new_type = GGML_TYPE_IQ4_K_R4; + else if (qs.model.hparams.n_gqa() >= 2) new_type = GGML_TYPE_IQ3_K_R4; + else new_type = GGML_TYPE_Q2_K_R4; + ++qs.i_attention_wv; + } + else if (qs.model.hparams.n_expert >= 8 && name.find("attn_k") != std::string::npos) { + new_type = GGML_TYPE_Q4_K_R4; + } + else if (qs.model.hparams.n_expert >= 8 && (name.find("blk.0.ffn_down") != std::string::npos || + name.find("blk.0.ffn_gate") != std::string::npos || + name.find("blk.0.ffn_up") != std::string::npos)) { + new_type = GGML_TYPE_IQ3_K_R4; + } + else if (qs.model.hparams.n_expert >= 8 && name.find("attn_q") != std::string::npos) { + new_type = GGML_TYPE_Q4_K_R4; + } + else if (name.find("attn_qkv.weight") != std::string::npos) { + new_type = GGML_TYPE_IQ2_K_R4; + } + else if (name.find("_shexp.weight") != std::string::npos) { + new_type = GGML_TYPE_IQ4_K_R4; + } + else if (name.find("ffn_down") != std::string::npos) { + auto [i_layer, n_layer] = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str()); + if (qs.params->ffn_down_type < GGML_TYPE_COUNT) new_type = qs.params->ffn_down_type; + else if (i_layer < n_layer/8) { + new_type = GGML_TYPE_Q2_K_R4; + } + ++qs.i_ffn_down; + } + else if (name.find("attn_output.weight") != std::string::npos) { + new_type = qs.model.hparams.n_expert >= 4 ? GGML_TYPE_Q5_K_R4 : GGML_TYPE_IQ2_K_R4; + } } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || ftype == LLAMA_FTYPE_MOSTLY_IQ2_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS_R4 || @@ -16095,6 +16134,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = GGML_TYPE_Q5_K; } else { if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4) new_type = GGML_TYPE_IQ2_K_R4; else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || is_iq2_m) new_type = GGML_TYPE_IQ3_S; } } @@ -16539,6 +16579,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break; case LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4: default_type = GGML_TYPE_IQ3_XXS_R4; break; case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break; + case LLAMA_FTYPE_MOSTLY_IQ1_S_R4:default_type = GGML_TYPE_IQ1_S_R4;break; case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break; case LLAMA_FTYPE_MOSTLY_IQ1_BN: default_type = GGML_TYPE_IQ1_BN; break; case LLAMA_FTYPE_MOSTLY_IQ2_BN: default_type = GGML_TYPE_IQ2_BN; break; @@ -16892,6 +16933,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_type == GGML_TYPE_IQ2_S || new_type == GGML_TYPE_IQ2_S_R4|| new_type == GGML_TYPE_IQ1_S || + new_type == GGML_TYPE_IQ1_S_R4|| (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) || (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0))) { LLAMA_LOG_ERROR("\n\n============================================================\n"); @@ -17011,6 +17053,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ3_S; else chunk_size_multiplier = 4; } + else if (new_type == GGML_TYPE_IQ1_S_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ1_S; + else chunk_size_multiplier = 4; + } else if (new_type == GGML_TYPE_BF16_R16) { if (tensor->ne[1] % 16 != 0) new_type = GGML_TYPE_BF16; else chunk_size_multiplier = 16; From a6f9f2ec9af92b5a13f035db054aac2fd2efaee7 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 5 Feb 2025 14:45:51 +0200 Subject: [PATCH 07/14] iq1_s_r4: slightly faster NEON gemm/gemv (#186) Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 393 +++++++++++++++++++++++++++++++---- 1 file changed, 358 insertions(+), 35 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 559cff05..ea8e8274 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -892,6 +892,265 @@ static const uint64_t iq1s_grid_us[2048] = { 0x0202020202000002, 0x0202020202000200, 0x0202020202000202, 0x0202020202010101, 0x0202020202020000, 0x0202020202020002, 0x0202020202020200, 0x0202020202020202, }; +#else +static const uint32_t iq1s_grid_us[2048] = { + 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, + 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, + 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200, + 0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212, + 0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011, + 0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111, + 0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220, + 0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022, + 0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220, + 0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101, + 0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110, + 0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111, + 0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010, + 0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210, + 0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221, + 0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021, + 0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002, + 0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101, + 0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101, + 0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211, + 0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110, + 0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022, + 0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121, + 0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220, + 0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001, + 0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101, + 0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102, + 0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012, + 0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010, + 0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111, + 0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122, + 0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222, + 0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001, + 0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102, + 0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101, + 0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000, + 0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101, + 0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112, + 0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110, + 0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211, + 0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012, + 0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111, + 0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120, + 0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122, + 0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121, + 0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221, + 0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001, + 0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101, + 0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101, + 0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011, + 0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111, + 0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011, + 0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122, + 0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121, + 0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222, + 0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101, + 0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000, + 0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200, + 0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110, + 0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112, + 0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222, + 0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021, + 0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121, + 0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201, + 0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200, + 0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101, + 0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011, + 0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010, + 0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211, + 0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121, + 0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000, + 0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202, + 0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202, + 0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211, + 0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112, + 0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020, + 0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121, + 0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222, + 0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102, + 0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100, + 0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110, + 0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011, + 0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111, + 0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110, + 0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121, + 0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222, + 0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201, + 0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102, + 0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201, + 0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012, + 0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010, + 0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010, + 0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110, + 0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011, + 0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212, + 0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021, + 0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021, + 0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021, + 0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101, + 0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101, + 0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100, + 0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010, + 0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111, + 0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010, + 0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111, + 0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120, + 0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120, + 0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101, + 0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001, + 0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201, + 0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210, + 0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211, + 0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111, + 0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112, + 0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211, + 0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010, + 0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021, + 0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122, + 0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221, + 0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102, + 0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100, + 0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101, + 0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101, + 0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101, + 0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012, + 0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110, + 0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112, + 0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210, + 0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210, + 0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210, + 0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010, + 0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110, + 0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122, + 0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020, + 0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021, + 0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022, + 0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120, + 0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222, + 0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221, + 0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001, + 0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102, + 0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201, + 0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012, + 0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111, + 0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012, + 0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110, + 0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110, + 0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121, + 0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221, + 0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220, + 0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222, + 0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000, + 0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201, + 0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012, + 0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011, + 0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212, + 0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221, + 0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121, + 0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202, + 0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202, + 0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002, + 0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101, + 0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210, + 0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112, + 0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011, + 0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011, + 0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210, + 0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020, + 0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220, + 0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222, + 0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222, + 0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001, + 0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010, + 0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111, + 0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010, + 0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110, + 0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221, + 0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122, + 0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202, + 0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100, + 0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101, + 0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112, + 0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111, + 0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211, + 0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222, + 0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221, + 0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022, + 0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101, + 0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211, + 0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111, + 0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111, + 0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010, + 0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121, + 0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222, + 0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000, + 0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202, + 0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000, + 0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202, + 0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110, + 0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110, + 0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222, + 0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120, + 0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022, + 0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101, + 0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202, + 0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110, + 0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110, + 0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111, + 0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111, + 0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120, + 0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121, + 0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001, + 0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202, + 0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001, + 0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200, + 0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011, + 0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212, + 0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012, + 0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110, + 0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012, + 0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111, + 0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020, + 0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121, + 0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222, + 0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102, + 0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102, + 0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101, + 0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212, + 0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210, + 0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111, + 0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212, + 0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221, + 0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121, + 0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002, + 0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000, + 0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202, + 0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112, + 0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111, + 0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020, + 0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221, + 0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022, + 0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100, + 0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201, + 0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112, + 0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211, + 0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012, + 0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121, + 0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020, + 0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120, + 0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200, + 0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200, + 0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110, + 0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011, + 0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222, + 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, + 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222, +}; #endif #ifndef HAVE_FANCY_SIMD @@ -11698,15 +11957,78 @@ static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data } } +static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<1, block_q8_1_x4> q8(info); + int nb = n / 32; + GGML_ASSERT(nb%4 == 0); + int8x16_t qx[8]; + float32x4_t acc[2] = {}; + int32x4_t isum[8]; + auto ms = vdup_n_u16(0x8000); + for (int ix= 0; ix < nrc_x; ix += 4) { + auto dptr = (const ggml_half *)((const char *)vx + ix*bx); + auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr)); + auto x = (const block_iq1_s_r4 *)(dptr + 4); + for (int ib = 0; ib < nb/4; ++ib) { + auto scale_yd = vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[0][ib].d+0)); + auto scale_ym = vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[0][ib].d+4)); + for (int k = 0; k < 4; ++k) { + auto sas = vld1_u16(x[4*ib+k].qh); + auto scales4 = vand_u16(vshr_n_u16(sas, 12), vdup_n_u16(7)); + scales4 = vorr_u16(vshl_n_u16(scales4, 1), vdup_n_u16(1)); + auto signs = vreinterpret_s16_u16(vorr_u16(vceq_u16(vand_u16(sas, ms), ms), vdup_n_u16(1))); + isum[k+4] = vmull_s16(signs, scales4); + qx[0] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)], + iq1s_grid[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)]}); + qx[1] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)], + iq1s_grid[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)]}); + qx[2] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)], + iq1s_grid[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)]}); + qx[3] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)], + iq1s_grid[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)]}); + qx[4] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)], + iq1s_grid[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)]}); + qx[5] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)], + iq1s_grid[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)]}); + qx[6] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)], + iq1s_grid[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]}); + qx[7] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)], + iq1s_grid[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]}); + auto scales = vmovl_u16(scales4); + auto y = vld1q_s8_x2(q8.y[0][ib].qs + 32*k); + auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]); + auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]); + auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]); + auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]); + sumi1 = vpaddq_s32(sumi1, sumi2); + sumi3 = vpaddq_s32(sumi3, sumi4); + isum[k] = vmulq_s32(scales, vpaddq_s32(sumi1, sumi3)); + } + acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[0]), scale_yd, 0); + acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[1]), scale_yd, 1); + acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[2]), scale_yd, 2); + acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[3]), scale_yd, 3); + acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[4]), scale_ym, 0); + acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[5]), scale_ym, 1); + acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[6]), scale_ym, 2); + acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[7]), scale_ym, 3); + } + info.store(ix, 0, vmulq_f32(d1, vfmaq_f32(acc[0], acc[1], vdupq_n_f32(IQ1S_DELTA)))); + acc[0] = acc[1] = vdupq_n_f32(0.f); + } +} + template static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); Q8 q8(info); int nb = n / 32; GGML_ASSERT(nb%4 == 0); - int8x16_t qx[8]; + uint8x16_t qx[8]; int32x4_t acc[nrc_y] = {}; auto ms = vdup_n_u16(0x8000); + auto mask = vdupq_n_s8(0x03); float d8[8*nrc_y]; for (int ix= 0; ix < nrc_x; ix += 4) { auto dptr = (const ggml_half *)((const char *)vx + ix*bx); @@ -11722,42 +12044,42 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI auto sas = vld1_u16(x[4*ib+k].qh); auto scales4 = vand_u16(vshr_n_u16(sas, 12), vdup_n_u16(7)); scales4 = vorr_u16(vshl_n_u16(scales4, 1), vdup_n_u16(1)); - auto signs = vorr_u16(vceq_u16(vand_u16(sas, ms), ms), vdup_n_u16(1)); - auto delta4 = vmulq_f32(vdupq_n_f32(IQ1S_DELTA), vcvtq_f32_s32(vmull_s16(signs, scales4))); - qx[0] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)], - iq1s_grid[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)]}); - qx[1] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)], - iq1s_grid[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)]}); - qx[2] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)], - iq1s_grid[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)]}); - qx[3] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)], - iq1s_grid[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)]}); - qx[4] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)], - iq1s_grid[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]}); - qx[5] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)], - iq1s_grid[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)]}); - qx[6] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)], - iq1s_grid[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]}); - qx[7] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)], - iq1s_grid[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]}); + auto signs = vreinterpret_s16_u16(vorr_u16(vceq_u16(vand_u16(sas, ms), ms), vdup_n_u16(1))); + signs = vadd_s16(vdup_n_s16(-8), signs); + auto delta4 = vmulq_f32(vdupq_n_f32(0.125f), vcvtq_f32_s32(vmull_s16(signs, scales4))); + qx[0] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]}); + qx[2] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]}); + qx[4] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)]}); + qx[6] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]}); + qx[1] = vandq_u8(vshrq_n_u8(qx[0], 4), mask); qx[0] = vandq_u8(qx[0], mask); + qx[3] = vandq_u8(vshrq_n_u8(qx[2], 4), mask); qx[2] = vandq_u8(qx[2], mask); + qx[5] = vandq_u8(vshrq_n_u8(qx[4], 4), mask); qx[4] = vandq_u8(qx[4], mask); + qx[7] = vandq_u8(vshrq_n_u8(qx[6], 4), mask); qx[6] = vandq_u8(qx[6], mask); auto scales = vmovl_u16(scales4); for (int iy = 0; iy < nrc_y; ++iy) { - auto sumi1 = vdupq_n_s32(0); - auto sumi2 = vdupq_n_s32(0); - auto y = vld1_s8_x4(q8.y[iy][ib].qs + 32*k); - auto y1 = vcombine_s8(y.val[0], y.val[0]); - auto y2 = vcombine_s8(y.val[1], y.val[1]); - sumi1 = ggml_vdotq_s32(sumi1, qx[0], y1); - sumi2 = ggml_vdotq_s32(sumi2, qx[4], y1); - sumi1 = ggml_vdotq_s32(sumi1, qx[2], y2); - sumi2 = ggml_vdotq_s32(sumi2, qx[6], y2); - y1 = vcombine_s8(y.val[2], y.val[2]); - y2 = vcombine_s8(y.val[3], y.val[3]); - sumi1 = ggml_vdotq_s32(sumi1, qx[1], y1); - sumi2 = ggml_vdotq_s32(sumi2, qx[5], y1); - sumi1 = ggml_vdotq_s32(sumi1, qx[3], y2); - sumi2 = ggml_vdotq_s32(sumi2, qx[7], y2); - auto sumi = vmulq_s32(scales, vpaddq_s32(sumi1, sumi2)); + auto y = vld1q_s8_x2(q8.y[iy][ib].qs + 32*k); + auto sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[0]), y.val[0], 0); + sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[1]), y.val[0], 1); + sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[2]), y.val[0], 2); + sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[3]), y.val[0], 3); + sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[4]), y.val[1], 0); + sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[5]), y.val[1], 1); + sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[6]), y.val[1], 2); + sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[7]), y.val[1], 3); + sumi = vmulq_s32(scales, sumi); acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[8*iy+k+0]), vcvtq_f32_s32(sumi)); acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[8*iy+k+4]), delta4); } @@ -13391,6 +13713,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { break; case GGML_TYPE_IQ1_S_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_r4_q8_1); + m.funcs[0] = mul_mat_iq1_s_r4_q8_1_1; m.func16 = mul_mat_iq1_s_r4_q8_1<16>; expected_Btype = GGML_TYPE_Q8_1_X4; break; From 7f61b3068e18728e5e7e2b95546ff03dd2fd41ac Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 6 Feb 2025 14:08:52 +0200 Subject: [PATCH 08/14] IQ1_M_R4: better 1.75 bpw quants (#187) * iq1_m_r4: basics (quantize/dequantize) * iq1_m_r4: Zen4 gemm * iq1_m_r4: neon gemm * iq1_m_r4: switch to q8_0_x4 also on AVX2/Zen4 With the deltas being per group of 8, we cannot make use of the q8 sums stored in q8_1, so we get a tiny gain by using q8_0_x4. * iq1_m_r4: rename mul_mat_iq1_m_r4_q8_1 to mul_mat_iq1_m_r4_q8_0 --------- Co-authored-by: Iwan Kawrakow --- examples/quantize/quantize.cpp | 2 + ggml/include/ggml.h | 2 + ggml/src/ggml-common.h | 8 + ggml/src/ggml-quants.c | 403 +++++++++++++++------------------ ggml/src/ggml-quants.h | 4 + ggml/src/ggml.c | 27 ++- ggml/src/iqk/iqk_mul_mat.cpp | 197 ++++++++++++++++ ggml/src/iqk/iqk_quantize.cpp | 117 ++++++++++ ggml/src/iqk/iqk_quantize.h | 6 + include/llama.h | 1 + src/llama.cpp | 16 +- 11 files changed, 553 insertions(+), 230 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 1c847e6b..7bdd8597 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -30,6 +30,7 @@ static const std::vector QUANT_OPTIONS = { { "IQ2_M_R4", LLAMA_FTYPE_MOSTLY_IQ2_M_R4, " 2.7 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "IQ1_S_R4", LLAMA_FTYPE_MOSTLY_IQ1_S_R4, " 1.5 bpw quantization", }, + { "IQ1_M_R4", LLAMA_FTYPE_MOSTLY_IQ1_M_R4, " 1.75 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, { "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.62 bpw quantization (Bitnet)", }, { "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", }, @@ -512,6 +513,7 @@ int main(int argc, char ** argv) { params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS_R4 || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4 || + params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_M_R4 || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_M)) { fprintf(stderr, "\n==========================================================================================================\n"); fprintf(stderr, "Please do not use IQ1_S, IQ1_M, IQ2_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n"); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 9668dc32..77ac33a9 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -432,6 +432,7 @@ extern "C" { GGML_TYPE_IQ3_S_R4 = 221, GGML_TYPE_IQ2_S_R4 = 222, GGML_TYPE_IQ4_XS_R4 = 223, + GGML_TYPE_IQ1_M_R4 = 229, GGML_TYPE_BF16_R16 = 230, GGML_TYPE_Q6_0_R4 = 233, GGML_TYPE_IQ2_BN_R4 = 335, @@ -516,6 +517,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ3_S_R4 = 220, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_S_R4 = 221, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_XS_R4 = 222, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ1_M_R4 = 223, // except 1d tensors GGML_FTYPE_MOSTLY_BF16_R16 = 224, // except 1d tensors GGML_FTYPE_MOSTLY_Q6_0_R4 = 227, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_BN_R4 = 329, // except 1d tensors diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 14813161..679353be 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -499,6 +499,14 @@ typedef struct { } block_iq1_m; static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding"); +// 1.75 bpw - blocks of 32 with 4 interleaved rows = 128 quants +typedef struct { + uint8_t qs[16]; // grid index, low 8 bits + uint8_t qh[ 8]; // grid index, high 3 bits + grid shift bits (for two groups of 8) + uint8_t scales[4]; // 4-bit block scales +} block_iq1_m_r4; +static_assert(sizeof(block_iq1_m_r4) == 28, "wrong iq1_m_r4 block size/padding"); + // // Bitnet and TriLM - implemented as 1.625 bpw // diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 3c4711f3..d32a583f 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -14145,85 +14145,6 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy int best_shift; iq1s_process_1block(block_size, xb, weight, L, &scales[ib], index, &best_shift, pairs, sumx, sumw); -// float max = fabsf(xb[0]); -// for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); -// if (max < GROUP_MAX_EPS_IQ1_S) { -// scales[ib] = 0; -// memset(L, 1, block_size); -// continue; -// } -// // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem. -// // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two -// // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights -// // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and -// // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale -// // for each possible and score for each split. -// for (int j = 0; j < block_size; ++j) { -// pairs[2*j] = xb[j]; -// idx[2*j] = j; -// } -// qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); -// { -// sumx[0] = sumw[0] = 0; -// for (int j = 0; j < block_size; ++j) { -// int i = idx[2*j]; -// sumx[j+1] = sumx[j] + weight[i]*xb[i]; -// sumw[j+1] = sumw[j] + weight[i]; -// } -// } -// float best_score = -FLT_MIN, scale = max; -// int besti1 = -1, besti2 = -1, best_shift = 0; -// for (int i1 = 0; i1 <= block_size; ++i1) { -// for (int i2 = i1; i2 <= block_size; ++i2) { -// float sumqx = (sumx[i1] - sumx[0])*x_p[0] + (sumx[i2] - sumx[i1])*x_p[1] + (sumx[block_size] - sumx[i2])*x_p[2]; -// float sumq2 = (sumw[i1] - sumw[0])*x_p[0]*x_p[0] + (sumw[i2] - sumw[i1])*x_p[1]*x_p[1] + (sumw[block_size] - sumw[i2])*x_p[2]*x_p[2]; -// if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { -// scale = sumqx/sumq2; best_score = scale*sumqx; -// besti1 = i1; besti2 = i2; best_shift = 1; -// } -// sumqx = (sumx[i1] - sumx[0])*x_m[0] + (sumx[i2] - sumx[i1])*x_m[1] + (sumx[block_size] - sumx[i2])*x_m[2]; -// sumq2 = (sumw[i1] - sumw[0])*x_m[0]*x_m[0] + (sumw[i2] - sumw[i1])*x_m[1]*x_m[1] + (sumw[block_size] - sumw[i2])*x_m[2]*x_m[2]; -// if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { -// scale = sumqx/sumq2; best_score = scale*sumqx; -// besti1 = i1; besti2 = i2; best_shift = -1; -// } -// } -// } -// GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0); -// for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; -// for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; -// for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; -// if (scale < 0) { -// for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j]; -// scale = -scale; best_shift = -best_shift; -// } -// bool all_on_grid = true; -// const float * xx = best_shift == 1 ? x_p : x_m; -// for (int k = 0; k < block_size/8; ++k) { -// uint16_t u = 0; -// for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j); -// int grid_index = kmap_q2xs[u]; -// if (grid_index < 0) { -// all_on_grid = false; -// const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; -// grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S); -// GGML_ASSERT(grid_index >= 0); -// } -// index[k] = grid_index; -// } -// if (!all_on_grid) { -// float sumqx = 0, sumq2 = 0; -// for (int k = 0; k < block_size/8; ++k) { -// const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]); -// for (int j = 0; j < 8; ++j) { -// float w = weight[8*k + j]; -// float q = xx[(pg[j] - 1)/2]; -// sumqx += w*q*xb[8*k+j]; -// sumq2 += w*q*q; -// } -// } -// if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2; -// } uint16_t h = 0; for (int k = 0; k < block_size/8; ++k) { y[ibl].qs[(block_size/8)*ib + k] = index[k] & 255; @@ -14232,10 +14153,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy y[ibl].qh[ib] = h; GGML_ASSERT(scales[ib] >= 0); max_scale = MAX(max_scale, scales[ib]); - //GGML_ASSERT(scale >= 0); - //scales[ib] = scale; shifts[ib] = best_shift; - //max_scale = MAX(max_scale, scale); } if (!max_scale) { @@ -14287,6 +14205,166 @@ void quantize_row_iq1_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, quantize_row_iq1_s_ref(x, (block_iq1_s *)y, k); } +void iq1m_process_1block(const float * xb, const float * weight, int8_t * L, float * the_scale, uint16_t * the_index, int * the_shift, + float * pairs) { + + const int block_size = IQ1M_BLOCK_SIZE; + + const float x_p[3] = {-1 + IQ1M_DELTA, IQ1M_DELTA, 1 + IQ1M_DELTA}; + const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA}; + + float sumqx[4], sumq2[4]; + + const int gindex = iq2_data_index(GGML_TYPE_IQ1_M); + + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; + + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + + // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem. + // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two + // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights + // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and + // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale + // for each possible and score for each split. + int * idx = (int *)(pairs + 1); + for (int j = 0; j < block_size; ++j) { + pairs[2*j] = xb[j]; + idx[2*j] = j; + } + qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); + float best_score = -FLT_MIN, scale = 0.f; + int besti1 = -1, besti2 = -1, best_k = -1; + // 0: +, + + // 1: +, - + // 2: -, + + // 3: -, - + for (int i1 = 0; i1 <= block_size; ++i1) { + for (int i2 = i1; i2 <= block_size; ++i2) { + memset(sumqx, 0, 4*sizeof(float)); + memset(sumq2, 0, 4*sizeof(float)); + for (int j = 0; j < i1; ++j) { + int i = idx[2*j]; + if (i < block_size/2) { + sumqx[0] += weight[i]*x_p[0]*xb[i]; + sumqx[1] += weight[i]*x_p[0]*xb[i]; + sumqx[2] += weight[i]*x_m[0]*xb[i]; + sumqx[3] += weight[i]*x_m[0]*xb[i]; + sumq2[0] += weight[i]*x_p[0]*x_p[0]; + sumq2[1] += weight[i]*x_p[0]*x_p[0]; + sumq2[2] += weight[i]*x_m[0]*x_m[0]; + sumq2[3] += weight[i]*x_m[0]*x_m[0]; + } else { + sumqx[0] += weight[i]*x_p[0]*xb[i]; + sumqx[2] += weight[i]*x_p[0]*xb[i]; + sumqx[1] += weight[i]*x_m[0]*xb[i]; + sumqx[3] += weight[i]*x_m[0]*xb[i]; + sumq2[0] += weight[i]*x_p[0]*x_p[0]; + sumq2[2] += weight[i]*x_p[0]*x_p[0]; + sumq2[1] += weight[i]*x_m[0]*x_m[0]; + sumq2[3] += weight[i]*x_m[0]*x_m[0]; + } + } + for (int j = i1; j < i2; ++j) { + int i = idx[2*j]; + if (i < block_size/2) { + sumqx[0] += weight[i]*x_p[1]*xb[i]; + sumqx[1] += weight[i]*x_p[1]*xb[i]; + sumqx[2] += weight[i]*x_m[1]*xb[i]; + sumqx[3] += weight[i]*x_m[1]*xb[i]; + sumq2[0] += weight[i]*x_p[1]*x_p[1]; + sumq2[1] += weight[i]*x_p[1]*x_p[1]; + sumq2[2] += weight[i]*x_m[1]*x_m[1]; + sumq2[3] += weight[i]*x_m[1]*x_m[1]; + } else { + sumqx[0] += weight[i]*x_p[1]*xb[i]; + sumqx[2] += weight[i]*x_p[1]*xb[i]; + sumqx[1] += weight[i]*x_m[1]*xb[i]; + sumqx[3] += weight[i]*x_m[1]*xb[i]; + sumq2[0] += weight[i]*x_p[1]*x_p[1]; + sumq2[2] += weight[i]*x_p[1]*x_p[1]; + sumq2[1] += weight[i]*x_m[1]*x_m[1]; + sumq2[3] += weight[i]*x_m[1]*x_m[1]; + } + } + for (int j = i2; j < block_size; ++j) { + int i = idx[2*j]; + if (i < block_size/2) { + sumqx[0] += weight[i]*x_p[2]*xb[i]; + sumqx[1] += weight[i]*x_p[2]*xb[i]; + sumqx[2] += weight[i]*x_m[2]*xb[i]; + sumqx[3] += weight[i]*x_m[2]*xb[i]; + sumq2[0] += weight[i]*x_p[2]*x_p[2]; + sumq2[1] += weight[i]*x_p[2]*x_p[2]; + sumq2[2] += weight[i]*x_m[2]*x_m[2]; + sumq2[3] += weight[i]*x_m[2]*x_m[2]; + } else { + sumqx[0] += weight[i]*x_p[2]*xb[i]; + sumqx[2] += weight[i]*x_p[2]*xb[i]; + sumqx[1] += weight[i]*x_m[2]*xb[i]; + sumqx[3] += weight[i]*x_m[2]*xb[i]; + sumq2[0] += weight[i]*x_p[2]*x_p[2]; + sumq2[2] += weight[i]*x_p[2]*x_p[2]; + sumq2[1] += weight[i]*x_m[2]*x_m[2]; + sumq2[3] += weight[i]*x_m[2]*x_m[2]; + } + } + for (int k = 0; k < 4; ++k) { + if (sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) { + scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k]; + besti1 = i1; besti2 = i2; best_k = k; + } + } + } + } + GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_k >= 0); + for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; + for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; + for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; + if (scale < 0) { + for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j]; + scale = -scale; + best_k = 3 - best_k; + } + bool all_on_grid = true; + const float * xx; + for (int k = 0; k < block_size/8; ++k) { + if (k == 0) xx = best_k < 2 ? x_p : x_m; + else xx = best_k%2 == 0 ? x_p : x_m; + uint16_t u = 0; + for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + all_on_grid = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S); + GGML_ASSERT(grid_index >= 0); + } + the_index[k] = grid_index; + } + if (!all_on_grid) { + float sumqx_f = 0, sumq2_f = 0; + for (int k = 0; k < block_size/8; ++k) { + if (k == 0) xx = best_k < 2 ? x_p : x_m; + else xx = best_k%2 == 0 ? x_p : x_m; + const int8_t * pg = (const int8_t *)(kgrid_q2xs + the_index[k]); + for (int j = 0; j < 8; ++j) { + float w = weight[8*k + j]; + float q = xx[(pg[j] - 1)/2]; + sumqx_f += w*q*xb[8*k+j]; + sumq2_f += w*q*q; + } + } + if (sumqx_f > 0 && sumq2_f > 0) scale = sumqx_f/sumq2_f; + } + *the_scale = scale; + *the_shift = best_k; +} + static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights, float * scales, float * weight, @@ -14301,7 +14379,6 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy const int * kmap_q2xs = iq2_data[gindex].map; const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; - //GGML_ASSERT(quant_weights && "missing quantization weights"); GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); @@ -14317,10 +14394,6 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA}; const uint8_t masks[4] = {0x00, 0x80, 0x08, 0x88}; - int * idx = (int *)(pairs + 1); - - float sumqx[4], sumq2[4]; - iq1m_scale_t s; const float * xx; @@ -14351,147 +14424,15 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy memset(L, 1, block_size); continue; } - // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem. - // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two - // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights - // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and - // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale - // for each possible and score for each split. - for (int j = 0; j < block_size; ++j) { - pairs[2*j] = xb[j]; - idx[2*j] = j; - } - qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); - float best_score = -FLT_MIN, scale = max; - int besti1 = -1, besti2 = -1, best_k = -1; - // 0: +, + - // 1: +, - - // 2: -, + - // 3: -, - - for (int i1 = 0; i1 <= block_size; ++i1) { - for (int i2 = i1; i2 <= block_size; ++i2) { - memset(sumqx, 0, 4*sizeof(float)); - memset(sumq2, 0, 4*sizeof(float)); - for (int j = 0; j < i1; ++j) { - int i = idx[2*j]; - if (i < block_size/2) { - sumqx[0] += weight[i]*x_p[0]*xb[i]; - sumqx[1] += weight[i]*x_p[0]*xb[i]; - sumqx[2] += weight[i]*x_m[0]*xb[i]; - sumqx[3] += weight[i]*x_m[0]*xb[i]; - sumq2[0] += weight[i]*x_p[0]*x_p[0]; - sumq2[1] += weight[i]*x_p[0]*x_p[0]; - sumq2[2] += weight[i]*x_m[0]*x_m[0]; - sumq2[3] += weight[i]*x_m[0]*x_m[0]; - } else { - sumqx[0] += weight[i]*x_p[0]*xb[i]; - sumqx[2] += weight[i]*x_p[0]*xb[i]; - sumqx[1] += weight[i]*x_m[0]*xb[i]; - sumqx[3] += weight[i]*x_m[0]*xb[i]; - sumq2[0] += weight[i]*x_p[0]*x_p[0]; - sumq2[2] += weight[i]*x_p[0]*x_p[0]; - sumq2[1] += weight[i]*x_m[0]*x_m[0]; - sumq2[3] += weight[i]*x_m[0]*x_m[0]; - } - } - for (int j = i1; j < i2; ++j) { - int i = idx[2*j]; - if (i < block_size/2) { - sumqx[0] += weight[i]*x_p[1]*xb[i]; - sumqx[1] += weight[i]*x_p[1]*xb[i]; - sumqx[2] += weight[i]*x_m[1]*xb[i]; - sumqx[3] += weight[i]*x_m[1]*xb[i]; - sumq2[0] += weight[i]*x_p[1]*x_p[1]; - sumq2[1] += weight[i]*x_p[1]*x_p[1]; - sumq2[2] += weight[i]*x_m[1]*x_m[1]; - sumq2[3] += weight[i]*x_m[1]*x_m[1]; - } else { - sumqx[0] += weight[i]*x_p[1]*xb[i]; - sumqx[2] += weight[i]*x_p[1]*xb[i]; - sumqx[1] += weight[i]*x_m[1]*xb[i]; - sumqx[3] += weight[i]*x_m[1]*xb[i]; - sumq2[0] += weight[i]*x_p[1]*x_p[1]; - sumq2[2] += weight[i]*x_p[1]*x_p[1]; - sumq2[1] += weight[i]*x_m[1]*x_m[1]; - sumq2[3] += weight[i]*x_m[1]*x_m[1]; - } - } - for (int j = i2; j < block_size; ++j) { - int i = idx[2*j]; - if (i < block_size/2) { - sumqx[0] += weight[i]*x_p[2]*xb[i]; - sumqx[1] += weight[i]*x_p[2]*xb[i]; - sumqx[2] += weight[i]*x_m[2]*xb[i]; - sumqx[3] += weight[i]*x_m[2]*xb[i]; - sumq2[0] += weight[i]*x_p[2]*x_p[2]; - sumq2[1] += weight[i]*x_p[2]*x_p[2]; - sumq2[2] += weight[i]*x_m[2]*x_m[2]; - sumq2[3] += weight[i]*x_m[2]*x_m[2]; - } else { - sumqx[0] += weight[i]*x_p[2]*xb[i]; - sumqx[2] += weight[i]*x_p[2]*xb[i]; - sumqx[1] += weight[i]*x_m[2]*xb[i]; - sumqx[3] += weight[i]*x_m[2]*xb[i]; - sumq2[0] += weight[i]*x_p[2]*x_p[2]; - sumq2[2] += weight[i]*x_p[2]*x_p[2]; - sumq2[1] += weight[i]*x_m[2]*x_m[2]; - sumq2[3] += weight[i]*x_m[2]*x_m[2]; - } - } - for (int k = 0; k < 4; ++k) { - if (sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) { - scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k]; - besti1 = i1; besti2 = i2; best_k = k; - } - } - } - } - GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_k >= 0); - for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; - for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; - for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; - if (scale < 0) { - for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j]; - scale = -scale; - best_k = best_k == 0 ? 3 : best_k == 1 ? 2 : best_k == 2 ? 1 : 0; - } - bool all_on_grid = true; - for (int k = 0; k < block_size/8; ++k) { - if (k == 0) xx = best_k < 2 ? x_p : x_m; - else xx = best_k%2 == 0 ? x_p : x_m; - uint16_t u = 0; - for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j); - int grid_index = kmap_q2xs[u]; - if (grid_index < 0) { - all_on_grid = false; - const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; - grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S); - GGML_ASSERT(grid_index >= 0); - } - index[k] = grid_index; - } - if (!all_on_grid) { - float sumqx_f = 0, sumq2_f = 0; - for (int k = 0; k < block_size/8; ++k) { - if (k == 0) xx = best_k < 2 ? x_p : x_m; - else xx = best_k%2 == 0 ? x_p : x_m; - const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]); - for (int j = 0; j < 8; ++j) { - float w = weight[8*k + j]; - float q = xx[(pg[j] - 1)/2]; - sumqx_f += w*q*xb[8*k+j]; - sumq2_f += w*q*q; - } - } - if (sumqx_f > 0 && sumq2_f > 0) scale = sumqx_f/sumq2_f; - } + + int best_k = -1; + iq1m_process_1block(xb, weight, L, &scales[ib], index, &best_k, pairs); + y[ibl].qs[2*ib + 0] = index[0] & 255; y[ibl].qs[2*ib + 1] = index[1] & 255; y[ibl].qh[ib] = (index[0] >> 8) | ((index[1] >> 8) << 4); - GGML_ASSERT(scale >= 0); - scales[ib] = scale; shifts[ib] = best_k; - max_scale = MAX(max_scale, scale); + max_scale = MAX(max_scale, scales[ib]); } if (!max_scale) { @@ -14553,6 +14494,19 @@ size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int64_t n return nrow * nblock * sizeof(block_iq1_m); } +void quantize_row_iq1_m_ref (const float * GGML_RESTRICT x, block_iq1_m * GGML_RESTRICT y, int64_t k) { + int nblock = k/QK_K; + float qw[QK_K]; + for (int j = 0; j < QK_K; ++j) qw[j] = 1; + for (int ibl = 0; ibl < nblock; ++ibl) { + quantize_iq1_m(x + ibl*QK_K, &y[ibl], 1, QK_K, qw); + } +} + +void quantize_row_iq1_m (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_iq1_m_ref(x, (block_iq1_m *)y, k); +} + // ============================ 4-bit non-linear quants static const int8_t iq4nl_index[241] = { @@ -15246,6 +15200,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_IQ3_S_R4: break; case GGML_TYPE_IQ2_S_R4: break; case GGML_TYPE_IQ1_S_R4: break; + case GGML_TYPE_IQ1_M_R4: break; case GGML_TYPE_Q4_0_R4: break; case GGML_TYPE_Q5_0_R4: break; case GGML_TYPE_Q6_0_R4: break; diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 4753f342..7c8e2110 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -43,6 +43,7 @@ void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGM void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); void quantize_row_iq1_bn_ref (const float * GGML_RESTRICT x, block_iq1_bn * GGML_RESTRICT y, int64_t k); void quantize_row_iq1_s_ref (const float * GGML_RESTRICT x, block_iq1_s * GGML_RESTRICT y, int64_t k); +void quantize_row_iq1_m_ref (const float * GGML_RESTRICT x, block_iq1_m * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -68,6 +69,7 @@ void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq1_bn (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq1_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq1_m (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dequantization void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -152,6 +154,8 @@ void iq3xs_free_impl(int grid_size); void iq1s_process_1block(int block_size, const float * xb, const float * weight, int8_t * L, float * the_scale, uint16_t * the_index, int * the_shift, float * pairs, float * sumx, float * sumw); +void iq1m_process_1block(const float * xb, const float * weight, int8_t * L, + float * the_scale, uint16_t * the_index, int * the_shift, float * pairs); #if defined(__ARM_FEATURE_SVE) extern int ggml_sve_cnt_b; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 64b7d3ce..4199a282 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1202,13 +1202,26 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq1_m), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq1_m, - .from_float = NULL, - .from_float_ref = NULL, + .from_float = quantize_row_iq1_m, + .from_float_ref = (ggml_from_float_t)quantize_row_iq1_m_ref, .vec_dot = ggml_vec_dot_iq1_m_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_IQ1_M_R4] = { + .type_name = "iq1_m_r4", + .blck_size = 32, + .type_size = sizeof(block_iq1_m_r4)/4, + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq1_m_r4, + .from_float = quantize_row_iq1_m_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_iq1_m_r4_ref, + .vec_dot = vec_dot_iq1_m_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_0_X4, + .nrows = 1, + .row_meta_size = 2, + }, [GGML_TYPE_IQ1_BN] = { .type_name = "iq1_bn", .blck_size = QK_IQ1BN, @@ -4401,6 +4414,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; case GGML_FTYPE_MOSTLY_IQ2_S_R4: wtype = GGML_TYPE_IQ2_S_R4; break; case GGML_FTYPE_MOSTLY_IQ1_S_R4: wtype = GGML_TYPE_IQ1_S_R4; break; + case GGML_FTYPE_MOSTLY_IQ1_M_R4: wtype = GGML_TYPE_IQ1_M_R4; break; case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break; case GGML_FTYPE_MOSTLY_Q4_0_4_8: wtype = GGML_TYPE_Q4_0_4_8; break; case GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = GGML_TYPE_Q4_0_8_8; break; @@ -10949,6 +10963,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_IQ1_S_R4: + case GGML_TYPE_IQ1_M_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -11418,6 +11433,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_IQ1_S_R4: + case GGML_TYPE_IQ1_M_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -11584,6 +11600,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_IQ1_S_R4: + case GGML_TYPE_IQ1_M_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -14823,6 +14840,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_IQ1_S_R4: + case GGML_TYPE_IQ1_M_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -15229,6 +15247,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_IQ1_S_R4: + case GGML_TYPE_IQ1_M_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -15529,6 +15548,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_IQ1_S_R4: + case GGML_TYPE_IQ1_M_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -16158,6 +16178,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_IQ1_S_R4: + case GGML_TYPE_IQ1_M_R4: case GGML_TYPE_Q8_K: case GGML_TYPE_Q8_K64: case GGML_TYPE_Q8_K16: @@ -22914,6 +22935,7 @@ void ggml_quantize_init(enum ggml_type type) { case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: iq2xs_init_impl(type); break; + case GGML_TYPE_IQ1_M_R4:iq2xs_init_impl(GGML_TYPE_IQ1_M); break; case GGML_TYPE_IQ1_S_R4:iq2xs_init_impl(GGML_TYPE_IQ1_S); break; case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break; @@ -22998,6 +23020,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ2_S: result = quantize_iq2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_S_R4:result = quantize_iq2_s_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ1_S_R4:result = quantize_iq1_s_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ1_M_R4:result = quantize_iq1_m_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ1_BN: result = quantize_iq1_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ea8e8274..57024602 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -260,6 +260,7 @@ struct MulMat { case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ1_S_R4: + case GGML_TYPE_IQ1_M_R4: case GGML_TYPE_IQ3_S_R4: return 4; case GGML_TYPE_IQ4_NL_R4: case GGML_TYPE_Q5_0_R4: @@ -295,6 +296,7 @@ struct MulMat { case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ1_S_R4: + case GGML_TYPE_IQ1_M_R4: case GGML_TYPE_IQ2_BN_R4: return 4; case GGML_TYPE_IQ4_XS_R4: case GGML_TYPE_Q4_0_R4: @@ -3609,6 +3611,102 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI } } +template +static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nb = n / 32; + GGML_ASSERT(nb%4 == 0); + auto shuffle0 = _mm256_set_epi64x(0x0909090909090909, 0x0808080808080808, 0x0101010101010101, 0x0000000000000000); + auto step = _mm256_set1_epi8(2); +#ifndef HAVE_FANCY_SIMD + auto m1 = _mm256_set1_epi16(1); +#endif + __m256i qx[4]; + __m256 acc[nrc_y] = {}; + auto ms = _mm_set1_epi8(0x08); + float d8[4*nrc_y]; + union { __m256i vec; uint16_t val[16]; } helper; + for (int ix= 0; ix < nrc_x; ix += 4) { + auto dptr = (const ggml_half *)((const char *)vx + ix*bx); + auto d1 = _mm_mul_ps(_mm_set1_ps(0.125f), _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)dptr))); + auto x = (const block_iq1_m_r4 *)(dptr + 4); + for (int ib = 0; ib < nb/4; ++ib) { + for (int iy = 0; iy < nrc_y; ++iy) { + _mm_storeu_ps(d8 + 4*iy, _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib].d))); + } + for (int k = 0; k < 4; ++k) { + auto qh = (const uint32_t *)x[4*ib+k].qh; + auto idxh = _mm_set_epi32(qh[1] >> 4, qh[1], qh[0] >> 4, qh[0]); + auto scales4 = _mm_set1_epi32(((const uint32_t *)x[4*ib+k].scales)[0]); + scales4 = _mm_and_si128(_mm_srlv_epi32(scales4, _mm_set_epi32(4, 0, 4, 0)), _mm_set1_epi8(0xf)); + scales4 = _mm_cvtepu8_epi16(scales4); + auto scales = MM256_SET_M128I(_mm_unpackhi_epi16(scales4, scales4), _mm_unpacklo_epi16(scales4, scales4)); + + auto signs128 = _mm_or_si128(_mm_cmpeq_epi8(_mm_and_si128(idxh, ms), ms), _mm_set1_epi8(1)); + signs128 = _mm_add_epi8(_mm_set1_epi8(-8), signs128); + auto signs = MM256_SET_M128I(signs128, signs128); + auto idxl = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)x[4*ib+k].qs)); + idxh = _mm_and_si128(idxh, _mm_set1_epi8(0x07)); + helper.vec = _mm256_or_si256(idxl, _mm256_slli_epi16(_mm256_cvtepu8_epi16(idxh), 8)); + qx[0] = _mm256_set_epi64x(iq1s_grid_us[helper.val[ 9]], iq1s_grid_us[helper.val[ 8]], + iq1s_grid_us[helper.val[ 1]], iq1s_grid_us[helper.val[ 0]]); + qx[1] = _mm256_set_epi64x(iq1s_grid_us[helper.val[13]], iq1s_grid_us[helper.val[12]], + iq1s_grid_us[helper.val[ 5]], iq1s_grid_us[helper.val[ 4]]); + qx[2] = _mm256_set_epi64x(iq1s_grid_us[helper.val[11]], iq1s_grid_us[helper.val[10]], + iq1s_grid_us[helper.val[ 3]], iq1s_grid_us[helper.val[ 2]]); + qx[3] = _mm256_set_epi64x(iq1s_grid_us[helper.val[15]], iq1s_grid_us[helper.val[14]], + iq1s_grid_us[helper.val[ 7]], iq1s_grid_us[helper.val[ 6]]); + qx[0] = _mm256_add_epi8(_mm256_slli_epi16(qx[0], 3), _mm256_shuffle_epi8(signs, shuffle0)); + auto shuffle = _mm256_add_epi8(shuffle0, step); + qx[2] = _mm256_add_epi8(_mm256_slli_epi16(qx[2], 3), _mm256_shuffle_epi8(signs, shuffle)); + shuffle = _mm256_add_epi8(shuffle, step); + qx[1] = _mm256_add_epi8(_mm256_slli_epi16(qx[1], 3), _mm256_shuffle_epi8(signs, shuffle)); + shuffle = _mm256_add_epi8(shuffle, step); + qx[3] = _mm256_add_epi8(_mm256_slli_epi16(qx[3], 3), _mm256_shuffle_epi8(signs, shuffle)); + auto s0 = _mm256_sign_epi8(qx[0], qx[0]); + auto s1 = _mm256_sign_epi8(qx[1], qx[1]); + auto s2 = _mm256_sign_epi8(qx[2], qx[2]); + auto s3 = _mm256_sign_epi8(qx[3], qx[3]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ib].qs + k); + auto y1 = _mm256_shuffle_epi32(y, 0x44); + auto y2 = _mm256_shuffle_epi32(y, 0xee); +#ifdef HAVE_FANCY_SIMD + // 0,0, 1,1, 0,0, 1,1 as int32_t + auto sumi1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), + s0, _mm256_sign_epi8(y1, qx[0])), s1, _mm256_sign_epi8(y2, qx[1])); + // 2,2, 3,3, 2,2, 3,3 as int32_t + auto sumi2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), + s2, _mm256_sign_epi8(y1, qx[2])), s3, _mm256_sign_epi8(y2, qx[3])); + auto sumi = _mm256_packs_epi32(sumi1, sumi2); +#else + // 4 x row 0, 4 x row 1, 4 x row 0, 4 x row 1 + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(s0, _mm256_sign_epi8(y1, qx[0])), + _mm256_maddubs_epi16(s1, _mm256_sign_epi8(y2, qx[1]))); + // 4 x row 2, 4 x row 3, 4 x row 2, 4 x row 3 + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(s2, _mm256_sign_epi8(y1, qx[2])), + _mm256_maddubs_epi16(s3, _mm256_sign_epi8(y2, qx[3]))); + // 0,0, 1,1, 0,0, 1,1 as int32_t + sumi1 = _mm256_madd_epi16(m1, sumi1); + // 2,2, 3,3, 2,2, 3,3 as int32_t + sumi2 = _mm256_madd_epi16(m1, sumi2); + // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 as int16_t + auto sumi = _mm256_packs_epi32(sumi1, sumi2); +#endif + sumi = _mm256_madd_epi16(scales, sumi); + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[4*iy+k]), _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumf = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, _mm_mul_ps(d1, sumf)); + acc[iy] = _mm256_setzero_ps(); + } + } +} + #ifdef HAVE_FANCY_SIMD template static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { @@ -9081,6 +9179,21 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { #endif expected_typeB = GGML_TYPE_Q8_1_X4; break; + case GGML_TYPE_IQ1_M_R4: + assert (ne00 % QK4_NL == 0); + mm.funcs[0] = mul_mat_iq1_m_r4_q8_0<1>; + mm.funcs[1] = mul_mat_iq1_m_r4_q8_0<2>; + mm.funcs[2] = mul_mat_iq1_m_r4_q8_0<3>; + mm.funcs[3] = mul_mat_iq1_m_r4_q8_0<4>; + mm.funcs[4] = mul_mat_iq1_m_r4_q8_0<5>; + mm.funcs[5] = mul_mat_iq1_m_r4_q8_0<6>; + mm.funcs[6] = mul_mat_iq1_m_r4_q8_0<7>; + mm.funcs[7] = mul_mat_iq1_m_r4_q8_0<8>; +#ifdef HAVE_FANCY_SIMD + mm.func16 = mul_mat_iq1_m_r4_q8_0<16>; +#endif + expected_typeB = GGML_TYPE_Q8_0_X4; + break; default: return false; @@ -12092,6 +12205,85 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI } } +template +static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nb = n / 32; + GGML_ASSERT(nb%4 == 0); + int8x16_t qx[8]; + int32x4_t acc[nrc_y] = {}; + auto shuffle0 = uint32x4_t{0x00000000, 0x01010101, 0x02020202, 0x03030303}; + auto step = vdupq_n_u8(4); + auto ms = vdupq_n_u8(0x08); + auto mask = vdupq_n_s8(0x18); + float d8[4*nrc_y]; + for (int ix= 0; ix < nrc_x; ix += 4) { + auto dptr = (const ggml_half *)((const char *)vx + ix*bx); + auto d1 = vmulq_f32(vdupq_n_f32(0.125f), vcvt_f32_f16(vld1_f16((const float16_t *)dptr))); + auto x = (const block_iq1_m_r4 *)(dptr + 4); + for (int ib = 0; ib < nb/4; ++ib) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto scales = vld1_f16((const float16_t *)q8.y[iy][ib].d); + vst1q_f32(d8+4*iy, vcvt_f32_f16(scales)); + } + for (int k = 0; k < 4; ++k) { + auto scales4 = vdup_n_u32(((const uint32_t *)x[4*ib+k].scales)[0]); + scales4 = vand_u8(vshl_u32(scales4, int32x2_t{0, -4}), vdup_n_u8(0xf)); + auto scales16 = vmovl_u8(scales4); + auto scales1 = vmovl_u16(vget_low_u16(scales16)); + auto scales2 = vmovl_u16(vget_high_u16(scales16)); + auto qh = (const uint32_t *)x[4*ib+k].qh; + auto idxh = uint32x4_t{qh[0], qh[0] >> 4, qh[1], qh[1] >> 4}; + auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(idxh, ms), ms), vdupq_n_u8(1))); + signs = vaddq_s8(signs, vdupq_n_s8(-8)); + qx[0] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]}); + qx[2] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 4) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 4) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 4) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 4) & 0x0700)]}); + qx[4] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[4] << 8) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[5] << 8) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[6] << 8) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[7] << 8) & 0x0700)]}); + qx[6] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[4] << 4) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[5] << 4) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[6] << 4) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[7] << 4) & 0x0700)]}); + auto shuffle = shuffle0; + for (int j = 0; j < 4; ++j) { + auto s = vqtbl1q_s8(signs, shuffle); + qx[2*j+1] = vaddq_s8(s, vandq_s8(vshrq_n_s8(qx[2*j+0], 1), mask)); + qx[2*j+0] = vaddq_s8(s, vandq_s8(vshlq_n_s8(qx[2*j+0], 3), mask)); + shuffle = vaddq_u8(shuffle, step); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ib].qs + 32*k); + auto sumi1 = vdupq_n_s32(0); + auto sumi2 = vdupq_n_s32(0); + sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[0]), y.val[0], 0); + sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[1]), y.val[0], 1); + sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[2]), y.val[0], 2); + sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[3]), y.val[0], 3); + sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[4]), y.val[1], 0); + sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[5]), y.val[1], 1); + sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[6]), y.val[1], 2); + sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[7]), y.val[1], 3); + auto sumi = vmlaq_s32(vmlaq_s32(vdupq_n_s32(0), sumi1, scales1), sumi2, scales2); + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[4*iy+k]), vcvtq_f32_s32(sumi)); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vmulq_f32(d1, acc[iy])); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + template static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -13717,6 +13909,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { m.func16 = mul_mat_iq1_s_r4_q8_1<16>; expected_Btype = GGML_TYPE_Q8_1_X4; break; + case GGML_TYPE_IQ1_M_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_m_r4_q8_0); + m.func16 = mul_mat_iq1_m_r4_q8_0<16>; + expected_Btype = GGML_TYPE_Q8_0_X4; + break; case GGML_TYPE_IQ3_XXS_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_xxs_r4_q8_k); m.func16 = mul_mat_iq3_xxs_r4_q8_k<16>; diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index a8553b43..e741a8ea 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -6193,6 +6193,123 @@ void vec_dot_iq1_s_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t GGML_UNUSED(by); } +void quantize_row_iq1_m_r4_ref(const float * x, block_iq1_m_r4 * y, int64_t k) { + quantize_iq1_m_r4(x, y, 4, k/4, nullptr); +} + +void quantize_row_iq1_m_r4(const float * x, void * y, int64_t k) { + quantize_iq1_m_r4(x, y, 4, k/4, nullptr); +} + +size_t quantize_iq1_m_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + constexpr int kBlockSize = 32; + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%kBlockSize == 0); + int nblock = n_per_row/kBlockSize; + float weight[kBlockSize]; + int8_t L[kBlockSize]; + float pairs[2*kBlockSize]; + float max[4]; + uint16_t index[4]; + int shift1, shift2; + float invd[4]; + const uint8_t masks[4] = {0x00, 0x80, 0x08, 0x88}; + std::vector scales(8*nblock); + auto row_size = ggml_row_size(GGML_TYPE_IQ1_M_R4, n_per_row); + char * cy = (char *)dst; + for (int row = 0; row < nrows; row += 4) { + ggml_half * dptr = (ggml_half *)cy; + auto y = (block_iq1_m_r4 *)(dptr + 4); + for (int k = 0; k < 4; ++k) max[k] = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + auto xb = src + k*n_per_row + kBlockSize*ibl; + float sumx2 = 0; + for (int j = 0; j < kBlockSize; ++j) sumx2 += xb[j]*xb[j]; + if (!sumx2) { + scales[8*ibl+2*k+0] = scales[8*ibl+2*k+1] = 0; + continue; + } + float sigma2 = 1.5f*sumx2/kBlockSize; + if (imatrix) { + for (int j = 0; j < kBlockSize; ++j) weight[j] = imatrix[kBlockSize*ibl + j]*sqrt(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < kBlockSize; ++j) weight[j] = sqrt(sigma2 + xb[j]*xb[j]); + } + iq1m_process_1block(xb+ 0, weight+ 0, L, scales.data() + 8*ibl + 2*k+0, index+0, &shift1, pairs); + iq1m_process_1block(xb+16, weight+16, L, scales.data() + 8*ibl + 2*k+1, index+2, &shift2, pairs); + max[k] = std::max(max[k], std::max(scales[8*ibl+2*k+0], scales[8*ibl+2*k+1])); + for (int i = 0; i < 4; ++i) { + y[ibl].qs[4*i + k] = index[i] & 255; + } + for (int i = 0; i < 2; ++i) { + y[ibl].qh[4*i+k] = (index[2*i+0] >> 8) | ((index[2*i+1] >> 8) << 4); + } + y[ibl].qh[0+k] |= masks[shift1]; + y[ibl].qh[4+k] |= masks[shift2]; + } + } + for (int k = 0; k < 4; ++k) { + dptr[k] = GGML_FP32_TO_FP16(1.0625f*max[k]/15);; + invd[k] = max[k] ? 15/max[k] : 0.f; + } + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + int ls1 = nearest_int(scales[8*ibl+2*k+0]*invd[k]); + int ls2 = nearest_int(scales[8*ibl+2*k+1]*invd[k]); + ls1 = std::max(0, std::min(15, ls1)); + ls2 = std::max(0, std::min(15, ls2)); + y[ibl].scales[k] = ls1 | (ls2 << 4); + } + } + cy += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_iq1_m_r4(const block_iq1_m_r4 * x, float * y, int64_t n) { + auto dptr = (const ggml_half *)x; + x = (const block_iq1_m_r4 *)(dptr + 4); + float d[4]; + for (int k = 0; k < 4; ++k) d[k] = GGML_FP16_TO_FP32(dptr[k]); + int n_per_row = n/4; + GGML_ASSERT(n_per_row%32 == 0); + int nblock = n_per_row/32; + float dl[2]; + float * yk[4]; + for (int k = 0; k < 4; ++k) yk[k] = y + k*n_per_row; + for (int ib = 0; ib < nblock; ++ib) { + for (int k = 0; k < 4; ++k) { + dl[0] = d[k]*(x[ib].scales[k] & 0xf); + dl[1] = d[k]*(x[ib].scales[k] >> 4); + for (int i = 0; i < 2; ++i) { + auto idx1 = x[ib].qs[8*i+k+0] | ((x[ib].qh[4*i+k] & 0x07) << 8); + auto idx2 = x[ib].qs[8*i+k+4] | ((x[ib].qh[4*i+k] & 0x70) << 4); + auto grid1 = (const int8_t *)(iq1s_grid + idx1); + auto grid2 = (const int8_t *)(iq1s_grid + idx2); + auto delta1 = x[ib].qh[4*i+k] & 0x08 ? -IQ1M_DELTA : IQ1M_DELTA; + auto delta2 = x[ib].qh[4*i+k] & 0x80 ? -IQ1M_DELTA : IQ1M_DELTA; + for (int j = 0; j < 8; ++j) yk[k][32*ib + 16*i + j + 0] = dl[i]*(grid1[j] + delta1); + for (int j = 0; j < 8; ++j) yk[k][32*ib + 16*i + j + 8] = dl[i]*(grid2[j] + delta2); + } + } + } +} + +void vec_dot_iq1_m_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ1_M_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + //================================================ namespace { diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 9a3c5dc6..0dbb88bd 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -205,6 +205,12 @@ size_t quantize_iq1_s_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT d void dequantize_row_iq1_s_r4(const block_iq1_s_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq1_s_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_iq1_m_r4_ref(const float * GGML_RESTRICT x, block_iq1_m_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_iq1_m_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq1_m_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq1_m_r4(const block_iq1_m_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq1_m_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void quantize_row_q8_k_r8_ref(const float * GGML_RESTRICT x, block_q8_k_r8 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_k_r8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); size_t quantize_q8_k_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/include/llama.h b/include/llama.h index 0f6d15ac..3f25b296 100644 --- a/include/llama.h +++ b/include/llama.h @@ -197,6 +197,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ3_S_R4 = 226, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_M_R4 = 229, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_XS_R4 = 230, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ1_M_R4 = 231, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q6_0_R4 = 335, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16_R16 = 232, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_BN_R4 = 337, // except 1d tensors diff --git a/src/llama.cpp b/src/llama.cpp index 943b945a..117f59be 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3955,6 +3955,7 @@ struct llama_model_loader { case GGML_TYPE_IQ3_XXS_R4: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4; break; case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break; case GGML_TYPE_IQ1_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ1_S_R4;break; + case GGML_TYPE_IQ1_M_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ1_M_R4;break; case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break; case GGML_TYPE_IQ1_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ1_BN; break; case GGML_TYPE_IQ2_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN; break; @@ -4690,6 +4691,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4: return "IQ3_XXS_R4 - 3.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_S_R4: return "IQ1_S_R4 - 1.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_M_R4: return "IQ1_M_R4 - 1.75 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL_R4:return "IQ4_NL_R4 - 4.5 bpw"; @@ -15969,7 +15971,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K || ftype == LLAMA_FTYPE_MOSTLY_IQ2_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || - ftype == LLAMA_FTYPE_MOSTLY_IQ2_M_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4) { + ftype == LLAMA_FTYPE_MOSTLY_IQ2_M_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ1_M_R4) { new_type = !qs.has_output ? GGML_TYPE_IQ4_K : GGML_TYPE_Q5_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS_R4) { @@ -15991,7 +15994,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS_R4 || - ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4) { + ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M_R4) { new_type = GGML_TYPE_Q2_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M_R4) { @@ -16068,7 +16071,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = GGML_TYPE_BF16; } } - } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4) { + } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M_R4) { if (name.find("attn_v.weight") != std::string::npos) { if (qs.model.hparams.n_expert >= 4 || qs.model.hparams.n_gqa() >= 4) new_type = GGML_TYPE_IQ4_K_R4; else if (qs.model.hparams.n_gqa() >= 2) new_type = GGML_TYPE_IQ3_K_R4; @@ -16134,7 +16137,6 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = GGML_TYPE_Q5_K; } else { if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4) new_type = GGML_TYPE_IQ2_K_R4; else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || is_iq2_m) new_type = GGML_TYPE_IQ3_S; } } @@ -16580,6 +16582,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4: default_type = GGML_TYPE_IQ3_XXS_R4; break; case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break; case LLAMA_FTYPE_MOSTLY_IQ1_S_R4:default_type = GGML_TYPE_IQ1_S_R4;break; + case LLAMA_FTYPE_MOSTLY_IQ1_M_R4:default_type = GGML_TYPE_IQ1_M_R4;break; case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break; case LLAMA_FTYPE_MOSTLY_IQ1_BN: default_type = GGML_TYPE_IQ1_BN; break; case LLAMA_FTYPE_MOSTLY_IQ2_BN: default_type = GGML_TYPE_IQ2_BN; break; @@ -16934,6 +16937,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_type == GGML_TYPE_IQ2_S_R4|| new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ1_S_R4|| + new_type == GGML_TYPE_IQ1_M_R4|| (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) || (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0))) { LLAMA_LOG_ERROR("\n\n============================================================\n"); @@ -17057,6 +17061,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ1_S; else chunk_size_multiplier = 4; } + else if (new_type == GGML_TYPE_IQ1_M_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ1_M; + else chunk_size_multiplier = 4; + } else if (new_type == GGML_TYPE_BF16_R16) { if (tensor->ne[1] % 16 != 0) new_type = GGML_TYPE_BF16; else chunk_size_multiplier = 16; From a08501ee5216402458d3d3e9b9af5763705eaffe Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 6 Feb 2025 18:45:28 +0200 Subject: [PATCH 09/14] Rename q4_0_r4, q8_0_r4 and iq4_xs_r4 to _r8 (#189) * Rename q4_0_r4 to q4_0_r8 to reflect actual row interleaving * Rename q8_0_r4 to q8_0_r8 to reflect actual row interleaving * Rename iq4_xs_r4 to iq4_xs_r8 to reflect actual row interleaving --------- Co-authored-by: Iwan Kawrakow --- examples/quantize/quantize.cpp | 6 +- ggml/include/ggml.h | 12 ++-- ggml/src/ggml-common.h | 4 +- ggml/src/ggml-quants.c | 6 +- ggml/src/ggml.c | 90 ++++++++++++------------- ggml/src/iqk/iqk_mul_mat.cpp | 120 ++++++++++++++++----------------- ggml/src/iqk/iqk_quantize.cpp | 72 ++++++++++---------- ggml/src/iqk/iqk_quantize.h | 30 ++++----- include/llama.h | 6 +- src/llama.cpp | 46 ++++++------- 10 files changed, 196 insertions(+), 196 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 7bdd8597..7ceee208 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -51,11 +51,11 @@ static const std::vector QUANT_OPTIONS = { { "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", }, { "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", }, { "IQ4_NL_R4",LLAMA_FTYPE_MOSTLY_IQ4_NL_R4," 4.50 bpw non-linear quantization", }, - { "IQ4_XS_R4",LLAMA_FTYPE_MOSTLY_IQ4_XS_R4," 4.25 bpw non-linear quantization", }, - { "Q4_0_R4", LLAMA_FTYPE_MOSTLY_Q4_0_R4, " 4.50 bpw quantization", }, + { "IQ4_XS_R8",LLAMA_FTYPE_MOSTLY_IQ4_XS_R8," 4.25 bpw non-linear quantization", }, + { "Q4_0_R8", LLAMA_FTYPE_MOSTLY_Q4_0_R8, " 4.50 bpw quantization", }, { "Q5_0_R4", LLAMA_FTYPE_MOSTLY_Q5_0_R4, " 5.50 bpw quantization", }, { "Q6_0_R4", LLAMA_FTYPE_MOSTLY_Q6_0_R4, " 6.50 bpw quantization", }, - { "Q8_0_R4", LLAMA_FTYPE_MOSTLY_Q8_0_R4, " 8.50 bpw quantization", }, + { "Q8_0_R8", LLAMA_FTYPE_MOSTLY_Q8_0_R8, " 8.50 bpw quantization", }, { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, { "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", }, { "IQ4_KS_R4",LLAMA_FTYPE_MOSTLY_IQ4_KS_R4,"IQ4_KS repacked", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 77ac33a9..b6bebd60 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -416,9 +416,9 @@ extern "C" { GGML_TYPE_Q8_K32 = 148, GGML_TYPE_Q8_KR8 = 149, - GGML_TYPE_Q4_0_R4 = 202, + GGML_TYPE_Q4_0_R8 = 202, GGML_TYPE_Q5_0_R4 = 206, - GGML_TYPE_Q8_0_R4 = 208, + GGML_TYPE_Q8_0_R8 = 208, GGML_TYPE_Q2_K_R4 = 210, GGML_TYPE_Q3_K_R4 = 211, GGML_TYPE_Q4_K_R4 = 212, @@ -431,7 +431,7 @@ extern "C" { GGML_TYPE_IQ4_NL_R4 = 220, GGML_TYPE_IQ3_S_R4 = 221, GGML_TYPE_IQ2_S_R4 = 222, - GGML_TYPE_IQ4_XS_R4 = 223, + GGML_TYPE_IQ4_XS_R8 = 223, GGML_TYPE_IQ1_M_R4 = 229, GGML_TYPE_BF16_R16 = 230, GGML_TYPE_Q6_0_R4 = 233, @@ -501,8 +501,8 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ2_KS = 138, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors // - GGML_FTYPE_MOSTLY_Q4_0_R4 = 202, // except 1d tensors - GGML_FTYPE_MOSTLY_Q8_0_R4 = 207, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors + GGML_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors GGML_FTYPE_MOSTLY_Q5_0_R4 = 208, // except 1d tensors GGML_FTYPE_MOSTLY_Q2_K_R4 = 210, // except 1d tensors GGML_FTYPE_MOSTLY_Q3_K_R4 = 211, // except 1d tensors @@ -516,7 +516,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_NL_R4 = 219, // except 1d tensors GGML_FTYPE_MOSTLY_IQ3_S_R4 = 220, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_S_R4 = 221, // except 1d tensors - GGML_FTYPE_MOSTLY_IQ4_XS_R4 = 222, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_XS_R8 = 222, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_M_R4 = 223, // except 1d tensors GGML_FTYPE_MOSTLY_BF16_R16 = 224, // except 1d tensors GGML_FTYPE_MOSTLY_Q6_0_R4 = 227, // except 1d tensors diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 679353be..0d014c23 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -562,8 +562,8 @@ typedef struct { uint8_t scales_h[QK_K/16]; uint8_t scales_l[QK_K/ 8]; uint8_t qs[QK_K*4]; -} block_iq4_xs_r4; -static_assert(sizeof(block_iq4_xs_r4) == 8*sizeof(block_iq4_xs), "wrong iq4_xs_rs block size/padding"); +} block_iq4_xs_r8; +static_assert(sizeof(block_iq4_xs_r8) == 8*sizeof(block_iq4_xs), "wrong iq4_xs_rs block size/padding"); typedef struct { uint8_t scales[QK_K/32]; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index d32a583f..fe7de167 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15193,7 +15193,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_IQ4_KS: break; case GGML_TYPE_IQ4_KSS: break; case GGML_TYPE_IQ4_NL_R4: break; - case GGML_TYPE_IQ4_XS_R4: break; + case GGML_TYPE_IQ4_XS_R8: break; case GGML_TYPE_IQ2_XXS_R4: break; case GGML_TYPE_IQ2_XS_R4: break; case GGML_TYPE_IQ3_XXS_R4: break; @@ -15201,10 +15201,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_IQ2_S_R4: break; case GGML_TYPE_IQ1_S_R4: break; case GGML_TYPE_IQ1_M_R4: break; - case GGML_TYPE_Q4_0_R4: break; + case GGML_TYPE_Q4_0_R8: break; case GGML_TYPE_Q5_0_R4: break; case GGML_TYPE_Q6_0_R4: break; - case GGML_TYPE_Q8_0_R4: break; + case GGML_TYPE_Q8_0_R8: break; case GGML_TYPE_Q2_K_R4: break; case GGML_TYPE_Q3_K_R4: break; case GGML_TYPE_Q4_K_R4: break; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4199a282..68525906 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1606,28 +1606,28 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, - [GGML_TYPE_IQ4_XS_R4] = { - .type_name = "iq4_xs_r4", + [GGML_TYPE_IQ4_XS_R8] = { + .type_name = "iq4_xs_r8", .blck_size = QK_K, .type_size = sizeof(block_iq4_xs), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_iq4_xs_r4, - .from_float = quantize_row_iq4_xs_r4, - .from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_r4_ref, - .vec_dot = vec_dot_iq4_xs_r4_q8_k, + .to_float = (ggml_to_float_t) dequantize_row_iq4_xs_r8, + .from_float = quantize_row_iq4_xs_r8, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_r8_ref, + .vec_dot = vec_dot_iq4_xs_r8_q8_k, .vec_dot_type = GGML_TYPE_Q8_K32, .nrows = 1, .row_meta_size = 0, }, - [GGML_TYPE_Q4_0_R4] = { - .type_name = "q4_0_r4", + [GGML_TYPE_Q4_0_R8] = { + .type_name = "q4_0_r8", .blck_size = QK4_NL, .type_size = sizeof(block_iq4_nl), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q4_0_r4, - .from_float = quantize_row_q4_0_r4, - .from_float_ref = (ggml_from_float_t)quantize_row_q4_0_r4_ref, - .vec_dot = vec_dot_q4_0_r4_q8_0, + .to_float = (ggml_to_float_t) dequantize_row_q4_0_r8, + .from_float = quantize_row_q4_0_r8, + .from_float_ref = (ggml_from_float_t)quantize_row_q4_0_r8_ref, + .vec_dot = vec_dot_q4_0_r8_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ .vec_dot_type = GGML_TYPE_Q8_1_X4, @@ -1640,15 +1640,15 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, - [GGML_TYPE_Q8_0_R4] = { - .type_name = "q8_0_r4", + [GGML_TYPE_Q8_0_R8] = { + .type_name = "q8_0_r8", .blck_size = QK8_0, .type_size = sizeof(block_q8_0), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q8_0_r4, - .from_float = quantize_row_q8_0_r4, - .from_float_ref = (ggml_from_float_t)quantize_row_q8_0_r4_ref, - .vec_dot = vec_dot_q8_0_r4_q8_0, + .to_float = (ggml_to_float_t) dequantize_row_q8_0_r8, + .from_float = quantize_row_q8_0_r8, + .from_float_ref = (ggml_from_float_t)quantize_row_q8_0_r8_ref, + .vec_dot = vec_dot_q8_0_r8_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ .vec_dot_type = GGML_TYPE_Q8_1_X4, @@ -4390,11 +4390,11 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ2_BN_R4: wtype = GGML_TYPE_IQ2_BN_R4;break; case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; case GGML_FTYPE_MOSTLY_IQ4_NL_R4: wtype = GGML_TYPE_IQ4_NL_R4;break; - case GGML_FTYPE_MOSTLY_IQ4_XS_R4: wtype = GGML_TYPE_IQ4_XS_R4;break; - case GGML_FTYPE_MOSTLY_Q4_0_R4: wtype = GGML_TYPE_Q4_0_R4; break; + case GGML_FTYPE_MOSTLY_IQ4_XS_R8: wtype = GGML_TYPE_IQ4_XS_R8;break; + case GGML_FTYPE_MOSTLY_Q4_0_R8: wtype = GGML_TYPE_Q4_0_R8; break; case GGML_FTYPE_MOSTLY_Q5_0_R4: wtype = GGML_TYPE_Q5_0_R4; break; case GGML_FTYPE_MOSTLY_Q6_0_R4: wtype = GGML_TYPE_Q6_0_R4; break; - case GGML_FTYPE_MOSTLY_Q8_0_R4: wtype = GGML_TYPE_Q8_0_R4; break; + case GGML_FTYPE_MOSTLY_Q8_0_R8: wtype = GGML_TYPE_Q8_0_R8; break; case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; case GGML_FTYPE_MOSTLY_IQ4_KS: wtype = GGML_TYPE_IQ4_KS; break; case GGML_FTYPE_MOSTLY_IQ4_KS_R4: wtype = GGML_TYPE_IQ4_KS_R4;break; @@ -10938,12 +10938,12 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ2_BN_R4: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL_R4: - case GGML_TYPE_IQ4_XS_R4: - case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_IQ4_XS_R8: + case GGML_TYPE_Q4_0_R8: case GGML_TYPE_Q5_0_R4: case GGML_TYPE_Q6_0_R4: case GGML_TYPE_I2_S: - case GGML_TYPE_Q8_0_R4: + case GGML_TYPE_Q8_0_R8: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: @@ -11408,12 +11408,12 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ2_BN_R4: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL_R4: - case GGML_TYPE_IQ4_XS_R4: - case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_IQ4_XS_R8: + case GGML_TYPE_Q4_0_R8: case GGML_TYPE_Q5_0_R4: case GGML_TYPE_Q6_0_R4: case GGML_TYPE_I2_S: - case GGML_TYPE_Q8_0_R4: + case GGML_TYPE_Q8_0_R8: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: @@ -11575,12 +11575,12 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ2_BN_R4: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL_R4: - case GGML_TYPE_IQ4_XS_R4: - case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_IQ4_XS_R8: + case GGML_TYPE_Q4_0_R8: case GGML_TYPE_Q5_0_R4: case GGML_TYPE_Q6_0_R4: case GGML_TYPE_I2_S: - case GGML_TYPE_Q8_0_R4: + case GGML_TYPE_Q8_0_R8: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: @@ -14815,12 +14815,12 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ2_BN_R4: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL_R4: - case GGML_TYPE_IQ4_XS_R4: - case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_IQ4_XS_R8: + case GGML_TYPE_Q4_0_R8: case GGML_TYPE_Q5_0_R4: case GGML_TYPE_Q6_0_R4: case GGML_TYPE_I2_S: - case GGML_TYPE_Q8_0_R4: + case GGML_TYPE_Q8_0_R8: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: @@ -15222,12 +15222,12 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ2_BN_R4: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL_R4: - case GGML_TYPE_IQ4_XS_R4: - case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_IQ4_XS_R8: + case GGML_TYPE_Q4_0_R8: case GGML_TYPE_Q5_0_R4: case GGML_TYPE_Q6_0_R4: case GGML_TYPE_I2_S: - case GGML_TYPE_Q8_0_R4: + case GGML_TYPE_Q8_0_R8: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: @@ -15523,12 +15523,12 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ2_BN_R4: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL_R4: - case GGML_TYPE_IQ4_XS_R4: - case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_IQ4_XS_R8: + case GGML_TYPE_Q4_0_R8: case GGML_TYPE_Q5_0_R4: case GGML_TYPE_Q6_0_R4: case GGML_TYPE_I2_S: - case GGML_TYPE_Q8_0_R4: + case GGML_TYPE_Q8_0_R8: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: @@ -16153,12 +16153,12 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ2_BN_R4: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL_R4: - case GGML_TYPE_IQ4_XS_R4: - case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_IQ4_XS_R8: + case GGML_TYPE_Q4_0_R8: case GGML_TYPE_Q5_0_R4: case GGML_TYPE_Q6_0_R4: case GGML_TYPE_I2_S: - case GGML_TYPE_Q8_0_R4: + case GGML_TYPE_Q8_0_R8: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: @@ -23028,11 +23028,11 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ2_BN_R4:result = quantize_iq2_bn_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL_R4: result = quantize_iq4_nl_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_IQ4_XS_R4: result = quantize_iq4_xs_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q4_0_R4: result = quantize_q4_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ4_XS_R8: result = quantize_iq4_xs_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_0_R8: result = quantize_q4_0_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_0_R4: result = quantize_q5_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q6_0_R4: result = quantize_q6_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q8_0_R4: result = quantize_q8_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q8_0_R8: result = quantize_q8_0_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_KS: result = quantize_iq4_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_KS_R4:result = quantize_iq4_ks_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 57024602..c561ca2b 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -266,12 +266,12 @@ struct MulMat { case GGML_TYPE_Q5_0_R4: case GGML_TYPE_Q6_0_R4: case GGML_TYPE_IQ2_BN_R4: - case GGML_TYPE_IQ4_XS_R4: + case GGML_TYPE_IQ4_XS_R8: case GGML_TYPE_Q4_K_R4: case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q8_K_R8: return 8; - case GGML_TYPE_Q4_0_R4: - case GGML_TYPE_Q8_0_R4: + case GGML_TYPE_Q4_0_R8: + case GGML_TYPE_Q8_0_R8: case GGML_TYPE_BF16_R16: return 16; default: return 1; } @@ -298,9 +298,9 @@ struct MulMat { case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ1_M_R4: case GGML_TYPE_IQ2_BN_R4: return 4; - case GGML_TYPE_IQ4_XS_R4: - case GGML_TYPE_Q4_0_R4: - case GGML_TYPE_Q8_0_R4: + case GGML_TYPE_IQ4_XS_R8: + case GGML_TYPE_Q4_0_R8: + case GGML_TYPE_Q8_0_R8: case GGML_TYPE_Q8_K_R8: return 8; case GGML_TYPE_BF16_R16: return 16; default: return 1; @@ -3435,7 +3435,7 @@ inline __m256i accum_q4_0_quants(const __m256i * v, const int8_t * qs) { } template -static void mul_mat_q4_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q4_0_r8_q8_1_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); Q8 q8(info); auto m4 = _mm256_set1_epi8(0xf); @@ -3709,9 +3709,9 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI #ifdef HAVE_FANCY_SIMD template -static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q4_0_r8_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { if constexpr (nrc_y == 1) { - mul_mat_q4_0_r4_q8_1_avx2<1>(n, vx, bx, info, nrc_x); + mul_mat_q4_0_r8_q8_1_avx2<1>(n, vx, bx, info, nrc_x); return; } GGML_ASSERT(nrc_x%16 == 0); @@ -3787,8 +3787,8 @@ static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn } #else template -static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - mul_mat_q4_0_r4_q8_1_avx2(n, vx, bx, info, nrc_x); +static void mul_mat_q4_0_r8_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_q4_0_r8_q8_1_avx2(n, vx, bx, info, nrc_x); } #endif @@ -4177,7 +4177,7 @@ inline __m256i q8_0_r8_dot_product(const uint8_t * x, const int8_t * y, __m256i return qx_r8_q8_dot_product(qx, y); } template -static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%16 == 0); Q8 q8(info); int nb = n / QK8_0; @@ -4263,7 +4263,7 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn } #else template -static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); Q8 q8(info); auto m1 = _mm256_set1_epi16(1); @@ -4345,7 +4345,7 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn #endif template -static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_iq4_xs_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); Q8 q8(info); auto m4 = _mm256_set1_epi8(0xf); @@ -4364,7 +4364,7 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const __m256 acc[nrc_y] = {}; __m256i qx[4]; for (int ix = 0; ix < nrc_x; ix += 8) { - const block_iq4_xs_r4 * iq4 = (const block_iq4_xs_r4 *)((const char *)vx + (ix+0)*bx); + const block_iq4_xs_r8 * iq4 = (const block_iq4_xs_r8 *)((const char *)vx + (ix+0)*bx); for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d)); auto slbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l); @@ -4465,11 +4465,11 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const #ifdef HAVE_FANCY_SIMD template -static void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - mul_mat_iq4_xs_r4_q8_k_avx2(n, vx, bx, info, nrc_x); +static void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_iq4_xs_r8_q8_k_avx2(n, vx, bx, info, nrc_x); return; if constexpr (nrc_y == 1){ - mul_mat_iq4_xs_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x); + mul_mat_iq4_xs_r8_q8_k_avx2<1>(n, vx, bx, info, nrc_x); } else { GGML_ASSERT(nrc_x%8 == 0); Q8 q8(info); @@ -4482,8 +4482,8 @@ static void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data __m512i isum[nrc_y] = {}; __m512i qx[4]; for (int ix = 0; ix < nrc_x; ix += 8) { - const block_iq4_xs_r4 * iq4l = (const block_iq4_xs_r4 *)((const char *)vx + (ix+0)*bx); - const block_iq4_xs_r4 * iq4h = (const block_iq4_xs_r4 *)((const char *)vx + (ix+4)*bx); + const block_iq4_xs_r8 * iq4l = (const block_iq4_xs_r8 *)((const char *)vx + (ix+0)*bx); + const block_iq4_xs_r8 * iq4h = (const block_iq4_xs_r8 *)((const char *)vx + (ix+4)*bx); for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[ibl].d)); auto dh = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[ibl].d)); @@ -4544,8 +4544,8 @@ static void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data } #else template -static void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - mul_mat_iq4_xs_r4_q8_k_avx2(n, vx, bx, info, nrc_x); +static void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_iq4_xs_r8_q8_k_avx2(n, vx, bx, info, nrc_x); } #endif @@ -8889,16 +8889,16 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[7] = mul_mat_iq4_nl_r4_q8_1<8>; expected_typeB = GGML_TYPE_Q8_1_X4; break; - case GGML_TYPE_IQ4_XS_R4: + case GGML_TYPE_IQ4_XS_R8: assert (ne00 % QK_K == 0); - mm.funcs[0] = mul_mat_iq4_xs_r4_q8_k<1>; - mm.funcs[1] = mul_mat_iq4_xs_r4_q8_k<2>; - mm.funcs[2] = mul_mat_iq4_xs_r4_q8_k<3>; - mm.funcs[3] = mul_mat_iq4_xs_r4_q8_k<4>; - mm.funcs[4] = mul_mat_iq4_xs_r4_q8_k<5>; - mm.funcs[5] = mul_mat_iq4_xs_r4_q8_k<6>; - mm.funcs[6] = mul_mat_iq4_xs_r4_q8_k<7>; - mm.funcs[7] = mul_mat_iq4_xs_r4_q8_k<8>; + mm.funcs[0] = mul_mat_iq4_xs_r8_q8_k<1>; + mm.funcs[1] = mul_mat_iq4_xs_r8_q8_k<2>; + mm.funcs[2] = mul_mat_iq4_xs_r8_q8_k<3>; + mm.funcs[3] = mul_mat_iq4_xs_r8_q8_k<4>; + mm.funcs[4] = mul_mat_iq4_xs_r8_q8_k<5>; + mm.funcs[5] = mul_mat_iq4_xs_r8_q8_k<6>; + mm.funcs[6] = mul_mat_iq4_xs_r8_q8_k<7>; + mm.funcs[7] = mul_mat_iq4_xs_r8_q8_k<8>; expected_typeB = GGML_TYPE_Q8_K32; break; case GGML_TYPE_IQ4_KS_R4: @@ -9113,18 +9113,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { #endif expected_typeB = GGML_TYPE_Q8_K; break; - case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_Q4_0_R8: assert (ne00 % QK4_NL == 0); - mm.funcs[0] = mul_mat_q4_0_r4_q8_1<1>; - mm.funcs[1] = mul_mat_q4_0_r4_q8_1<2>; - mm.funcs[2] = mul_mat_q4_0_r4_q8_1<3>; - mm.funcs[3] = mul_mat_q4_0_r4_q8_1<4>; - mm.funcs[4] = mul_mat_q4_0_r4_q8_1<5>; - mm.funcs[5] = mul_mat_q4_0_r4_q8_1<6>; - mm.funcs[6] = mul_mat_q4_0_r4_q8_1<7>; - mm.funcs[7] = mul_mat_q4_0_r4_q8_1<8>; + mm.funcs[0] = mul_mat_q4_0_r8_q8_1<1>; + mm.funcs[1] = mul_mat_q4_0_r8_q8_1<2>; + mm.funcs[2] = mul_mat_q4_0_r8_q8_1<3>; + mm.funcs[3] = mul_mat_q4_0_r8_q8_1<4>; + mm.funcs[4] = mul_mat_q4_0_r8_q8_1<5>; + mm.funcs[5] = mul_mat_q4_0_r8_q8_1<6>; + mm.funcs[6] = mul_mat_q4_0_r8_q8_1<7>; + mm.funcs[7] = mul_mat_q4_0_r8_q8_1<8>; #ifdef HAVE_FANCY_SIMD - mm.func16 = mul_mat_q4_0_r4_q8_1<16>; + mm.func16 = mul_mat_q4_0_r8_q8_1<16>; #endif expected_typeB = GGML_TYPE_Q8_1_X4; break; @@ -9152,16 +9152,16 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[7] = mul_mat_q6_0_r4_q8_1<8>; expected_typeB = GGML_TYPE_Q8_1_X4; break; - case GGML_TYPE_Q8_0_R4: + case GGML_TYPE_Q8_0_R8: assert (ne00 % QK4_NL == 0); - mm.funcs[0] = mul_mat_q8_0_r4_q8_1<1>; - mm.funcs[1] = mul_mat_q8_0_r4_q8_1<2>; - mm.funcs[2] = mul_mat_q8_0_r4_q8_1<3>; - mm.funcs[3] = mul_mat_q8_0_r4_q8_1<4>; - mm.funcs[4] = mul_mat_q8_0_r4_q8_1<5>; - mm.funcs[5] = mul_mat_q8_0_r4_q8_1<6>; - mm.funcs[6] = mul_mat_q8_0_r4_q8_1<7>; - mm.funcs[7] = mul_mat_q8_0_r4_q8_1<8>; + mm.funcs[0] = mul_mat_q8_0_r8_q8_1<1>; + mm.funcs[1] = mul_mat_q8_0_r8_q8_1<2>; + mm.funcs[2] = mul_mat_q8_0_r8_q8_1<3>; + mm.funcs[3] = mul_mat_q8_0_r8_q8_1<4>; + mm.funcs[4] = mul_mat_q8_0_r8_q8_1<5>; + mm.funcs[5] = mul_mat_q8_0_r8_q8_1<6>; + mm.funcs[6] = mul_mat_q8_0_r8_q8_1<7>; + mm.funcs[7] = mul_mat_q8_0_r8_q8_1<8>; expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_IQ1_S_R4: @@ -11779,7 +11779,7 @@ IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values, const u } template -void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); Q8 q8(info); auto m4 = vdupq_n_u8(0xf); @@ -11792,7 +11792,7 @@ void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& i int32x4x2_t scales; float32x4_t acc[2*nrc_y] = {}; for (int ix = 0; ix < nrc_x; ix += 8) { - const block_iq4_xs_r4 * iq4 = (const block_iq4_xs_r4 *)((const char *)vx + ix*bx); + const block_iq4_xs_r8 * iq4 = (const block_iq4_xs_r8 *)((const char *)vx + ix*bx); for (int ibl = 0; ibl < nbl; ++ibl) { auto d4_f16 = vld1q_f16((const float16_t *)iq4[ibl].d); auto d4l = vcvt_f32_f16(vget_low_f16 (d4_f16)); @@ -13662,7 +13662,7 @@ inline void qx_0_q8_0_dot(const int8x16_t * qx, const int8_t * qy, int32x4_t& su } template -void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +void mul_mat_q8_0_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); Q8 q8(info); int nb = n / QK8_0; @@ -13880,8 +13880,8 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer); expected_Btype = GGML_TYPE_Q8_0_X4; break; - case GGML_TYPE_IQ4_XS_R4: - SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r4_q8_k); + case GGML_TYPE_IQ4_XS_R8: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r8_q8_k); expected_Btype = GGML_TYPE_Q8_K32; break; case GGML_TYPE_IQ4_KS_R4: @@ -13964,7 +13964,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq5_k_r4_q8_k); expected_Btype = GGML_TYPE_Q8_K; break; - case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_Q4_0_R8: SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r8_q8_0, Q4_0_R8_Dequantizer); expected_Btype = GGML_TYPE_Q8_0_X4; break; @@ -13976,8 +13976,8 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q6_0_R4_Dequantizer); expected_Btype = GGML_TYPE_Q8_0_X4; break; - case GGML_TYPE_Q8_0_R4: - SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_0_r4_q8_0); + case GGML_TYPE_Q8_0_R8: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_0_r8_q8_0); expected_Btype = GGML_TYPE_Q8_0_X4; break; default: @@ -15260,9 +15260,9 @@ struct FlashQKfp32 { } else if constexpr (std::is_same_v>) { #ifdef __aarch64__ - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq); + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq); #else - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1, nq); + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_1, nq); #endif } else if constexpr (std::is_same_v>) { diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index e741a8ea..9ce5731d 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -3622,16 +3622,16 @@ void vec_dot_iq4_nl_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t } // -// ========================================= q4_0_r4 +// ========================================= q4_0_r8 // -void quantize_row_q4_0_r4_ref(const float * x, block_iq4_nl_r8 * y, int64_t k) { +void quantize_row_q4_0_r8_ref(const float * x, block_iq4_nl_r8 * y, int64_t k) { // we assume we are called with 8 rows - quantize_q4_0_r4(x, (void *)y, 8, k/8, nullptr); + quantize_q4_0_r8(x, (void *)y, 8, k/8, nullptr); } -void quantize_row_q4_0_r4(const float * x, void * y, int64_t k) { +void quantize_row_q4_0_r8(const float * x, void * y, int64_t k) { // we assume we are called with 8 rows - quantize_q4_0_r4(x, y, 8, k/8, nullptr); + quantize_q4_0_r8(x, y, 8, k/8, nullptr); } static void repack_q4_0(int nrows, int n_per_row, const block_q4_0 * x, block_iq4_nl_r8 * y, [[maybe_unused]] bool online) { @@ -3664,7 +3664,7 @@ static void repack_q4_0(int nrows, int n_per_row, const block_q4_0 * x, block_iq } } #ifdef __ARM_NEON -static void modify_q4_0_r4(int64_t k, char * cy) { +static void modify_q4_0_r8(int64_t k, char * cy) { auto y = (block_iq4_nl_r8 *)cy; int nb = k/(32*8); for (int ib = 0; ib < nb; ++ib) { @@ -3680,7 +3680,7 @@ static void modify_q4_0_r4(int64_t k, char * cy) { } #endif -size_t quantize_q4_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { +size_t quantize_q4_0_r8(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { GGML_ASSERT(nrows%8 == 0); auto row_size_nl = ggml_row_size(GGML_TYPE_IQ4_NL, n_per_row); std::vector qtmp(8*row_size_nl); @@ -3694,7 +3694,7 @@ size_t quantize_q4_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_ return nrows*row_size_nl; } -void dequantize_row_q4_0_r4(const block_iq4_nl_r8 * x, float * y, int64_t k) { +void dequantize_row_q4_0_r8(const block_iq4_nl_r8 * x, float * y, int64_t k) { // we assume we are called with 8 rows int n_per_row = k/8; int nb = n_per_row/QK4_0; @@ -3713,9 +3713,9 @@ void dequantize_row_q4_0_r4(const block_iq4_nl_r8 * x, float * y, int64_t k) { } } -void vec_dot_q4_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +void vec_dot_q4_0_r8_q8_0(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { #if GGML_USE_IQK_MULMAT - if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q4_0_R4, vx, 0, GGML_TYPE_Q8_0, vy, 0, s, 0, 0, 1)) { + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q4_0_R8, vx, 0, GGML_TYPE_Q8_0, vy, 0, s, 0, 0, 1)) { return; } #endif @@ -3728,16 +3728,16 @@ void vec_dot_q4_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t b // -// ========================================= q8_0_r4 +// ========================================= q8_0_r8 // -void quantize_row_q8_0_r4_ref(const float * x, block_q8_0_r8 * y, int64_t k) { +void quantize_row_q8_0_r8_ref(const float * x, block_q8_0_r8 * y, int64_t k) { // we assume we are called with 4 rows - quantize_q8_0_r4(x, (void *)y, 8, k/8, nullptr); + quantize_q8_0_r8(x, (void *)y, 8, k/8, nullptr); } -void quantize_row_q8_0_r4(const float * x, void * y, int64_t k) { +void quantize_row_q8_0_r8(const float * x, void * y, int64_t k) { // we assume we are called with 4 rows - quantize_q8_0_r4(x, y, 8, k/8, nullptr); + quantize_q8_0_r8(x, y, 8, k/8, nullptr); } static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8_0_r8 * y, [[maybe_unused]] bool online) { @@ -3770,7 +3770,7 @@ static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8 } #ifdef HAVE_FANCY_SIMD -static void modify_q8_0_r4(int64_t k, char * cy) { +static void modify_q8_0_r8(int64_t k, char * cy) { auto y = (block_iq4_nl_r8 *)cy; int nb = k/(32*8); for (int ib = 0; ib < nb; ++ib) { @@ -3782,7 +3782,7 @@ static void modify_q8_0_r4(int64_t k, char * cy) { } #endif -size_t quantize_q8_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { +size_t quantize_q8_0_r8(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { GGML_ASSERT(nrows%8 == 0); auto row_size_0 = ggml_row_size(GGML_TYPE_Q8_0, n_per_row); std::vector qtmp(8*row_size_0); @@ -3796,7 +3796,7 @@ size_t quantize_q8_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_ return nrows*row_size_0; } -void dequantize_row_q8_0_r4(const block_q8_0_r8 * x, float * y, int64_t k) { +void dequantize_row_q8_0_r8(const block_q8_0_r8 * x, float * y, int64_t k) { // we assume we are called with 4 rows int n_per_row = k/8; int nb = n_per_row/QK8_0; @@ -3813,9 +3813,9 @@ void dequantize_row_q8_0_r4(const block_q8_0_r8 * x, float * y, int64_t k) { } } -void vec_dot_q8_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +void vec_dot_q8_0_r8_q8_0(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { #if GGML_USE_IQK_MULMAT - if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_0_R4, vx, 0, GGML_TYPE_Q8_0, vy, 0, s, 0, 0, 1)) { + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_0_R8, vx, 0, GGML_TYPE_Q8_0, vy, 0, s, 0, 0, 1)) { return; } #endif @@ -4025,18 +4025,18 @@ void vec_dot_q6_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t b } // -// ========================================= iq4_xs_r4 +// ========================================= iq4_xs_r8 // -void quantize_row_iq4_xs_r4_ref(const float * x, block_iq4_xs_r4 * y, int64_t k) { - quantize_iq4_xs_r4(x, (void *)y, 8, k/8, nullptr); +void quantize_row_iq4_xs_r8_ref(const float * x, block_iq4_xs_r8 * y, int64_t k) { + quantize_iq4_xs_r8(x, (void *)y, 8, k/8, nullptr); } -void quantize_row_iq4_xs_r4(const float * x, void * y, int64_t k) { - quantize_iq4_xs_r4(x, y, 8, k/8, nullptr); +void quantize_row_iq4_xs_r8(const float * x, void * y, int64_t k) { + quantize_iq4_xs_r8(x, y, 8, k/8, nullptr); } -static void repack_iq4_xs(int nrows, int n_per_row, const block_iq4_xs * x, block_iq4_xs_r4 * y, [[maybe_unused]] bool online) { +static void repack_iq4_xs(int nrows, int n_per_row, const block_iq4_xs * x, block_iq4_xs_r8 * y, [[maybe_unused]] bool online) { GGML_ASSERT(nrows%8 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; @@ -4068,7 +4068,7 @@ static void repack_iq4_xs(int nrows, int n_per_row, const block_iq4_xs * x, bloc } } -size_t quantize_iq4_xs_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { +size_t quantize_iq4_xs_r8(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { GGML_ASSERT(nrows%8 == 0); GGML_ASSERT(n_per_row%QK_K == 0); char * qcur = (char *)dst; @@ -4076,14 +4076,14 @@ size_t quantize_iq4_xs_r4(const float * src, void * dst, int64_t nrows, int64_t std::vector qtmp(8*row_size); for (int row = 0; row < nrows; row += 8) { quantize_iq4_xs(src, (void *)qtmp.data(), 8, n_per_row, imatrix); - repack_iq4_xs(8, n_per_row, (const block_iq4_xs *)qtmp.data(), (block_iq4_xs_r4 *)qcur, false); + repack_iq4_xs(8, n_per_row, (const block_iq4_xs *)qtmp.data(), (block_iq4_xs_r8 *)qcur, false); qcur += 8*row_size; src += 8*n_per_row; } return nrows*row_size; } -void dequantize_row_iq4_xs_r4(const block_iq4_xs_r4 * x, float * y, int64_t k) { +void dequantize_row_iq4_xs_r8(const block_iq4_xs_r8 * x, float * y, int64_t k) { auto n_per_row = k/8; float * y8[8]; for (int k = 0; k < 8; ++k) y8[k] = y + n_per_row*k; @@ -4103,9 +4103,9 @@ void dequantize_row_iq4_xs_r4(const block_iq4_xs_r4 * x, float * y, int64_t k) { } } -void vec_dot_iq4_xs_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +void vec_dot_iq4_xs_r8_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { #if GGML_USE_IQK_MULMAT - if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_XS_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_XS_R8, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { return; } #endif @@ -6329,10 +6329,10 @@ struct Modify { bool iqk_modify_tensor(struct ggml_tensor * tensor) { static const std::unordered_map k_mod_map = { #ifdef __ARM_NEON - { GGML_TYPE_Q4_0_R4, {modify_q4_0_r4, 8} }, + { GGML_TYPE_Q4_0_R8, {modify_q4_0_r8, 8} }, #endif #ifdef HAVE_FANCY_SIMD - { GGML_TYPE_Q8_0_R4, {modify_q8_0_r4, 8} }, + { GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} }, { GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} }, #endif }; @@ -6373,7 +6373,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { { GGML_TYPE_IQ3_K, { GGML_TYPE_IQ3_K_R4, 4, (Repack::repack_func)repack_iq3_k} }, { GGML_TYPE_IQ4_K, { GGML_TYPE_IQ4_K_R4, 4, (Repack::repack_func)repack_iq4_k} }, { GGML_TYPE_IQ5_K, { GGML_TYPE_IQ5_K_R4, 4, (Repack::repack_func)repack_iq5_k} }, - { GGML_TYPE_IQ4_XS, { GGML_TYPE_IQ4_XS_R4, 8, (Repack::repack_func)repack_iq4_xs} }, + { GGML_TYPE_IQ4_XS, { GGML_TYPE_IQ4_XS_R8, 8, (Repack::repack_func)repack_iq4_xs} }, { GGML_TYPE_IQ4_KS, { GGML_TYPE_IQ4_KS_R4, 4, (Repack::repack_func)repack_iq4_ks} }, { GGML_TYPE_IQ4_NL, { GGML_TYPE_IQ4_NL_R4, 4, (Repack::repack_func)repack_iq4_nl} }, { GGML_TYPE_IQ2_BN, { GGML_TYPE_IQ2_BN_R4, 4, (Repack::repack_func)repack_iq2_bn} }, @@ -6387,10 +6387,10 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { { GGML_TYPE_Q4_K, { GGML_TYPE_Q4_K_R4, 4, (Repack::repack_func)repack_q4_k} }, { GGML_TYPE_Q5_K, { GGML_TYPE_Q5_K_R4, 4, (Repack::repack_func)repack_q5_k} }, { GGML_TYPE_Q6_K, { GGML_TYPE_Q6_K_R4, 4, (Repack::repack_func)repack_q6_k} }, - { GGML_TYPE_Q4_0, { GGML_TYPE_Q4_0_R4, 8, (Repack::repack_func)repack_q4_0} }, + { GGML_TYPE_Q4_0, { GGML_TYPE_Q4_0_R8, 8, (Repack::repack_func)repack_q4_0} }, { GGML_TYPE_Q5_0, { GGML_TYPE_Q5_0_R4, 4, (Repack::repack_func)repack_q5_0} }, { GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4, (Repack::repack_func)repack_q6_0} }, - { GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R4, 8, (Repack::repack_func)repack_q8_0} }, + { GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R8, 8, (Repack::repack_func)repack_q8_0} }, { GGML_TYPE_Q8_K, { GGML_TYPE_Q8_K_R8, 8, (Repack::repack_func)repack_q8_k} }, #ifdef __AVX512BF16__ { GGML_TYPE_BF16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16}}, diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 0dbb88bd..ff553ae7 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -67,17 +67,17 @@ size_t quantize_iq4_nl_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT void dequantize_row_iq4_nl_r4(const block_iq4_nl_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq4_nl_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void quantize_row_q4_0_r4_ref(const float * GGML_RESTRICT x, block_iq4_nl_r8 * GGML_RESTRICT y, int64_t k); -void quantize_row_q4_0_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -size_t quantize_q4_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -void dequantize_row_q4_0_r4(const block_iq4_nl_r8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void vec_dot_q4_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_q4_0_r8_ref(const float * GGML_RESTRICT x, block_iq4_nl_r8 * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_0_r8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q4_0_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q4_0_r8(const block_iq4_nl_r8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q4_0_r8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void quantize_row_q8_0_r4_ref(const float * GGML_RESTRICT x, block_q8_0_r8 * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_0_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -size_t quantize_q8_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -void dequantize_row_q8_0_r4(const block_q8_0_r8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void vec_dot_q8_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_q8_0_r8_ref(const float * GGML_RESTRICT x, block_q8_0_r8 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0_r8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q8_0_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q8_0_r8(const block_q8_0_r8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q8_0_r8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void quantize_row_q5_0_r4_ref(const float * GGML_RESTRICT x, block_q5_0_r4 * GGML_RESTRICT y, int64_t k); void quantize_row_q5_0_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -91,11 +91,11 @@ size_t quantize_q6_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds void dequantize_row_q6_0_r4(const block_q6_0_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_q6_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void quantize_row_iq4_xs_r4_ref(const float * GGML_RESTRICT x, block_iq4_xs_r4 * GGML_RESTRICT y, int64_t k); -void quantize_row_iq4_xs_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -size_t quantize_iq4_xs_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -void dequantize_row_iq4_xs_r4(const block_iq4_xs_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void vec_dot_iq4_xs_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_iq4_xs_r8_ref(const float * GGML_RESTRICT x, block_iq4_xs_r8 * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_xs_r8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq4_xs_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq4_xs_r8(const block_iq4_xs_r8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq4_xs_r8_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void quantize_row_iq2_bn_ref (const float * GGML_RESTRICT x, block_iq2_bn * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_bn (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); diff --git a/include/llama.h b/include/llama.h index 3f25b296..730c087a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -181,8 +181,8 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ2_KS = 147, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_KSS = 148, // except 1d tensors // - LLAMA_FTYPE_MOSTLY_Q4_0_R4 = 202, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q8_0_R4 = 207, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_0_R4 = 208, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q2_K_R4 = 210, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q3_K_R4 = 211, // except 1d tensors @@ -196,7 +196,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 = 225, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ3_S_R4 = 226, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_M_R4 = 229, // except 1d tensors - LLAMA_FTYPE_MOSTLY_IQ4_XS_R4 = 230, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ4_XS_R8 = 230, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M_R4 = 231, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q6_0_R4 = 335, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16_R16 = 232, // except 1d tensors diff --git a/src/llama.cpp b/src/llama.cpp index 117f59be..00a3c9b1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3962,11 +3962,11 @@ struct llama_model_loader { case GGML_TYPE_IQ2_BN_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN_R4;break; case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_NL_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL_R4;break; - case GGML_TYPE_IQ4_XS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS_R4;break; - case GGML_TYPE_Q4_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_R4; break; + case GGML_TYPE_IQ4_XS_R8:ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS_R8;break; + case GGML_TYPE_Q4_0_R8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_R8; break; case GGML_TYPE_Q5_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q5_0_R4; break; case GGML_TYPE_Q6_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_0_R4; break; - case GGML_TYPE_Q8_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q8_0_R4; break; + case GGML_TYPE_Q8_0_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_0_R8; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ4_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS; break; case GGML_TYPE_IQ4_KS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS_R4; break; @@ -4695,11 +4695,11 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL_R4:return "IQ4_NL_R4 - 4.5 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ4_XS_R4:return "IQ4_XS_R4 - 4.25 bpw"; - case LLAMA_FTYPE_MOSTLY_Q4_0_R4: return "Q4_0_R4 - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_XS_R8:return "IQ4_XS_R8 - 4.25 bpw"; + case LLAMA_FTYPE_MOSTLY_Q4_0_R8: return "Q4_0_R8 - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_Q5_0_R4: return "Q5_0_R4 - 5.5 bpw"; case LLAMA_FTYPE_MOSTLY_Q6_0_R4: return "Q6_0_R4 - 6.5 bpw"; - case LLAMA_FTYPE_MOSTLY_Q8_0_R4: return "Q8_0_R4 - 8.5 bpw"; + case LLAMA_FTYPE_MOSTLY_Q8_0_R8: return "Q8_0_R8 - 8.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_KS: return "IQ4_KS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_KS_R4:return "IQ4_KS_R4 - 4.25 bpw"; @@ -15982,7 +15982,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS_R4) && !qs.has_output) { new_type = GGML_TYPE_IQ5_K; } - else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_Q8_0_R4 && new_type != GGML_TYPE_IQ6_K && new_type != GGML_TYPE_Q6_K_R4 && + else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_Q8_0_R8 && new_type != GGML_TYPE_IQ6_K && new_type != GGML_TYPE_Q6_K_R4 && new_type != GGML_TYPE_Q8_K_R8) { new_type = GGML_TYPE_Q6_K; } @@ -16016,7 +16016,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (new_type == GGML_TYPE_IQ4_NL_R4) { new_type = GGML_TYPE_IQ4_NL; } - else if (new_type == GGML_TYPE_IQ4_XS_R4) { + else if (new_type == GGML_TYPE_IQ4_XS_R8) { new_type = GGML_TYPE_IQ4_XS; } else if (new_type == GGML_TYPE_Q2_K_R4) { @@ -16055,7 +16055,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (new_type == GGML_TYPE_IQ4_KS_R4) { new_type = GGML_TYPE_IQ4_KS; } - else if (new_type == GGML_TYPE_Q4_0_R4) { + else if (new_type == GGML_TYPE_Q4_0_R8) { new_type = GGML_TYPE_Q4_0; } else if (new_type == GGML_TYPE_Q5_0_R4) { @@ -16064,7 +16064,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (new_type == GGML_TYPE_Q6_0_R4) { new_type = GGML_TYPE_Q6_0; } - else if (new_type == GGML_TYPE_Q8_0_R4) { + else if (new_type == GGML_TYPE_Q8_0_R8) { new_type = GGML_TYPE_Q8_0; } else if (new_type == GGML_TYPE_BF16_R16) { @@ -16188,7 +16188,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n } else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || - ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R8 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) && qs.model.hparams.n_gqa() >= 2) { new_type = GGML_TYPE_IQ5_K; } @@ -16229,7 +16229,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (new_type == GGML_TYPE_Q4_K || new_type == GGML_TYPE_IQ4_XS) new_type = GGML_TYPE_Q5_K; else if (new_type == GGML_TYPE_IQ4_NL) new_type = GGML_TYPE_Q5_K; else if (new_type == GGML_TYPE_IQ4_NL_R4) new_type = GGML_TYPE_Q5_K; - else if (new_type == GGML_TYPE_IQ4_XS_R4) new_type = GGML_TYPE_Q5_K; + else if (new_type == GGML_TYPE_IQ4_XS_R8) new_type = GGML_TYPE_Q5_K; else if (new_type == GGML_TYPE_Q5_K) new_type = GGML_TYPE_Q6_K; } ++qs.i_attention_wv; @@ -16306,7 +16306,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (i_layer < n_layer/8 && !qs.has_imatrix && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS || - ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R4)) { + ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R8)) { new_type = GGML_TYPE_Q5_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 && i_layer < n_layer/8 && !qs.has_imatrix) { @@ -16326,7 +16326,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n // same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix. new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1; } - else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_0_R4 && qs.has_imatrix && i_layer < n_layer/8) { + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_0_R8 && qs.has_imatrix && i_layer < n_layer/8) { new_type = GGML_TYPE_IQ4_NL_R4; } ++qs.i_ffn_down; @@ -16339,7 +16339,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_K || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_R4 || - ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R4 || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R8 || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_Q2_K_R4|| ftype == LLAMA_FTYPE_MOSTLY_IQ4_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S_R4) { new_type = GGML_TYPE_Q5_K; @@ -16411,7 +16411,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S || new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K || new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_Q4_K_R4 || - new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ4_KS || new_type == GGML_TYPE_IQ4_XS_R4 || + new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ4_KS || new_type == GGML_TYPE_IQ4_XS_R8 || new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS || new_type == GGML_TYPE_Q6_K_R4 || new_type == GGML_TYPE_Q5_K_R4 || new_type == GGML_TYPE_Q3_K_R4 || new_type == GGML_TYPE_Q2_K_R4 || new_type == GGML_TYPE_IQ4_K_R4|| new_type == GGML_TYPE_Q8_K_R8 || new_type == GGML_TYPE_IQ3_K_R4|| @@ -16459,7 +16459,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: - case GGML_TYPE_IQ4_XS_R4: + case GGML_TYPE_IQ4_XS_R8: case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ4_K_R4: @@ -16589,11 +16589,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ2_BN_R4:default_type = GGML_TYPE_IQ2_BN_R4;break; case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break; case LLAMA_FTYPE_MOSTLY_IQ4_NL_R4:default_type = GGML_TYPE_IQ4_NL_R4;break; - case LLAMA_FTYPE_MOSTLY_IQ4_XS_R4:default_type = GGML_TYPE_IQ4_XS_R4;break; - case LLAMA_FTYPE_MOSTLY_Q4_0_R4: default_type = GGML_TYPE_Q4_0_R4; break; + case LLAMA_FTYPE_MOSTLY_IQ4_XS_R8:default_type = GGML_TYPE_IQ4_XS_R8;break; + case LLAMA_FTYPE_MOSTLY_Q4_0_R8: default_type = GGML_TYPE_Q4_0_R8; break; case LLAMA_FTYPE_MOSTLY_Q5_0_R4: default_type = GGML_TYPE_Q5_0_R4; break; case LLAMA_FTYPE_MOSTLY_Q6_0_R4: default_type = GGML_TYPE_Q6_0_R4; break; - case LLAMA_FTYPE_MOSTLY_Q8_0_R4: default_type = GGML_TYPE_Q8_0_R4; break; + case LLAMA_FTYPE_MOSTLY_Q8_0_R8: default_type = GGML_TYPE_Q8_0_R8; break; case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ4_KS: default_type = GGML_TYPE_IQ4_KS; break; case LLAMA_FTYPE_MOSTLY_IQ4_KS_R4:default_type = GGML_TYPE_IQ4_KS_R4;break; @@ -16969,11 +16969,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_NL; else chunk_size_multiplier = 4; } - else if (new_type == GGML_TYPE_IQ4_XS_R4) { + else if (new_type == GGML_TYPE_IQ4_XS_R8) { if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_IQ4_XS; else chunk_size_multiplier = 8; } - else if (new_type == GGML_TYPE_Q4_0_R4) { + else if (new_type == GGML_TYPE_Q4_0_R8) { if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q4_0; else chunk_size_multiplier = 8; } @@ -16985,7 +16985,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q6_0; else chunk_size_multiplier = 4; } - else if (new_type == GGML_TYPE_Q8_0_R4) { + else if (new_type == GGML_TYPE_Q8_0_R8) { if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0; else chunk_size_multiplier = 8; } From b08a2e9dfc0e721f7f190c25f37794390966e326 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 7 Feb 2025 08:33:28 +0200 Subject: [PATCH 10/14] Add additional checks for iq1_s_r4 quantization (#191) Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_quantize.cpp | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 9ce5731d..a01ed109 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -6116,23 +6116,48 @@ size_t quantize_iq1_s_r4(const float * src, void * dst, int64_t nrows, int64_t n auto y = (block_iq1_s_r4 *)(dptr + 4); for (int k = 0; k < 4; ++k) max[k] = 0; for (int ibl = 0; ibl < nblock; ++ibl) { - if (imatrix) { - for (int j = 0; j < kBlockSize; ++j) weight[j] = imatrix[kBlockSize*ibl + j]; - } for (int k = 0; k < 4; ++k) { auto xb = src + k*n_per_row + kBlockSize*ibl; float sumx2 = 0; for (int j = 0; j < kBlockSize; ++j) sumx2 += xb[j]*xb[j]; + if (!sumx2) { + printf("Found block with all zeros\n"); + // all zero + int ind = 1029; // this is the grid entry with all zeros + scales[4*ibl+k] = 0; + uint16_t h = 0; + for (int i = 0; i < 4; ++i) { + y[ibl].qs[4*i + k] = ind & 255; + h |= (ind >> 8) << 3*i; + } + y[ibl].qh[k] = h; + continue; + } float sigma2 = 1.5f*sumx2/kBlockSize; + bool have_imatrix = false; if (imatrix) { - for (int j = 0; j < kBlockSize; ++j) weight[j] = imatrix[kBlockSize*ibl + j]*sqrt(sigma2 + xb[j]*xb[j]); - } else { + have_imatrix = true; + float sumwx = 0; + for (int j = 0; j < kBlockSize; ++j) { + weight[j] = imatrix[kBlockSize*ibl + j]*sqrt(sigma2 + xb[j]*xb[j]); + sumwx += weight[j]*std::abs(xb[j]); + } + if (!sumwx) { + printf("Found block with mismatching importance/model weights\n"); + // Either all weights are zero, or xb is zero where weight is not zero. + // In both of these cases it is better to simply ignore the imatrix + have_imatrix = false; + } + } + if (!have_imatrix) { for (int j = 0; j < kBlockSize; ++j) weight[j] = sqrt(sigma2 + xb[j]*xb[j]); } iq1s_process_1block(kBlockSize, xb, weight, L, scales.data() + 4*ibl + k, index, &shift, pairs, sumx, sumw); + GGML_ASSERT(scales[4*ibl+k] >= 0); max[k] = std::max(max[k], scales[4*ibl+k]); uint16_t h = 0; for (int i = 0; i < 4; ++i) { + GGML_ASSERT(index[i] >= 0 && index[i] < 2048); y[ibl].qs[4*i + k] = index[i] & 255; h |= (index[i] >> 8) << 3*i; } From 4601a8c3735d8e47c46e0927712d77c4f422be6c Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 7 Feb 2025 08:33:42 +0200 Subject: [PATCH 11/14] cuda: non-contiguous rms norm (#190) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda/norm.cu | 153 ++++++++++++++++++++++++++++++++++--- src/llama.cpp | 6 +- 2 files changed, 144 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 7e670912..9e4931a3 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -131,6 +131,51 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol } } +template +static __global__ void rms_norm_f32_nc( + const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps) { + const int nrows = gridDim.x; + const int nchannels = gridDim.y; + + const int row = blockIdx.x; + const int channel = blockIdx.y; + const int sample = blockIdx.z; + const int tid = threadIdx.x; + + x += sample*stride_sample + channel*stride_channel + row*stride_row; + dst += ((sample*nchannels + channel)*nrows + row)*ncols; + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[col]; + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp); + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); + __shared__ float s_sum[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + const float mean = tmp / ncols; + const float scale = rsqrtf(mean + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[col] = scale * x[col]; + } +} + template static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; @@ -165,6 +210,51 @@ static __global__ void fused_rms_norm_f32(const float * x, const float * y, floa } } +template +static __global__ void fused_rms_norm_f32_nc( + const float * x, const float * y, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps) { + const int nrows = gridDim.x; + const int nchannels = gridDim.y; + + const int row = blockIdx.x; + const int channel = blockIdx.y; + const int sample = blockIdx.z; + const int tid = threadIdx.x; + + x += sample*stride_sample + channel*stride_channel + row*stride_row; + dst += ((sample*nchannels + channel)*nrows + row)*ncols; + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[col]; + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp); + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); + __shared__ float s_sum[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + const float mean = tmp / ncols; + const float scale = rsqrtf(mean + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[col] = scale * y[col] * x[col]; + } +} + static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { @@ -197,6 +287,19 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con } } +static void rms_norm_f32_nc_cuda( + const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + rms_norm_f32_nc<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_f32_nc<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } +} + static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); @@ -209,6 +312,19 @@ static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * ds } } +static void fused_rms_norm_f32_nc_cuda( + const float * x, const float * y, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + fused_rms_norm_f32_nc<<>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } else { + const dim3 block_dims(1024, 1, 1); + fused_rms_norm_f32_nc<1024><<>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } +} + void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -255,18 +371,24 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); - float eps; memcpy(&eps, dst->op_params, sizeof(float)); - rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); + const int64_t ne00 = src0->ne[0]; + if (ggml_is_contiguous(src0)) { + const int64_t nrows = ggml_nrows(src0); + rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); + } else { + auto ts0 = ggml_type_size(src0->type); + GGML_ASSERT(src0->nb[0] == ts0); + auto s01 = src0->nb[1] / ts0; + auto s02 = src0->nb[2] / ts0; + auto s03 = src0->nb[3] / ts0; + rms_norm_f32_nc_cuda(src0_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream); + } } void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -281,19 +403,26 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT(src0->ne[0] == src1->ne[0]); GGML_ASSERT(ggml_nrows(src1) == 1); - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); - float eps; memcpy(&eps, dst->op_params, sizeof(float)); - fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream); + const int64_t ne00 = src0->ne[0]; + + if (ggml_is_contiguous(src0)) { + const int64_t nrows = ggml_nrows(src0); + fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream); + } else { + auto ts0 = ggml_type_size(src0->type); + GGML_ASSERT(src0->nb[0] == ts0); + auto s01 = src0->nb[1] / ts0; + auto s02 = src0->nb[2] / ts0; + auto s03 = src0->nb[3] / ts0; + fused_rms_norm_f32_nc_cuda(src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream); + } } diff --git a/src/llama.cpp b/src/llama.cpp index 00a3c9b1..29926a94 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13390,7 +13390,7 @@ struct llm_build_context { ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); cb(k_pe, "k_pe", il); - kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm + //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, model.layers[il].attn_kv_a_norm, NULL, LLM_NORM_RMS, cb, il); @@ -13422,7 +13422,7 @@ struct llm_build_context { 0); cb(v_states, "v_states", il); - q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE + //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE q_pe = ggml_rope_ext( ctx0, q_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -13431,7 +13431,7 @@ struct llm_build_context { cb(q_pe, "q_pe", il); // shared RoPE key - k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE + //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE k_pe = ggml_rope_ext( ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, From 6d7b58eade37e45e3d8286a2353658047539d2b2 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 8 Feb 2025 09:48:59 +0200 Subject: [PATCH 12/14] Revert #79 (#192) * Revert "Do not quantize activations if not necessary (#79)" This reverts commit 0bf4d99774aa3b6d00ef564acbc4dc211e45db33. * Fixed compilation after revert --------- Co-authored-by: Iwan Kawrakow --- ggml/include/ggml.h | 1 - ggml/src/ggml.c | 63 ++++++++++++--------------------------------- 2 files changed, 16 insertions(+), 48 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b6bebd60..c307d42e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -724,7 +724,6 @@ extern "C" { // since https://github.com/ggerganov/ggml/issues/287 struct ggml_cplan { size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()` - size_t q_size; uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()` int n_threads; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 68525906..b19fb006 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2586,7 +2586,6 @@ struct ggml_compute_params { // work buffer for all threads size_t wsize; - size_t qsize; void * wdata; struct ggml_compute_state_shared * shared; @@ -13940,7 +13939,7 @@ static void ggml_compute_forward_mul_mat_one_chunk( return; } - const void * wdata = (src1->type == vec_dot_type) ? src1->data : (char *)params->wdata + params->wsize - params->qsize + GGML_MAX_NAME; + const void * wdata = (src1->type == vec_dot_type) ? src1->data : (char *)params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); assert(ne12 % ne02 == 0); @@ -14110,12 +14109,7 @@ UseGgmlGemm1:; #endif if (src1->type != vec_dot_type) { - char * wdata = (char *)params->wdata + params->wsize - params->qsize; - - if (strncmp(src1->name, wdata - GGML_MAX_NAME, GGML_MAX_NAME) == 0) { - goto AlreadyQuantized; - } - wdata += GGML_MAX_NAME; + char * wdata = params->wdata; #if IK_PRINT_TIMING int64_t t1 = ggml_time_us(); @@ -14125,7 +14119,7 @@ UseGgmlGemm1:; const size_t nbw2 = nbw1*ne11; const size_t nbw3 = nbw2*ne12; - assert(params->qsize >= ne13*nbw3); + assert(params->wsize >= ne13*nbw3); GGML_ASSERT(src1->type == GGML_TYPE_F32); for (int64_t i13 = 0; i13 < ne13; ++i13) { @@ -14157,17 +14151,14 @@ UseGgmlGemm1:; #endif if (ith == 0) { - wdata -= GGML_MAX_NAME; - memcpy(wdata, src1->name, GGML_MAX_NAME); // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. //atomic_store(¶ms->shared->current_chunk, nth); } -AlreadyQuantized:; + ggml_barrier(params->shared); } - const void * wdata = (src1->type == vec_dot_type) ? src1->data - : (const void *)((const char *)params->wdata + params->wsize - params->qsize + GGML_MAX_NAME); + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; #if GGML_USE_IQK_MULMAT if (src1->type != vec_dot_type && dst->type == GGML_TYPE_F32) { @@ -14354,10 +14345,9 @@ static void ggml_compute_forward_mul_mat_id( const int n_ids = ids->ne[0]; // n_expert_used const int n_as = ne02; // n_expert - char * qdata = (char *)params->wdata + params->wsize - params->qsize; - - char * wdata_src1_end = (src1->type == vec_dot_type) ? qdata : - qdata + GGML_PAD(GGML_MAX_NAME + ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t)); + char * wdata_src1_end = (src1->type == vec_dot_type) ? + (char *) params->wdata : + (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t)); struct mmid_row_mapping { int32_t i1; @@ -14367,19 +14357,14 @@ static void ggml_compute_forward_mul_mat_id( int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11] - bool store_name = false; if (src1->type != vec_dot_type) { - if (strncmp(src1->name, qdata, GGML_MAX_NAME) == 0) { - goto QuantizationAlreadyDone; - } - store_name = true; - char * wdata = qdata + GGML_MAX_NAME; + char * wdata = params->wdata; const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); const size_t nbw2 = nbw1*ne11; const size_t nbw3 = nbw2*ne12; - assert(params->qsize >= ne13*nbw3); + assert(params->wsize >= ne13*nbw3); GGML_ASSERT(src1->type == GGML_TYPE_F32); for (int64_t i13 = 0; i13 < ne13; ++i13) { @@ -14395,12 +14380,7 @@ static void ggml_compute_forward_mul_mat_id( #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] -QuantizationAlreadyDone:; if (ith == 0) { - if (store_name) { - memcpy(qdata, src1->name, GGML_MAX_NAME); - } - // initialize matrix_row_counts memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); @@ -14429,7 +14409,7 @@ QuantizationAlreadyDone:; const char * src0_cur = (const char *) src0->data + cur_a*nb02; - const void * wdata = (src1->type == vec_dot_type) ? src1->data : qdata + GGML_MAX_NAME; + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); const int64_t nr0 = ne01; // src0 rows @@ -21017,7 +20997,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa } size_t work_size = 0; - size_t q_size = 0; struct ggml_cplan cplan; memset(&cplan, 0, sizeof(struct ggml_cplan)); @@ -21033,7 +21012,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa max_tasks = MAX(max_tasks, n_tasks); size_t cur = 0; - size_t cur_q = 0; switch (node->op) { case GGML_OP_CPY: @@ -21064,8 +21042,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type; if (node->src[1]->type != vec_dot_type) { - cur_q = ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]); - //cur_q = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1])); + cur = ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]); } } break; case GGML_OP_MUL_MAT_ID: @@ -21075,13 +21052,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa const struct ggml_tensor * src1 = node->src[1]; const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; if (src1->type != vec_dot_type) { - cur_q += ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]); - //cur_q += ggml_row_size(vec_dot_type, ggml_nelements(src1)); + cur += ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]); } const int n_as = src0->ne[2]; - cur_q += GGML_PAD(cur, sizeof(int64_t)); // align - cur_q += n_as * sizeof(int64_t); // matrix_row_counts - cur_q += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows + cur += GGML_PAD(cur, sizeof(int64_t)); // align + cur += n_as * sizeof(int64_t); // matrix_row_counts + cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows } break; case GGML_OP_OUT_PROD: { @@ -21170,20 +21146,14 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa } work_size = MAX(work_size, cur); - q_size = MAX(q_size, cur_q); } if (work_size > 0) { work_size += CACHE_LINE_SIZE*(n_threads - 1); } - if (q_size > 0) { - q_size += GGML_MAX_NAME; - } - work_size += q_size; cplan.n_threads = MIN(max_tasks, n_threads); cplan.work_size = work_size; - cplan.q_size = q_size; cplan.work_data = NULL; return cplan; @@ -21201,7 +21171,6 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { /*.ith =*/ state->ith, /*.nth =*/ state->shared->n_threads, /*.wsize =*/ cplan->work_size, - /*.qsize =*/ cplan->q_size, /*.wdata =*/ cplan->work_data, /*.shared=*/ state->shared, }; From 33390c4b74fa52875d6028c5c9aaf84f17288c25 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 9 Feb 2025 09:14:52 +0200 Subject: [PATCH 13/14] Use Q8_K_128 for IQ1_S_R4 and IQ1_M_R4 matrix multiplications (#194) * iq1_s_r4: Use Q8_K_128 instead of Q8_1_X4 for gemm (AVX2/Zen4) * iq1_m_r4: Use Q8_K_128 instead of Q8_1_X4 for gemm (AVX2/Zen4) * iq1_s_r4: Use Q8_K_128 instead of Q8_1_X4 for gemm (Neon) * iq1_m_r4: Use Q8_K_128 instead of Q8_0_X4 for gemm (Neon) * Simdify q8_K128 quantization also on Neon * Cleanup --------- Co-authored-by: Iwan Kawrakow --- ggml/include/ggml.h | 1 + ggml/src/ggml-common.h | 5 +- ggml/src/ggml.c | 13 +++- ggml/src/iqk/iqk_mul_mat.cpp | 75 +++++++++++----------- ggml/src/iqk/iqk_quantize.cpp | 117 +++++++++++++++++++++++++++++++++- ggml/src/iqk/iqk_quantize.h | 1 + 6 files changed, 169 insertions(+), 43 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index c307d42e..66bcb25a 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -415,6 +415,7 @@ extern "C" { GGML_TYPE_Q8_K16 = 147, GGML_TYPE_Q8_K32 = 148, GGML_TYPE_Q8_KR8 = 149, + GGML_TYPE_Q8_K128 = 150, GGML_TYPE_Q4_0_R8 = 202, GGML_TYPE_Q5_0_R4 = 206, diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 0d014c23..4308f0b9 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -377,15 +377,16 @@ typedef struct { } block_q8_K; static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); typedef struct { - float d; // delta + float d; // delta int8_t qs[64]; // quants } block_q8_K64; static_assert(sizeof(block_q8_K64) == sizeof(float) + 64, "wrong q8_K64 block size/padding"); typedef struct { float d; // delta + int16_t bsums[4]; // quant sums for blocks of 32 int8_t qs[128]; // quants } block_q8_K128; -static_assert(sizeof(block_q8_K128) == sizeof(float) + 128, "wrong q8_K128 block size/padding"); +static_assert(sizeof(block_q8_K128) == sizeof(float) + 4*sizeof(int16_t) + 128, "wrong q8_K128 block size/padding"); typedef struct { ggml_half d[8]; // delta diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b19fb006..e07dd547 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1192,7 +1192,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq1_s_r4, .from_float_ref = (ggml_from_float_t)quantize_row_iq1_s_r4_ref, .vec_dot = vec_dot_iq1_s_r4_q8_k, - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_K128, .nrows = 1, .row_meta_size = 2, }, @@ -1218,7 +1218,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq1_m_r4, .from_float_ref = (ggml_from_float_t)quantize_row_iq1_m_r4_ref, .vec_dot = vec_dot_iq1_m_r4_q8_k, - .vec_dot_type = GGML_TYPE_Q8_0_X4, + .vec_dot_type = GGML_TYPE_Q8_K128, .nrows = 1, .row_meta_size = 2, }, @@ -1354,6 +1354,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q8_K64, .row_meta_size = 0, }, + [GGML_TYPE_Q8_K128] = { + .type_name = "q8_K128", + .blck_size = 128, + .type_size = sizeof(block_q8_K128), + .is_quantized = true, + .from_float = quantize_row_q8_K128, + .row_meta_size = 0, + }, [GGML_TYPE_Q8_K16] = { .type_name = "q8_K16", .blck_size = 64, @@ -16161,6 +16169,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ1_M_R4: case GGML_TYPE_Q8_K: case GGML_TYPE_Q8_K64: + case GGML_TYPE_Q8_K128: case GGML_TYPE_Q8_K16: case GGML_TYPE_Q8_K32: case GGML_TYPE_Q4_0_4_4: diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index c561ca2b..aeba2c59 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3528,26 +3528,27 @@ static void mul_mat_q4_0_r8_q8_1_avx2(int n, const void * vx, size_t bx, const D template static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); + Q8 q8(info); int nb = n / 32; GGML_ASSERT(nb%4 == 0); __m256i qx[4]; __m256 acc[nrc_y] = {}; auto m1 = _mm256_set1_epi16(1); auto ms = _mm_set1_epi16(-32768); - float d8[8*nrc_y]; + float d8[4*nrc_y]; union { __m256i vec; uint16_t val[16]; } helper; struct aux_iq1_s_r4 { uint8_t qs[16]; uint64_t qh; }; - for (int ix= 0; ix < nrc_x; ix += 4) { + for (int ix = 0; ix < nrc_x; ix += 4) { auto dptr = (const ggml_half *)((const char *)vx + ix*bx); auto d1 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)dptr)); auto x = (const aux_iq1_s_r4 *)(dptr + 4); for (int ib = 0; ib < nb/4; ++ib) { for (int iy = 0; iy < nrc_y; ++iy) { - _mm256_storeu_ps(d8 + 8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib].d))); + auto bsums = _mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib].bsums)); + _mm_storeu_ps(d8 + 4*iy, _mm_mul_ps(_mm_set1_ps(q8.y[iy][ib].d), _mm_cvtepi32_ps(bsums))); } for (int k = 0; k < 4; ++k) { auto idxh = _mm256_set1_epi64x(x[4*ib+k].qh); @@ -3556,8 +3557,8 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI scales4 = _mm_or_si128(_mm_slli_epi16(scales4, 1), _mm_set1_epi16(1)); auto signs = _mm_or_si128(_mm_cmpeq_epi16(_mm_and_si128(sas, ms), ms), _mm256_castsi256_si128(m1)); signs = _mm_add_epi16(_mm_set1_epi16(-8), signs); - auto delta4 = _mm_mul_ps(_mm_set1_ps(0.0625f), _mm_cvtepi32_ps(_mm_cvtepi16_epi32( - _mm_mullo_epi16(scales4, signs)))); + signs = _mm_mullo_epi16(signs, scales4); + auto delta4 = _mm_mul_ps(_mm_set1_ps(0.0625f), _mm_cvtepi32_ps(_mm_cvtepi16_epi32(signs))); auto delta = _mm256_set_m128(delta4, delta4); scales4 = _mm_unpacklo_epi16(scales4, scales4); // 0,0, 1,1, 2,2, 3,3 auto scales = MM256_SET_M128I(scales4, scales4); @@ -3598,8 +3599,8 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI auto sumi = _mm256_packs_epi32(sumi1, sumi2); #endif sumi = _mm256_madd_epi16(scales, sumi); - acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[8*iy+k+0]), _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[8*iy+k+4]), delta, acc[iy]); + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.y[iy][ib].d), _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[4*iy+k]), delta, acc[iy]); } } } @@ -3614,7 +3615,7 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI template static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); + Q8 q8(info); int nb = n / 32; GGML_ASSERT(nb%4 == 0); auto shuffle0 = _mm256_set_epi64x(0x0909090909090909, 0x0808080808080808, 0x0101010101010101, 0x0000000000000000); @@ -3624,17 +3625,14 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI #endif __m256i qx[4]; __m256 acc[nrc_y] = {}; + __m256i isum[nrc_y] = {}; auto ms = _mm_set1_epi8(0x08); - float d8[4*nrc_y]; union { __m256i vec; uint16_t val[16]; } helper; for (int ix= 0; ix < nrc_x; ix += 4) { auto dptr = (const ggml_half *)((const char *)vx + ix*bx); auto d1 = _mm_mul_ps(_mm_set1_ps(0.125f), _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)dptr))); auto x = (const block_iq1_m_r4 *)(dptr + 4); for (int ib = 0; ib < nb/4; ++ib) { - for (int iy = 0; iy < nrc_y; ++iy) { - _mm_storeu_ps(d8 + 4*iy, _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib].d))); - } for (int k = 0; k < 4; ++k) { auto qh = (const uint32_t *)x[4*ib+k].qh; auto idxh = _mm_set_epi32(qh[1] >> 4, qh[1], qh[0] >> 4, qh[0]); @@ -3694,10 +3692,13 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 as int16_t auto sumi = _mm256_packs_epi32(sumi1, sumi2); #endif - sumi = _mm256_madd_epi16(scales, sumi); - acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[4*iy+k]), _mm256_cvtepi32_ps(sumi), acc[iy]); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales, sumi)); } } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.y[iy][ib].d), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } } for (int iy = 0; iy < nrc_y; ++iy) { auto sumf = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); @@ -9177,7 +9178,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { #ifdef HAVE_FANCY_SIMD mm.func16 = mul_mat_iq1_s_r4_q8_1<16>; #endif - expected_typeB = GGML_TYPE_Q8_1_X4; + expected_typeB = GGML_TYPE_Q8_K128; break; case GGML_TYPE_IQ1_M_R4: assert (ne00 % QK4_NL == 0); @@ -9192,7 +9193,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { #ifdef HAVE_FANCY_SIMD mm.func16 = mul_mat_iq1_m_r4_q8_0<16>; #endif - expected_typeB = GGML_TYPE_Q8_0_X4; + expected_typeB = GGML_TYPE_Q8_K128; break; default: @@ -12072,7 +12073,7 @@ static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); - Q8<1, block_q8_1_x4> q8(info); + Q8<1, block_q8_K128> q8(info); int nb = n / 32; GGML_ASSERT(nb%4 == 0); int8x16_t qx[8]; @@ -12084,8 +12085,8 @@ static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const Dat auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr)); auto x = (const block_iq1_s_r4 *)(dptr + 4); for (int ib = 0; ib < nb/4; ++ib) { - auto scale_yd = vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[0][ib].d+0)); - auto scale_ym = vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[0][ib].d+4)); + auto scale_yd = vdupq_n_f32(q8.y[0][ib].d); + auto scale_ym = vmulq_f32(scale_yd, vcvtq_f32_s32(vmovl_s16(vld1_s16(q8.y[0][ib].bsums)))); for (int k = 0; k < 4; ++k) { auto sas = vld1_u16(x[4*ib+k].qh); auto scales4 = vand_u16(vshr_n_u16(sas, 12), vdup_n_u16(7)); @@ -12135,23 +12136,22 @@ static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const Dat template static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); + Q8 q8(info); int nb = n / 32; GGML_ASSERT(nb%4 == 0); uint8x16_t qx[8]; int32x4_t acc[nrc_y] = {}; auto ms = vdup_n_u16(0x8000); auto mask = vdupq_n_s8(0x03); - float d8[8*nrc_y]; + float d8[4*nrc_y]; for (int ix= 0; ix < nrc_x; ix += 4) { auto dptr = (const ggml_half *)((const char *)vx + ix*bx); auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr)); auto x = (const block_iq1_s_r4 *)(dptr + 4); for (int ib = 0; ib < nb/4; ++ib) { for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = vld1q_f16((const float16_t *)q8.y[iy][ib].d); - vst1q_f32(d8+8*iy+0, vcvt_f32_f16(vget_low_f16(scales))); - vst1q_f32(d8+8*iy+4, vcvt_f32_f16(vget_high_f16(scales))); + auto scales = vcvtq_f32_s32(vmovl_s16(vld1_s16(q8.y[iy][ib].bsums))); + vst1q_f32(d8+4*iy, vmulq_f32(vdupq_n_f32(q8.y[iy][ib].d), scales)); } for (int k = 0; k < 4; ++k) { auto sas = vld1_u16(x[4*ib+k].qh); @@ -12193,8 +12193,8 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[6]), y.val[1], 2); sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[7]), y.val[1], 3); sumi = vmulq_s32(scales, sumi); - acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[8*iy+k+0]), vcvtq_f32_s32(sumi)); - acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[8*iy+k+4]), delta4); + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.y[iy][ib].d), vcvtq_f32_s32(sumi)); + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[4*iy+k]), delta4); } } } @@ -12208,25 +12208,21 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI template static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); + Q8 q8(info); int nb = n / 32; GGML_ASSERT(nb%4 == 0); int8x16_t qx[8]; - int32x4_t acc[nrc_y] = {}; + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[nrc_y] = {}; auto shuffle0 = uint32x4_t{0x00000000, 0x01010101, 0x02020202, 0x03030303}; auto step = vdupq_n_u8(4); auto ms = vdupq_n_u8(0x08); auto mask = vdupq_n_s8(0x18); - float d8[4*nrc_y]; for (int ix= 0; ix < nrc_x; ix += 4) { auto dptr = (const ggml_half *)((const char *)vx + ix*bx); auto d1 = vmulq_f32(vdupq_n_f32(0.125f), vcvt_f32_f16(vld1_f16((const float16_t *)dptr))); auto x = (const block_iq1_m_r4 *)(dptr + 4); for (int ib = 0; ib < nb/4; ++ib) { - for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = vld1_f16((const float16_t *)q8.y[iy][ib].d); - vst1q_f32(d8+4*iy, vcvt_f32_f16(scales)); - } for (int k = 0; k < 4; ++k) { auto scales4 = vdup_n_u32(((const uint32_t *)x[4*ib+k].scales)[0]); scales4 = vand_u8(vshl_u32(scales4, int32x2_t{0, -4}), vdup_n_u8(0xf)); @@ -12272,10 +12268,13 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[5]), y.val[1], 1); sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[6]), y.val[1], 2); sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[7]), y.val[1], 3); - auto sumi = vmlaq_s32(vmlaq_s32(vdupq_n_s32(0), sumi1, scales1), sumi2, scales2); - acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[4*iy+k]), vcvtq_f32_s32(sumi)); + isum[iy] = vmlaq_s32(vmlaq_s32(isum[iy], sumi1, scales1), sumi2, scales2); } } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.y[iy][ib].d), vcvtq_f32_s32(isum[iy])); + isum[iy] = vdupq_n_s32(0); + } } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, vmulq_f32(d1, acc[iy])); @@ -13907,12 +13906,12 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_r4_q8_1); m.funcs[0] = mul_mat_iq1_s_r4_q8_1_1; m.func16 = mul_mat_iq1_s_r4_q8_1<16>; - expected_Btype = GGML_TYPE_Q8_1_X4; + expected_Btype = GGML_TYPE_Q8_K128; break; case GGML_TYPE_IQ1_M_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_m_r4_q8_0); m.func16 = mul_mat_iq1_m_r4_q8_0<16>; - expected_Btype = GGML_TYPE_Q8_0_X4; + expected_Btype = GGML_TYPE_Q8_K128; break; case GGML_TYPE_IQ3_XXS_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_xxs_r4_q8_k); diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index a01ed109..f33fc183 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -2733,6 +2733,7 @@ size_t quantize_iq6_k(const float * src, void * dst, int64_t nrows, int64_t n_pe return nrows * nblock * sizeof(block_iq6_k); } +namespace { template void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { assert(k % QK_K == 0); @@ -2843,7 +2844,7 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { x += QK_K; } #endif - +} } void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) { @@ -2858,6 +2859,120 @@ void quantize_row_q8_KR8(const float * x, void * vy, int64_t k) { iqk_quantize_row_q8_K_T<2>(x, vy, k); } +namespace { +// TODO: merge this with the above template +void iqk_quantize_row_q8_K128(const float * x, void * vy, int64_t k) { + constexpr int kBlockSize = 128; + assert(k % kBlockSize == 0); + const int nb = k / kBlockSize; + auto y = (block_q8_K128 *)vy; +#ifdef __AVX2__ + const __m256 signBit = _mm256_set1_ps(-0.0f); + const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + for (int i = 0; i < nb; i++) { + const float * xb = x + i*kBlockSize; + __m256 maxAbs = _mm256_setzero_ps(); + const float * xx = xb; + for (int ib = 0; ib < kBlockSize/8; ++ib) { + const __m256 v = _mm256_loadu_ps(xx); xx += 8; + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps(signBit, v)); + } + const float maxScalar = hmax_f32_8(maxAbs); + const float d = maxScalar / 127.f; + y[i].d = d; + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + xx = xb; + int8_t * q8 = y[i].qs; + for (int ib = 0; ib < kBlockSize/32; ++ib) { + __m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8; + __m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8; + __m256 v2 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8; + __m256 v3 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8; + v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); + v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); + v2 = _mm256_round_ps(v2, _MM_ROUND_NEAREST); + v3 = _mm256_round_ps(v3, _MM_ROUND_NEAREST); + __m256i i0 = _mm256_cvtps_epi32(v0); + __m256i i1 = _mm256_cvtps_epi32(v1); + __m256i i2 = _mm256_cvtps_epi32(v2); + __m256i i3 = _mm256_cvtps_epi32(v3); + y[i].bsums[ib] = hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + i0 = _mm256_packs_epi32( i0, i1 ); + i2 = _mm256_packs_epi32( i2, i3 ); + i0 = _mm256_packs_epi16( i0, i2 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + _mm256_storeu_si256((__m256i *)q8, i0); + q8 += 32; + } + } +#elif defined __ARM_NEON + int32x4_t ival[8]; + for (int i = 0; i < nb; i++) { + const float * xb = x + i*kBlockSize; + auto vmax = vdupq_n_f32(0.f); + for (int j = 0; j < kBlockSize; j += 4) { + vmax = vmaxq_f32(vmax, vabsq_f32(vld1q_f32(xb + j))); + } + auto smax = vmaxvq_f32(vmax); + if (!smax) { + std::memset(&y[i], 0, sizeof(y[i])); + continue; + } + y[i].d = smax/127; + auto vid = vdupq_n_f32(127/smax); + for (int ib = 0; ib < kBlockSize/32; ++ib) { + auto isum = vdupq_n_s32(0); + for (int k = 0; k < 8; ++k) { + auto val = vld1q_f32(xb + 32*ib + 4*k); + ival[k] = vcvtnq_s32_f32(vmulq_f32(val, vid)); + isum = vaddq_s32(isum, ival[k]); + } + y[i].bsums[ib] = vaddvq_s32(isum); + for (int k = 0; k < 4; ++k) { + auto i16 = vcombine_s16(vmovn_s32(ival[2*k+0]), vmovn_s32(ival[2*k+1])); + vst1_s8(y[i].qs + 32*ib + 8*k, vmovn_s16(i16)); + } + } + } +#else + for (int i = 0; i < nb; i++) { + + float amax = 0; + for (int j = 0; j < kBlockSize; ++j) { + float ax = std::abs(x[j]); + amax = std::max(amax, ax); + } + if (!amax) { + y[i].d = 0; + memset(y[i].qs, 0, kBlockSize); + memset(y[i].bsums, 0, kBlockSize/32*(sizeof(int16_t))); + x += kBlockSize; + continue; + } + const float iscale = 127.f/amax; + for (int j = 0; j < kBlockSize; ++j) { + int v = nearest_int(iscale*x[j]); + y[i].qs[j] = v; + } + for (int j = 0; j < kBlockSize/32; ++j) { + int sum = 0; + for (int ii = 0; ii < 32; ++ii) { + sum += y[i].qs[j*32 + ii]; + } + y[i].bsums[j] = sum; + } + y[i].d = 1/iscale; + x += kBlockSize; + } +#endif +} +} + +void quantize_row_q8_K128(const float * x, void * vy, int64_t k) { + iqk_quantize_row_q8_K128(x, vy, k); +} + namespace { static void quantize_row_iq4_k_impl_bs128(const int super_block_size, const int block_size, int n_per_row, const float * x, char * cy, diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index ff553ae7..97719361 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -220,6 +220,7 @@ void vec_dot_q8_k_r8_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const voi void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K128(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K16(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K32(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_KR8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); From cae2b81155fdad75b7beab3a835c438120412969 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 9 Feb 2025 18:59:33 +0200 Subject: [PATCH 14/14] FA: Add option to build all FA kernels (#197) Similar to the CUDA situation. It is OFF by default. If OFF, only F16, Q8_0, Q6_0, and, if the CPU provides native BF16 support, BF16 FA kernels will be included. To enable all, cmake -DGGML_IQK_FA_ALL_QUANTS=1 ... This cuts compilation time for iqk_mul_mat.cpp by almost half (45 seconds vs 81 seconds on my Ryzen-7950X). Co-authored-by: Iwan Kawrakow --- ggml/CMakeLists.txt | 2 ++ ggml/src/CMakeLists.txt | 4 +++ ggml/src/iqk/iqk_mul_mat.cpp | 66 ++++++++++++++++++------------------ 3 files changed, 39 insertions(+), 33 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 90b37d5b..6775fdcb 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -130,6 +130,8 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF) option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" OFF) +option(GGML_IQK_FA_ALL_QUANTS "ggml: compile all quants for IQK FlashAttention" OFF) + option(GGML_CURL "ggml: use libcurl to download model from an URL" OFF) option(GGML_HIPBLAS "ggml: use hipBLAS" OFF) option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index da1746c8..3d1a2970 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -259,6 +259,10 @@ if (GGML_IQK_MUL_MAT) add_compile_definitions(GGML_USE_IQK_MULMAT) set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp) set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h) + if (GGML_IQK_FA_ALL_QUANTS) + message(STATUS "Including all IQK FA kernels") + add_compile_definitions(GGML_IQK_FA_ALL_QUANTS) + endif() endif() if (GGML_LLAMAFILE) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index aeba2c59..ee0af7e9 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -15239,14 +15239,7 @@ struct FlashQKfp32 { case 7: return std::make_pair(mul_mat<7>, 7);\ }\ } - if constexpr (std::is_same_v>) { -#ifdef __aarch64__ - MAKE_FUNCS(mul_mat_qX_0_q8_0>) { + if constexpr (std::is_same_v>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0>) { +#ifdef __aarch64__ + MAKE_FUNCS(mul_mat_qX_0_q8_0>) { +#ifdef __aarch64__ + MAKE_FUNCS(mul_mat_qX_0_q8_0>) { @@ -15278,13 +15286,7 @@ struct FlashQKfp32 { MAKE_FUNCS(mul_mat_qX_1_q8_1_T>) { -#ifdef __aarch64__ - MAKE_FUNCS(mul_mat_qX_0_q8_0 void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { -// if constexpr (std::is_same_v> || std::is_same_v> || -// std::is_same_v> || -// std::is_same_v> || -// std::is_same_v> || -// std::is_same_v>) { -// compute_helper_q>( -// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); -// } else { -// compute_helper>( -// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); -// } if constexpr (std::is_same_v> || std::is_same_v> || std::is_same_v> || std::is_same_v>) { @@ -16027,6 +16018,11 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperQ80 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; + case GGML_TYPE_Q6_0: { + HelperQ60 vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + } break; +#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); @@ -16039,10 +16035,7 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperIQ4nl vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; - case GGML_TYPE_Q6_0: { - HelperQ60 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); - } break; +#endif default: break; } } @@ -16062,6 +16055,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperQ80 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; + case GGML_TYPE_Q6_0: { + HelperQ60 kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + } break; +#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); @@ -16074,10 +16072,7 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperIQ4nl kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; - case GGML_TYPE_Q6_0: { - HelperQ60 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); - } break; +#endif default: break; } @@ -16087,8 +16082,12 @@ inline bool flash_attn_is_supported(ggml_type type) { #ifdef __AVX512BF16__ if (type == GGML_TYPE_BF16) return true; #endif +#if GGML_IQK_FA_ALL_QUANTS if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true; +#else + if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0) return true; +#endif return false; } } @@ -16115,6 +16114,7 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k auto type_v = ggml_type(int_type_v); if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false; if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32 + if (D != 64 && D != 96 && D != 128 && D != 256) return false; auto ck = (const char *)k; auto cv = (const char *)v;