If available, use bf16 for iq3_kt gemm/gemv

With that, we get PP-512 = 233 t/s.
This commit is contained in:
Iwan Kawrakow
2025-06-02 07:14:54 +03:00
parent 62d8dd932b
commit 9890618db4
3 changed files with 126 additions and 2 deletions

View File

@@ -1604,6 +1604,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = vec_dot_iq3_kt_q8_k,
#ifdef __ARM_NEON
.vec_dot_type = GGML_TYPE_F16,
#elif defined __AVX512BF16__
.vec_dot_type = GGML_TYPE_BF16,
#else
.vec_dot_type = GGML_TYPE_F32,
#endif

View File

@@ -64,12 +64,17 @@ struct Trellis1 {
v = _mm512_add_ps(_mm512_permutexvar_ps(shuf1, v), _mm512_permutexvar_ps(shuf2, v));
return _mm512_castps512_ps256(v);
}
//template <bool is_abs = false>
inline __m256i gen8bh(__m256i i8_1, __m256i i8_2, __m512 scale) const {
auto v1 = _mm512_cvtph_ps(i8_1);
auto v2 = _mm512_cvtph_ps(i8_2);
auto vs1 = _mm512_permutex2var_ps(v1, shuf1, v2);
auto vs2 = _mm512_permutex2var_ps(v1, shuf2, v2);
auto v = _mm512_mul_ps(scale, _mm512_add_ps(vs1, vs2));
//if constexpr (is_abs) {
// v = _mm512_andnot_ps(_mm512_set1_ps(-0.0f), v);
//}
//v = _mm512_mul_ps(scale, v);
return __m256i(_mm512_cvtneps_pbh(v));
}
#endif
@@ -423,6 +428,116 @@ void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
}
}
#ifdef __AVX512BF16__
void iqk_dequantize_iq3_kt(int n, const void * vx, size_t bx, ggml_bf16_t * y, size_t stride_y, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
const int nb = n/QK_K;
Trellis1 trellis;
union { __m256 vec; float val[8]; } s_helper;
auto shifts = _mm_set_epi32(0, 0, 4, 0);
__m256i all_signs[2];
for (int ix = 0; ix < nrc_x; ++ix) {
const float * dptr = (const float *)((const char*)vx + ix*bx);
if (dptr[0] < 0) {
printf("Oops: row scale is %g\n", dptr[0]);
GGML_ABORT("Fatal error");
}
auto vd = _mm256_set1_ps(dptr[0] * 31.75f * 1.015f);
const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1);
for (int i = 0; i < nb; ++i) {
const uint16_t * ql = (const uint16_t *)x[i].ql;
const uint8_t * qh = x[i].qh;
auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales);
s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf));
auto s32 = _mm256_cvtepi8_epi32(s8);
s_helper.vec = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(s32));
for (int j = 0; j < 2; ++j) all_signs[j] = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)(qh + 16*j)));
auto mask = _mm256_set1_epi16(0x01);
for (int ib = 0; ib < QK_K/32; ++ib) {
auto scale = _mm512_set1_ps(s_helper.val[ib]);
auto xval1 = trellis.gen8bh(trellis.next8(ql[4*ib+0]+4096), trellis.next8(ql[4*ib+1]+4096), scale);
auto xval2 = trellis.gen8bh(trellis.next8(ql[4*ib+2]+4096), trellis.next8(ql[4*ib+3]+4096), scale);
auto sign1 = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(all_signs[0], mask), mask), _mm256_set1_epi16(0x8000));
auto sign2 = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(all_signs[1], mask), mask), _mm256_set1_epi16(0x8000));
auto x1 = _mm256_or_si256(sign1, _mm256_and_si256(xval1, _mm256_set1_epi16(0x7fff)));
auto x2 = _mm256_or_si256(sign2, _mm256_and_si256(xval2, _mm256_set1_epi16(0x7fff)));
_mm256_storeu_si256((__m256i *)(y+i*QK_K+32*ib+ 0), x1);
_mm256_storeu_si256((__m256i *)(y+i*QK_K+32*ib+16), x2);
mask = _mm256_slli_epi16(mask, 1);
}
}
y += stride_y;
}
}
template <int nrc_y>
void mul_mat_iq3_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
const int nb = n/QK_K;
Trellis1 trellis;
union { __m256 vec; float val[8]; } s_helper;
auto shifts = _mm_set_epi32(0, 0, 4, 0);
__m256i all_signs[2];
__m256 accd[2*nrc_y];
const ggml_bf16_t * y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
for (int ix = 0; ix < nrc_x; ++ix) {
const float * dptr = (const float *)((const char*)vx + ix*bx);
if (dptr[0] < 0) {
printf("Oops: row scale is %g\n", dptr[0]);
GGML_ABORT("Fatal error");
}
auto vd = _mm256_set1_ps(dptr[0] * 31.75f * 1.015f);
const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1);
for (int iy = 0; iy < 2*nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const uint16_t * ql = (const uint16_t *)x[i].ql;
const uint8_t * qh = x[i].qh;
auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales);
s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf));
auto s32 = _mm256_cvtepi8_epi32(s8);
s_helper.vec = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(s32));
for (int j = 0; j < 2; ++j) all_signs[j] = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)(qh + 16*j)));
auto mask = _mm256_set1_epi16(0x01);
for (int ib = 0; ib < QK_K/32; ++ib) {
auto scale = _mm512_set1_ps(s_helper.val[ib]);
auto xval1 = trellis.gen8bh(trellis.next8(ql[4*ib+0]+4096), trellis.next8(ql[4*ib+1]+4096), scale);
auto xval2 = trellis.gen8bh(trellis.next8(ql[4*ib+2]+4096), trellis.next8(ql[4*ib+3]+4096), scale);
auto sign1 = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(all_signs[0], mask), mask), _mm256_set1_epi16(0x8000));
auto sign2 = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(all_signs[1], mask), mask), _mm256_set1_epi16(0x8000));
mask = _mm256_slli_epi16(mask, 1);
auto x1 = __m256bh(_mm256_or_si256(sign1, _mm256_and_si256(xval1, _mm256_set1_epi16(0x7fff))));
auto x2 = __m256bh(_mm256_or_si256(sign2, _mm256_and_si256(xval2, _mm256_set1_epi16(0x7fff))));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y1 = __m256bh(_mm256_loadu_si256((const __m256i *)(y[iy]+i*QK_K+32*ib+ 0)));
auto y2 = __m256bh(_mm256_loadu_si256((const __m256i *)(y[iy]+i*QK_K+32*ib+16)));
accd[2*iy+0] = _mm256_dpbf16_ps(accd[2*iy+0], y1, x1);
accd[2*iy+1] = _mm256_dpbf16_ps(accd[2*iy+1], y2, x2);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(_mm256_add_ps(accd[2*iy], accd[2*iy+1])));
}
}
}
#endif
void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
const int nb = n/QK_K;
@@ -564,6 +679,11 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_BF16_T, kernels);
return true;
}
if (typeA == GGML_TYPE_IQ3_KT) {
if (typeB != GGML_TYPE_BF16) return false;
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_kt_BF16_T, kernels);
return true;
}
#endif
if (ggml_type(typeB) != GGML_TYPE_F32) {
@@ -592,10 +712,11 @@ bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void *
switch (type) {
#ifdef __AVX512BF16__
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (ggml_bf16_t *)y, stride_y, nrc_x); break;
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (ggml_bf16_t *)y, stride_y, nrc_x); break;
#else
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
#endif
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
#endif
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
default: return false;
}

View File

@@ -238,10 +238,11 @@ struct MulMat {
switch (type) {
#ifdef __AVX512BF16__
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_BF16 : type;
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_BF16 : type;
#else
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
#endif
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
#endif
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
default: break;
}