mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Unroll for loop for repacked BF16 MATMUL (#1047)
see https://github.com/ikawrakow/ik_llama.cpp/discussions/1028 for detail
This commit is contained in:
@@ -923,6 +923,15 @@ static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values,
|
||||
|
||||
#endif
|
||||
|
||||
// static unrool for:
|
||||
template<int N, typename T>
|
||||
inline void static_for(T&&f) {
|
||||
if constexpr(N>0) {
|
||||
static_for<N-1>(f);
|
||||
f(N-1);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#include <intrin.h>
|
||||
|
||||
@@ -333,7 +333,8 @@ template <int nrc_y>
|
||||
static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%16 == 0);
|
||||
const ggml_bf16_t * y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
|
||||
static_for<nrc_y>([&](const int iy) { y[iy] = (const ggml_bf16_t *)info.src1_row(iy); });
|
||||
|
||||
for (int ix = 0; ix < nrc_x/32; ++ix) {
|
||||
__m512 acc[2*nrc_y] = {};
|
||||
__m512bh qx[8];
|
||||
@@ -348,7 +349,7 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
|
||||
qx[5] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+1);
|
||||
qx[6] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+2);
|
||||
qx[7] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
static_for<nrc_y>([&](const int iy) {
|
||||
auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
|
||||
//auto y = _mm512_broadcast_i32x4(y128);
|
||||
auto y256 = MM256_SET_M128I(y128, y128);
|
||||
@@ -361,12 +362,12 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
|
||||
acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[5], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
|
||||
acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[6], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
|
||||
acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[7], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
|
||||
}
|
||||
});
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
static_for<nrc_y>([&](const int iy) {
|
||||
info.store(32*ix+ 0, iy, acc[2*iy+0]);
|
||||
info.store(32*ix+16, iy, acc[2*iy+1]);
|
||||
}
|
||||
});
|
||||
}
|
||||
for (int ix = 32*(nrc_x/32); ix < nrc_x; ix += 16) {
|
||||
__m512 acc[nrc_y] = {};
|
||||
@@ -377,7 +378,7 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
|
||||
qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+1);
|
||||
qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+2);
|
||||
qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
static_for<nrc_y>([&](const int iy) {
|
||||
auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
|
||||
auto y256 = MM256_SET_M128I(y128, y128);
|
||||
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
|
||||
@@ -385,11 +386,11 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
|
||||
acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
|
||||
acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
|
||||
acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
|
||||
}
|
||||
});
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
static_for<nrc_y>([&](const int iy) {
|
||||
info.store(ix, iy, acc[iy]);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user