mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-27 09:53:40 +00:00
Improve gemv for bf16_r16
It is better to process one "row" at a time and to have 4 accumulators. I guess, this allows better interleving of load and fmadd instructions. We get ~10% better performance for 1 thread, and fully saturate memory bandwidth at 2 threads with a ~3.5% better performance (4.4 vs 4.25 t/s for L3-8B).
This commit is contained in:
@@ -5033,6 +5033,32 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
|
||||
}
|
||||
}
|
||||
}
|
||||
static void mul_mat_bf16_r16_bf16_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%16 == 0);
|
||||
const ggml_bf16_t * y = (const ggml_bf16_t *)info.src1_row(0);
|
||||
for (int ix = 0; ix < nrc_x; ix += 16) {
|
||||
__m512 acc[4] = {};
|
||||
__m512bh qx[4];
|
||||
const ggml_bf16_t * b8 = (const ggml_bf16_t *)((const char *)vx + ix*bx);
|
||||
for (int ib = 0; ib < n/8; ++ib) {
|
||||
auto y128 = _mm_loadu_si128((const __m128i*)y+ib);
|
||||
auto y256 = MM256_SET_M128I(y128, y128);
|
||||
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
|
||||
qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+0);
|
||||
qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+1);
|
||||
qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+2);
|
||||
qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+3);
|
||||
acc[0] = _mm512_dpbf16_ps(acc[0], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
|
||||
acc[1] = _mm512_dpbf16_ps(acc[1], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
|
||||
acc[2] = _mm512_dpbf16_ps(acc[2], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
|
||||
acc[3] = _mm512_dpbf16_ps(acc[3], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
|
||||
}
|
||||
acc[0] = _mm512_add_ps(acc[0], acc[1]);
|
||||
acc[2] = _mm512_add_ps(acc[2], acc[3]);
|
||||
acc[0] = _mm512_add_ps(acc[0], acc[2]);
|
||||
info.store(ix, 0, acc[0]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <int nrc_y>
|
||||
@@ -7436,7 +7462,7 @@ void set_mul_mat_bf16(MulMat& mm) {
|
||||
}
|
||||
void set_mul_mat_bf16_r16(MulMat& mm) {
|
||||
for (auto& f : mm.funcs) f = nullptr;
|
||||
mm.funcs[0] = mul_mat_bf16_r16_bf16<1>;
|
||||
mm.funcs[0] = mul_mat_bf16_r16_bf16_1;
|
||||
mm.funcs[1] = mul_mat_bf16_r16_bf16<2>;
|
||||
mm.funcs[2] = mul_mat_bf16_r16_bf16<3>;
|
||||
mm.funcs[3] = mul_mat_bf16_r16_bf16<4>;
|
||||
|
||||
Reference in New Issue
Block a user