Improve gemv for bf16_r16

It is better to process one "row" at a time and to have
4 accumulators. I guess, this allows better interleving of
load and fmadd instructions. We get ~10% better performance
for 1 thread, and fully saturate memory bandwidth at 2 threads
with a ~3.5% better performance (4.4 vs 4.25 t/s for L3-8B).
This commit is contained in:
Iwan Kawrakow
2025-01-23 08:29:48 +02:00
parent 6d23495b9b
commit 4941c043bb

View File

@@ -5033,6 +5033,32 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
}
}
}
static void mul_mat_bf16_r16_bf16_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%16 == 0);
const ggml_bf16_t * y = (const ggml_bf16_t *)info.src1_row(0);
for (int ix = 0; ix < nrc_x; ix += 16) {
__m512 acc[4] = {};
__m512bh qx[4];
const ggml_bf16_t * b8 = (const ggml_bf16_t *)((const char *)vx + ix*bx);
for (int ib = 0; ib < n/8; ++ib) {
auto y128 = _mm_loadu_si128((const __m128i*)y+ib);
auto y256 = MM256_SET_M128I(y128, y128);
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+0);
qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+1);
qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+2);
qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+3);
acc[0] = _mm512_dpbf16_ps(acc[0], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
acc[1] = _mm512_dpbf16_ps(acc[1], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
acc[2] = _mm512_dpbf16_ps(acc[2], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
acc[3] = _mm512_dpbf16_ps(acc[3], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
}
acc[0] = _mm512_add_ps(acc[0], acc[1]);
acc[2] = _mm512_add_ps(acc[2], acc[3]);
acc[0] = _mm512_add_ps(acc[0], acc[2]);
info.store(ix, 0, acc[0]);
}
}
#endif
template <int nrc_y>
@@ -7436,7 +7462,7 @@ void set_mul_mat_bf16(MulMat& mm) {
}
void set_mul_mat_bf16_r16(MulMat& mm) {
for (auto& f : mm.funcs) f = nullptr;
mm.funcs[0] = mul_mat_bf16_r16_bf16<1>;
mm.funcs[0] = mul_mat_bf16_r16_bf16_1;
mm.funcs[1] = mul_mat_bf16_r16_bf16<2>;
mm.funcs[2] = mul_mat_bf16_r16_bf16<3>;
mm.funcs[3] = mul_mat_bf16_r16_bf16<4>;