Unroll for loop for repacked BF16 MATMUL (#1047)

see https://github.com/ikawrakow/ik_llama.cpp/discussions/1028 for
detail
This commit is contained in:
Djip007
2025-12-08 06:09:45 +01:00
committed by GitHub
parent 2f645f2579
commit 5669d39036
2 changed files with 19 additions and 9 deletions

View File

@@ -923,6 +923,15 @@ static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values,
#endif
// static unrool for:
template<int N, typename T>
inline void static_for(T&&f) {
if constexpr(N>0) {
static_for<N-1>(f);
f(N-1);
}
}
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#include <intrin.h>

View File

@@ -333,7 +333,8 @@ template <int nrc_y>
static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%16 == 0);
const ggml_bf16_t * y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
static_for<nrc_y>([&](const int iy) { y[iy] = (const ggml_bf16_t *)info.src1_row(iy); });
for (int ix = 0; ix < nrc_x/32; ++ix) {
__m512 acc[2*nrc_y] = {};
__m512bh qx[8];
@@ -348,7 +349,7 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
qx[5] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+1);
qx[6] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+2);
qx[7] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+3);
for (int iy = 0; iy < nrc_y; ++iy) {
static_for<nrc_y>([&](const int iy) {
auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
//auto y = _mm512_broadcast_i32x4(y128);
auto y256 = MM256_SET_M128I(y128, y128);
@@ -361,12 +362,12 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[5], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[6], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[7], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
}
});
}
for (int iy = 0; iy < nrc_y; ++iy) {
static_for<nrc_y>([&](const int iy) {
info.store(32*ix+ 0, iy, acc[2*iy+0]);
info.store(32*ix+16, iy, acc[2*iy+1]);
}
});
}
for (int ix = 32*(nrc_x/32); ix < nrc_x; ix += 16) {
__m512 acc[nrc_y] = {};
@@ -377,7 +378,7 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+1);
qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+2);
qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+3);
for (int iy = 0; iy < nrc_y; ++iy) {
static_for<nrc_y>([&](const int iy) {
auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
auto y256 = MM256_SET_M128I(y128, y128);
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
@@ -385,11 +386,11 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
}
});
}
for (int iy = 0; iy < nrc_y; ++iy) {
static_for<nrc_y>([&](const int iy) {
info.store(ix, iy, acc[iy]);
}
});
}
}