If available, use bf16 for iq4_kt gemm/gemv

With that, we get PP-512 = 240 t/s.
This commit is contained in:
Iwan Kawrakow
2025-06-02 11:59:20 +03:00
parent 0715919fc0
commit 061d064b21
2 changed files with 48 additions and 61 deletions

View File

@@ -73,12 +73,23 @@ struct Trellis1 {
auto v = _mm512_mul_ps(scale, _mm512_add_ps(vs1, vs2));
return __m256i(_mm512_cvtneps_pbh(v));
}
inline __m256i gen8bh(__m256i i8_1, __m256i i8_2, __m512 scale, __m512 offset) const {
auto v1 = _mm512_cvtph_ps(i8_1);
auto v2 = _mm512_cvtph_ps(i8_2);
auto vs1 = _mm512_permutex2var_ps(v1, shuf1, v2);
auto vs2 = _mm512_permutex2var_ps(v1, shuf2, v2);
auto v = _mm512_fmadd_ps(scale, _mm512_add_ps(vs1, vs2), offset);
return __m256i(_mm512_cvtneps_pbh(v));
}
inline __m256i gen8bh(uint32_t val1, uint32_t val2, __m512 scale) const {
return gen8bh(next8(val1), next8(val2), scale);
}
inline __m256i gen8bh(uint32_t val1, uint32_t val2, uint32_t val3, uint32_t val4, __m512 scale) const {
return gen8bh(next8(val1, val2), next8(val3, val4), scale);
}
inline __m256i gen8bh(uint32_t val1, uint32_t val2, uint32_t val3, uint32_t val4, __m512 scale, __m512 offset) const {
return gen8bh(next8(val1, val2), next8(val3, val4), scale, offset);
}
#endif
inline __m256i next8(uint32_t val) const {
auto mval = _mm256_set1_epi32(val);
@@ -673,10 +684,8 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
}
#ifdef __AVX512BF16__
template <int nrc_y>
void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, ggml_bf16_t * y, size_t stride_y, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
/*
const int nb = n/QK_K;
constexpr int kNumGroups = 64;
@@ -685,67 +694,54 @@ void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& in
union { __m256 vec; float val[8]; } s_helper;
union { __m256i vec; uint32_t val[8]; } o_helper;
constexpr int k_acc = 2 * nrc_y;
__m256 accd[k_acc];
const ggml_bf16_t * y[nrc_y];
float row_sum[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
//auto sum = _mm256_setzero_ps();
//for (int i = 0; i < n/8; ++i) sum = _mm256_add_ps(sum, _mm256_loadu_ps(y[iy] + 8*i));
row_sum[iy] = 0; //hsum_float_8(sum);
}
uint32_t val[8];
union { __m512i vec[2]; uint16_t val[64]; } h_helper;
for (int ix = 0; ix < nrc_x; ++ix) {
const float * dptr = (const float *)((const char*)vx + ix*bx);
auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f);
auto dav = dptr[1];
auto dav = _mm512_set1_ps(dptr[1]);
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs);
const uint32_t * shb = x[i].qs;
const uint8_t * ql = (const uint8_t *)(shb + 8);
const uint8_t * qh = ql + kNumGroups;
const uint8_t * ql = (const uint8_t *)(shb + 8);
auto vql1 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+0));
auto vql2 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+1));
auto vqh = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)(ql + kNumGroups)));
h_helper.vec[0] = _mm512_add_epi16(vql1, _mm512_and_si512(_mm512_slli_epi16(vqh, 8), _mm512_set1_epi16(0xf00)));
h_helper.vec[1] = _mm512_add_epi16(vql2, _mm512_and_si512(_mm512_slli_epi16(vqh, 4), _mm512_set1_epi16(0xf00)));
auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1);
s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64))));
o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096));
for (int ib = 0; ib < 4; ++ib) {
auto scale1 = _mm512_set1_ps(s_helper.val[ib+0]);
auto scale2 = _mm512_set1_ps(s_helper.val[ib+4]);
for (int j = 0; j < 4; j += 2) {
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
val[0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 << 12) & 0x7000) + o_helper.val[ib+0];
val[1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 << 9) & 0x7000) + o_helper.val[ib+0];
val[2] = ql[8*ib+2*j+ 2] + ((qh[8*ib+2*j+2] << 8) & 0xf00) + ((sh1 << 6) & 0x7000) + o_helper.val[ib+0];
val[3] = ql[8*ib+2*j+ 3] + ((qh[8*ib+2*j+3] << 8) & 0xf00) + ((sh1 << 3) & 0x7000) + o_helper.val[ib+0];
auto xval1 = trellis.gen8bh(val[0], val[1], val[2], val[3], scale1);
val[4] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 << 12) & 0x7000) + o_helper.val[ib+4];
val[5] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 << 9) & 0x7000) + o_helper.val[ib+4];
val[6] = ql[8*ib+2*j+34] + ((qh[8*ib+2*j+2] << 4) & 0xf00) + ((sh2 << 6) & 0x7000) + o_helper.val[ib+4];
val[7] = ql[8*ib+2*j+35] + ((qh[8*ib+2*j+3] << 4) & 0xf00) + ((sh2 << 3) & 0x7000) + o_helper.val[ib+4];
auto xval2 = trellis.gen8bh(val[4], val[5], val[6], val[7], scale2);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y1 = _mm256_loadu_si256((const __m256i *)(y[iy] + i*QK_K+32*ib+8*j+ 0));
auto y2 = _mm256_loadu_si256((const __m256i *)(y[iy] + i*QK_K+32*ib+8*j+128));
accd[2*iy+0] = _mm256_dpbf16_ps(accd[2*iy+0], __m256bh(y1), __m256bh(xval1));
accd[2*iy+1] = _mm256_dpbf16_ps(accd[2*iy+1], __m256bh(y2), __m256bh(xval2));
}
}
for (int ib = 0; ib < QK_K/32; ++ib) {
auto scale = _mm512_set1_ps(s_helper.val[ib]);
uint32_t val1 = h_helper.val[8*ib+0] + ((shb[ib] << 4) & 0x7000) + o_helper.val[ib];
uint32_t val2 = h_helper.val[8*ib+1] + ((shb[ib] << 1) & 0x7000) + o_helper.val[ib];
uint32_t val3 = h_helper.val[8*ib+2] + ((shb[ib] >> 2) & 0x7000) + o_helper.val[ib];
uint32_t val4 = h_helper.val[8*ib+3] + ((shb[ib] >> 5) & 0x7000) + o_helper.val[ib];
auto xval1 = trellis.gen8bh(val1, val2, val3, val4, scale, dav);
val1 = h_helper.val[8*ib+4] + ((shb[ib] >> 8) & 0x7000) + o_helper.val[ib];
val2 = h_helper.val[8*ib+5] + ((shb[ib] >> 11) & 0x7000) + o_helper.val[ib];
val3 = h_helper.val[8*ib+6] + ((shb[ib] >> 14) & 0x7000) + o_helper.val[ib];
val4 = h_helper.val[8*ib+7] + ((shb[ib] >> 17) & 0x7000) + o_helper.val[ib];
auto xval2 = trellis.gen8bh(val1, val2, val3, val4, scale, dav);
_mm256_storeu_si256((__m256i *)(y + i*QK_K+32*ib+ 0), xval1);
_mm256_storeu_si256((__m256i *)(y + i*QK_K+32*ib+16), xval2);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(_mm256_add_ps(accd[2*iy], accd[2*iy+1])) + dav*row_sum[iy]);
}
y += stride_y;
}
*/
}
template <int nrc_y>
void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
const int nb = n/QK_K;
constexpr int kNumGroups = 64;
@@ -759,7 +755,6 @@ void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& in
float row_sum[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
//row_sum[iy] = 0;
auto sum = _mm256_setzero_ps();
auto one = _mm512_cvtneps_pbh(_mm512_set1_ps(1.f));
for (int i = 0; i < n/16; ++i) sum = _mm256_dpbf16_ps(sum, one, __m256bh(_mm256_loadu_si256((const __m256i *)y[iy] + i)));
@@ -782,7 +777,6 @@ void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& in
const uint8_t * ql = (const uint8_t *)(shb + 8);
auto vql1 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+0));
auto vql2 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+1));
//const uint32_t * qh = (const uint32_t *)(ql + kNumGroups);
auto vqh = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)(ql + kNumGroups)));
h_helper.vec[0] = _mm512_add_epi16(vql1, _mm512_and_si512(_mm512_slli_epi16(vqh, 8), _mm512_set1_epi16(0xf00)));
h_helper.vec[1] = _mm512_add_epi16(vql2, _mm512_and_si512(_mm512_slli_epi16(vqh, 4), _mm512_set1_epi16(0xf00)));
@@ -792,15 +786,6 @@ void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& in
for (int ib = 0; ib < QK_K/32; ++ib) {
auto scale = _mm512_set1_ps(s_helper.val[ib]);
// qh[(Q::kNg*ib + j)%(kNumGroups/2)] -> qh[(8*ib+j)%32], j = 0...7
// ib = 0 -> 0....7 (uint8_t) -> 0, 1 (uint32_t), shift = 0
// ib = 1 -> 8...15 (uint8_t) -> 2, 3 (uint32_t), shift = 0
// ib = 2 -> 16..23 (uint8_t) -> 4, 5 (uint32_t), shift = 0
// ib = 3 -> 24..31 (uint8_t) -> 6, 7 (uint32_t), shift = 0
// ib = 4 -> 0....7 (uint8_t) -> 1, 1 (uint32_t), shift = 4
// ib = 5 -> 8...15 (uint8_t) -> 2, 3 (uint32_t), shift = 4
// ib = 6 -> 16..23 (uint8_t) -> 4, 5 (uint32_t), shift = 4
// ib = 7 -> 24..31 (uint8_t) -> 6, 7 (uint32_t), shift = 4
uint32_t val1 = h_helper.val[8*ib+0] + ((shb[ib] << 4) & 0x7000) + o_helper.val[ib];
uint32_t val2 = h_helper.val[8*ib+1] + ((shb[ib] << 1) & 0x7000) + o_helper.val[ib];
uint32_t val3 = h_helper.val[8*ib+2] + ((shb[ib] >> 2) & 0x7000) + o_helper.val[ib];
@@ -882,11 +867,12 @@ bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void *
#ifdef __AVX512BF16__
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (ggml_bf16_t *)y, stride_y, nrc_x); break;
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (ggml_bf16_t *)y, stride_y, nrc_x); break;
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (ggml_bf16_t *)y, stride_y, nrc_x); break;
#else
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
#endif
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
#endif
default: return false;
}
return true;

View File

@@ -239,11 +239,12 @@ struct MulMat {
#ifdef __AVX512BF16__
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_BF16 : type;
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_BF16 : type;
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_BF16 : type;
#else
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
#endif
//case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
default: break;
}
#else