diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 22c89911..3a0648ac 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -2074,190 +2074,110 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf #endif // Zen4 or vanilla AVX2 - -template -static void mul_mat_iq4_nl_x4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); - //printf("%s: n = %d, nrc_x = %d, bx = %zu\n", __func__, n, nrc_x, bx); - Q8 q8(info); - auto m4 = _mm256_set1_epi8(0xf); - auto values = load_iq4nl_values_256(); - int nb = n / QK4_NL; - //float dequant[128], stored[4], s[4*nrc_y]; - //float dequant[128], s[4*nrc_y]; - GGML_ASSERT(nb%4 == 0); - __m256 acc[nrc_y] = {}; - for (int ix = 0; ix < nrc_x; ix += 4) { - const block_iq4_nl_x4 * iq4 = (const block_iq4_nl_x4 *)((const char *)vx + ix*bx); - //for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = _mm256_setzero_ps(); - //for (int j = 0; j < 4*nrc_y; ++j) s[j] = 0; - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - for (int k = 0; k < 4; ++k) { - //dequantize_row_iq4_nl_x4(iq4+4*ib4+k, dequant, 128); - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[4*ib4+k].d)); - auto scales = _mm256_set_m128(scales128, scales128); - auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-64.f)); - //auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-128.f)); - auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+0); - auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+1); - auto q1 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); - auto q2 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); - auto q3 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); - auto q4 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); - __m256i sumi; - for (int iy = 0; iy < nrc_y; ++iy) { - //float d8 = GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]); - //const int8_t * qs = q8.y[iy][ib4].qs+32*k; - ////s[0] = s[1] = s[2] = s[3] = 0; - //for (int j = 0; j < 32; ++j) { - // s[4*iy+0] += d8*dequant[j+ 0]*qs[j]; - // s[4*iy+1] += d8*dequant[j+32]*qs[j]; - // s[4*iy+2] += d8*dequant[j+64]*qs[j]; - // s[4*iy+3] += d8*dequant[j+96]*qs[j]; - //} - ////s[0] *= d8; s[1] *= d8; s[2] *= d8; s[3] *= d8; - auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), q1, _mm256_shuffle_epi32(y, 0x00)); - sumi = _mm256_dpbusd_epi32(sumi, q2, _mm256_shuffle_epi32(y, 0x55)); - sumi = _mm256_dpbusd_epi32(sumi, q3, _mm256_shuffle_epi32(y, 0xaa)); - sumi = _mm256_dpbusd_epi32(sumi, q4, _mm256_shuffle_epi32(y, 0xff)); - //auto d4d8 = _mm256_mul_ps(scales, _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k]))); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]))); - //auto check = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), _mm256_setzero_ps()); - //check = _mm256_fmadd_ps(scales_m, _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k+4])), check); - //auto check128 = _mm_add_ps(_mm256_castps256_ps128(check), _mm256_extractf128_ps(check, 1)); - //_mm_storeu_ps(stored, check128); - //bool is_good = true; - //for (int j = 0; j < 4; ++j) { - // if (std::abs(stored[j] - s[j]) > 1e-1*(std::abs(s[j])+std::abs(stored[j]))) is_good = false; - //} - //if (!is_good) { - // static int ncount = 0; - // printf("Oops\n"); - // for (int j = 0; j < 4; ++j) printf("%g vs %g, diff = %g, thresh = %g\n", s[j], stored[j], std::abs(stored[j] - s[j]), 1e-2f*std::abs(s[j])); - // _mm_storeu_ps(stored, scales128); - // printf("iq4_nl scales: %g, %g, %g, %g\n", stored[0], stored[1], stored[2], stored[3]); - // printf("d8 = %g, %g\n", d8, GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])); - // _mm_storeu_ps(stored, _mm256_castps256_ps128(d4d8)); - // printf("iq4_nl scales x q8: %g, %g, %g, %g", stored[0], stored[1], stored[2], stored[3]); - // _mm_storeu_ps(stored, _mm256_extractf128_ps(d4d8, 1)); - // printf(" %g, %g, %g, %g\n", stored[0], stored[1], stored[2], stored[3]); - // check = _mm256_mul_ps(scales_m, _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k+4]))); - // check128 = _mm_add_ps(_mm256_castps256_ps128(check), _mm256_extractf128_ps(check, 1)); - // _mm_storeu_ps(stored, check128); - // printf("minus: %g, %g, %g, %g\n", stored[0], stored[1], stored[2], stored[3]); - // check = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), _mm256_setzero_ps()); - // check128 = _mm_add_ps(_mm256_castps256_ps128(check), _mm256_extractf128_ps(check, 1)); - // _mm_storeu_ps(stored, check128); - // printf("plus: %g, %g, %g, %g\n", stored[0], stored[1], stored[2], stored[3]); - // auto sumi128 = _mm_add_epi32(_mm256_castsi256_si128(sumi), _mm256_extractf128_si256(sumi, 1)); - // _mm_storeu_si128((__m128i *)stored, sumi128); - // const int32_t * aux32 = (const int32_t *)stored; - // printf("sumi: %d, %d, %d, %d\n", aux32[0], aux32[1], aux32[2], aux32[3]); - // //uint8_t aux[32]; - // //_mm256_storeu_si256((__m256i *)aux, q1); - // //printf("=== q1\n"); - // //for (int j = 0; j < 32; ++j) printf("%d %u\n", j, aux[j]); - // //_mm256_storeu_si256((__m256i *)aux, q2); - // //printf("=== q2\n"); - // //for (int j = 0; j < 32; ++j) printf("%d %u\n", j, aux[j]); - // //_mm256_storeu_si256((__m256i *)aux, q3); - // //printf("=== q3\n"); - // //for (int j = 0; j < 32; ++j) printf("%d %u\n", j, aux[j]); - // //_mm256_storeu_si256((__m256i *)aux, q4); - // //printf("=== q4\n"); - // //for (int j = 0; j < 32; ++j) printf("%d %u\n", j, aux[j]); - // if (++ncount > 10) GGML_ABORT("fatal error"); - //} - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - //acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k+4])), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[iy]); - } - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); - //_mm_storeu_ps(stored, sum); - //bool is_good = true; - //for (int j = 0; j < 4; ++j) { - // if (std::abs(stored[j] - s[j]) > 1e-4f) { printf("Oops: %g vs %g\n", stored[j], s[j]); is_good = false; } - //} - //if (!is_good) GGML_ABORT("fatal error"); - info.store(ix, iy, sum); - //for (int k = 0; k < 4; ++k) info.store(ix+k, iy, s[4*iy+k]); - acc[iy] = _mm256_setzero_ps(); - } - } -} - //template //static void mul_mat_iq4_nl_x4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { -// GGML_ASSERT(nrc_x%8 == 0); -// //printf("%s: n = %d, nrc_x = %d, bx = %zu\n", __func__, n, nrc_x, bx); +// GGML_ASSERT(nrc_x%4 == 0); // Q8 q8(info); // auto m4 = _mm256_set1_epi8(0xf); // auto values = load_iq4nl_values_256(); // int nb = n / QK4_NL; // GGML_ASSERT(nb%4 == 0); -// __m256 acc[2*nrc_y] = {}; -// __m256i qx[8]; -// for (int ix = 0; ix < nrc_x; ix += 8) { -// const block_iq4_nl_x4 * iq4l = (const block_iq4_nl_x4 *)((const char *)vx + (ix+0)*bx); -// const block_iq4_nl_x4 * iq4h = (const block_iq4_nl_x4 *)((const char *)vx + (ix+4)*bx); +// //__m256 acc[nrc_y] = {}; +// __m256 acc[2*nrc_y] = {}; +// for (int ix = 0; ix < nrc_x; ix += 4) { +// const block_iq4_nl_x4 * iq4 = (const block_iq4_nl_x4 *)((const char *)vx + ix*bx); // for (int ib4 = 0; ib4 < nb/4; ++ib4) { // for (int k = 0; k < 4; ++k) { -// auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[4*ib4+k].d)); -// auto scales1 = _mm256_set_m128(scales128, scales128); -// auto scales1_m = _mm256_mul_ps(scales1, _mm256_set1_ps(-64.f)); -// auto bits1 = _mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0); -// auto bits2 = _mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1); -// qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); -// qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); -// qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); -// qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); -// scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[4*ib4+k].d)); -// auto scales2 = _mm256_set_m128(scales128, scales128); -// auto scales2_m = _mm256_mul_ps(scales2, _mm256_set1_ps(-64.f)); -// bits1 = _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0); -// bits2 = _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1); -// qx[4] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); -// qx[5] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); -// qx[6] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); -// qx[7] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); -// __m256i sumi1, sumi2; +// auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[4*ib4+k].d)); +// auto scales = _mm256_set_m128(scales128, scales128); +// auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-64.f)); +// auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+0); +// auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+1); +// auto q1 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); +// auto q2 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); +// auto q3 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); +// auto q4 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); +// __m256i sumi; // for (int iy = 0; iy < nrc_y; ++iy) { // auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); -// auto sy = _mm256_shuffle_epi32(y, 0x00); -// sumi1 = _mm256_dpbusd_epi32(qx[0], sy, _mm256_setzero_si256()); -// sumi2 = _mm256_dpbusd_epi32(qx[4], sy, _mm256_setzero_si256()); -// sy = _mm256_shuffle_epi32(y, 0x55); -// sumi1 = _mm256_dpbusd_epi32(qx[1], sy, sumi1); -// sumi2 = _mm256_dpbusd_epi32(qx[5], sy, sumi2); -// sy = _mm256_shuffle_epi32(y, 0xaa); -// sumi1 = _mm256_dpbusd_epi32(qx[2], sy, sumi1); -// sumi2 = _mm256_dpbusd_epi32(qx[6], sy, sumi2); -// sy = _mm256_shuffle_epi32(y, 0xff); -// sumi1 = _mm256_dpbusd_epi32(qx[3], sy, sumi1); -// sumi2 = _mm256_dpbusd_epi32(qx[7], sy, sumi2); -// auto dy = _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k])); -// acc[2*iy+0] = _mm256_fmadd_ps(_mm256_mul_ps(scales1, dy), _mm256_cvtepi32_ps(sumi1), acc[2*iy+0]); -// acc[2*iy+1] = _mm256_fmadd_ps(_mm256_mul_ps(scales2, dy), _mm256_cvtepi32_ps(sumi2), acc[2*iy+1]); -// dy = _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k+4])); -// acc[2*iy+0] = _mm256_fmadd_ps(scales1_m, dy, acc[2*iy+0]); -// acc[2*iy+1] = _mm256_fmadd_ps(scales2_m, dy, acc[2*iy+1]); +// sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), q1, _mm256_shuffle_epi32(y, 0x00)); +// sumi = _mm256_dpbusd_epi32(sumi, q2, _mm256_shuffle_epi32(y, 0x55)); +// sumi = _mm256_dpbusd_epi32(sumi, q3, _mm256_shuffle_epi32(y, 0xaa)); +// sumi = _mm256_dpbusd_epi32(sumi, q4, _mm256_shuffle_epi32(y, 0xff)); +// auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]))); +// //acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); +// //acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[iy]); +// acc[2*iy+0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[2*iy+0]); +// acc[2*iy+1] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); // } // } // } // for (int iy = 0; iy < nrc_y; ++iy) { -// auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[2*iy+0]), _mm256_extractf128_ps(acc[2*iy+0], 1)); -// info.store(ix+0, iy, sum); -// sum = _mm_add_ps(_mm256_castps256_ps128(acc[2*iy+1]), _mm256_extractf128_ps(acc[2*iy+1], 1)); -// info.store(ix+4, iy, sum); +// auto sum256 = _mm256_add_ps(acc[2*iy+0], acc[2*iy+1]); // acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_ps(); +// auto sum = _mm_add_ps(_mm256_castps256_ps128(sum256), _mm256_extractf128_ps(sum256, 1)); +// info.store(ix, iy, sum); +// //auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); +// //info.store(ix, iy, sum); +// //acc[iy] = _mm256_setzero_ps(); // } // } //} +template +static void mul_mat_iq4_nl_x4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + auto m4 = _mm512_set1_epi8(0xf); + auto values = load_iq4nl_values_512(); + int nb = n / QK4_NL; + GGML_ASSERT(nb%4 == 0); + __m512 acc[2*nrc_y] = {}; + __m512i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_iq4_nl_x4 * iq4l = (const block_iq4_nl_x4 *)((const char *)vx + (ix+0)*bx); + const block_iq4_nl_x4 * iq4h = (const block_iq4_nl_x4 *)((const char *)vx + (ix+4)*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int k = 0; k < 4; ++k) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[4*ib4+k].d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[4*ib4+k].d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-64.f)); + auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)), + _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1); + auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)), + _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1); + qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4)); + qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4)); + qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4)); + qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); + info.store(ix+0, iy, sum1); + info.store(ix+4, iy, sum2); + } + } +} + template inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) { if (j == 0) {