diff --git a/ggml/src/iqk/iqk_common.h b/ggml/src/iqk/iqk_common.h index 2f677853..5503bdf9 100644 --- a/ggml/src/iqk/iqk_common.h +++ b/ggml/src/iqk/iqk_common.h @@ -923,6 +923,15 @@ static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values, #endif +// static unrool for: +template +inline void static_for(T&&f) { + if constexpr(N>0) { + static_for(f); + f(N-1); + } +} + #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #include diff --git a/ggml/src/iqk/iqk_gemm_floats.cpp b/ggml/src/iqk/iqk_gemm_floats.cpp index 5165eb98..56ce52a4 100644 --- a/ggml/src/iqk/iqk_gemm_floats.cpp +++ b/ggml/src/iqk/iqk_gemm_floats.cpp @@ -333,7 +333,8 @@ template static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%16 == 0); const ggml_bf16_t * y[nrc_y]; - for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy); + static_for([&](const int iy) { y[iy] = (const ggml_bf16_t *)info.src1_row(iy); }); + for (int ix = 0; ix < nrc_x/32; ++ix) { __m512 acc[2*nrc_y] = {}; __m512bh qx[8]; @@ -348,7 +349,7 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI qx[5] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+1); qx[6] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+2); qx[7] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+3); - for (int iy = 0; iy < nrc_y; ++iy) { + static_for([&](const int iy) { auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib); //auto y = _mm512_broadcast_i32x4(y128); auto y256 = MM256_SET_M128I(y128, y128); @@ -361,12 +362,12 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[5], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[6], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[7], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - } + }); } - for (int iy = 0; iy < nrc_y; ++iy) { + static_for([&](const int iy) { info.store(32*ix+ 0, iy, acc[2*iy+0]); info.store(32*ix+16, iy, acc[2*iy+1]); - } + }); } for (int ix = 32*(nrc_x/32); ix < nrc_x; ix += 16) { __m512 acc[nrc_y] = {}; @@ -377,7 +378,7 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+1); qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+2); qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+3); - for (int iy = 0; iy < nrc_y; ++iy) { + static_for([&](const int iy) { auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib); auto y256 = MM256_SET_M128I(y128, y128); auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1); @@ -385,11 +386,11 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - } + }); } - for (int iy = 0; iy < nrc_y; ++iy) { + static_for([&](const int iy) { info.store(ix, iy, acc[iy]); - } + }); } }