mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-24 00:19:19 +00:00
If available, use bf16 for iq4_kt gemm/gemv
With that, we get PP-512 = 240 t/s.
This commit is contained in:
@@ -73,12 +73,23 @@ struct Trellis1 {
|
||||
auto v = _mm512_mul_ps(scale, _mm512_add_ps(vs1, vs2));
|
||||
return __m256i(_mm512_cvtneps_pbh(v));
|
||||
}
|
||||
inline __m256i gen8bh(__m256i i8_1, __m256i i8_2, __m512 scale, __m512 offset) const {
|
||||
auto v1 = _mm512_cvtph_ps(i8_1);
|
||||
auto v2 = _mm512_cvtph_ps(i8_2);
|
||||
auto vs1 = _mm512_permutex2var_ps(v1, shuf1, v2);
|
||||
auto vs2 = _mm512_permutex2var_ps(v1, shuf2, v2);
|
||||
auto v = _mm512_fmadd_ps(scale, _mm512_add_ps(vs1, vs2), offset);
|
||||
return __m256i(_mm512_cvtneps_pbh(v));
|
||||
}
|
||||
inline __m256i gen8bh(uint32_t val1, uint32_t val2, __m512 scale) const {
|
||||
return gen8bh(next8(val1), next8(val2), scale);
|
||||
}
|
||||
inline __m256i gen8bh(uint32_t val1, uint32_t val2, uint32_t val3, uint32_t val4, __m512 scale) const {
|
||||
return gen8bh(next8(val1, val2), next8(val3, val4), scale);
|
||||
}
|
||||
inline __m256i gen8bh(uint32_t val1, uint32_t val2, uint32_t val3, uint32_t val4, __m512 scale, __m512 offset) const {
|
||||
return gen8bh(next8(val1, val2), next8(val3, val4), scale, offset);
|
||||
}
|
||||
#endif
|
||||
inline __m256i next8(uint32_t val) const {
|
||||
auto mval = _mm256_set1_epi32(val);
|
||||
@@ -673,10 +684,8 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
|
||||
}
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, ggml_bf16_t * y, size_t stride_y, int nrc_x) {
|
||||
GGML_ASSERT(n%QK_K == 0);
|
||||
/*
|
||||
const int nb = n/QK_K;
|
||||
constexpr int kNumGroups = 64;
|
||||
|
||||
@@ -685,67 +694,54 @@ void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& in
|
||||
union { __m256 vec; float val[8]; } s_helper;
|
||||
union { __m256i vec; uint32_t val[8]; } o_helper;
|
||||
|
||||
constexpr int k_acc = 2 * nrc_y;
|
||||
|
||||
__m256 accd[k_acc];
|
||||
const ggml_bf16_t * y[nrc_y];
|
||||
float row_sum[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
|
||||
//auto sum = _mm256_setzero_ps();
|
||||
//for (int i = 0; i < n/8; ++i) sum = _mm256_add_ps(sum, _mm256_loadu_ps(y[iy] + 8*i));
|
||||
row_sum[iy] = 0; //hsum_float_8(sum);
|
||||
}
|
||||
|
||||
uint32_t val[8];
|
||||
union { __m512i vec[2]; uint16_t val[64]; } h_helper;
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||
auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f);
|
||||
auto dav = dptr[1];
|
||||
auto dav = _mm512_set1_ps(dptr[1]);
|
||||
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
|
||||
|
||||
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs);
|
||||
const uint32_t * shb = x[i].qs;
|
||||
const uint8_t * ql = (const uint8_t *)(shb + 8);
|
||||
const uint8_t * qh = ql + kNumGroups;
|
||||
const uint8_t * ql = (const uint8_t *)(shb + 8);
|
||||
auto vql1 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+0));
|
||||
auto vql2 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+1));
|
||||
auto vqh = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)(ql + kNumGroups)));
|
||||
h_helper.vec[0] = _mm512_add_epi16(vql1, _mm512_and_si512(_mm512_slli_epi16(vqh, 8), _mm512_set1_epi16(0xf00)));
|
||||
h_helper.vec[1] = _mm512_add_epi16(vql2, _mm512_and_si512(_mm512_slli_epi16(vqh, 4), _mm512_set1_epi16(0xf00)));
|
||||
auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1);
|
||||
s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64))));
|
||||
o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto scale1 = _mm512_set1_ps(s_helper.val[ib+0]);
|
||||
auto scale2 = _mm512_set1_ps(s_helper.val[ib+4]);
|
||||
for (int j = 0; j < 4; j += 2) {
|
||||
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
|
||||
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
|
||||
val[0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 << 12) & 0x7000) + o_helper.val[ib+0];
|
||||
val[1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 << 9) & 0x7000) + o_helper.val[ib+0];
|
||||
val[2] = ql[8*ib+2*j+ 2] + ((qh[8*ib+2*j+2] << 8) & 0xf00) + ((sh1 << 6) & 0x7000) + o_helper.val[ib+0];
|
||||
val[3] = ql[8*ib+2*j+ 3] + ((qh[8*ib+2*j+3] << 8) & 0xf00) + ((sh1 << 3) & 0x7000) + o_helper.val[ib+0];
|
||||
auto xval1 = trellis.gen8bh(val[0], val[1], val[2], val[3], scale1);
|
||||
val[4] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 << 12) & 0x7000) + o_helper.val[ib+4];
|
||||
val[5] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 << 9) & 0x7000) + o_helper.val[ib+4];
|
||||
val[6] = ql[8*ib+2*j+34] + ((qh[8*ib+2*j+2] << 4) & 0xf00) + ((sh2 << 6) & 0x7000) + o_helper.val[ib+4];
|
||||
val[7] = ql[8*ib+2*j+35] + ((qh[8*ib+2*j+3] << 4) & 0xf00) + ((sh2 << 3) & 0x7000) + o_helper.val[ib+4];
|
||||
auto xval2 = trellis.gen8bh(val[4], val[5], val[6], val[7], scale2);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y1 = _mm256_loadu_si256((const __m256i *)(y[iy] + i*QK_K+32*ib+8*j+ 0));
|
||||
auto y2 = _mm256_loadu_si256((const __m256i *)(y[iy] + i*QK_K+32*ib+8*j+128));
|
||||
accd[2*iy+0] = _mm256_dpbf16_ps(accd[2*iy+0], __m256bh(y1), __m256bh(xval1));
|
||||
accd[2*iy+1] = _mm256_dpbf16_ps(accd[2*iy+1], __m256bh(y2), __m256bh(xval2));
|
||||
}
|
||||
}
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
auto scale = _mm512_set1_ps(s_helper.val[ib]);
|
||||
|
||||
uint32_t val1 = h_helper.val[8*ib+0] + ((shb[ib] << 4) & 0x7000) + o_helper.val[ib];
|
||||
uint32_t val2 = h_helper.val[8*ib+1] + ((shb[ib] << 1) & 0x7000) + o_helper.val[ib];
|
||||
uint32_t val3 = h_helper.val[8*ib+2] + ((shb[ib] >> 2) & 0x7000) + o_helper.val[ib];
|
||||
uint32_t val4 = h_helper.val[8*ib+3] + ((shb[ib] >> 5) & 0x7000) + o_helper.val[ib];
|
||||
auto xval1 = trellis.gen8bh(val1, val2, val3, val4, scale, dav);
|
||||
|
||||
val1 = h_helper.val[8*ib+4] + ((shb[ib] >> 8) & 0x7000) + o_helper.val[ib];
|
||||
val2 = h_helper.val[8*ib+5] + ((shb[ib] >> 11) & 0x7000) + o_helper.val[ib];
|
||||
val3 = h_helper.val[8*ib+6] + ((shb[ib] >> 14) & 0x7000) + o_helper.val[ib];
|
||||
val4 = h_helper.val[8*ib+7] + ((shb[ib] >> 17) & 0x7000) + o_helper.val[ib];
|
||||
auto xval2 = trellis.gen8bh(val1, val2, val3, val4, scale, dav);
|
||||
|
||||
_mm256_storeu_si256((__m256i *)(y + i*QK_K+32*ib+ 0), xval1);
|
||||
_mm256_storeu_si256((__m256i *)(y + i*QK_K+32*ib+16), xval2);
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, hsum_float_8(_mm256_add_ps(accd[2*iy], accd[2*iy+1])) + dav*row_sum[iy]);
|
||||
}
|
||||
y += stride_y;
|
||||
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
constexpr int kNumGroups = 64;
|
||||
|
||||
@@ -759,7 +755,6 @@ void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& in
|
||||
float row_sum[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
|
||||
//row_sum[iy] = 0;
|
||||
auto sum = _mm256_setzero_ps();
|
||||
auto one = _mm512_cvtneps_pbh(_mm512_set1_ps(1.f));
|
||||
for (int i = 0; i < n/16; ++i) sum = _mm256_dpbf16_ps(sum, one, __m256bh(_mm256_loadu_si256((const __m256i *)y[iy] + i)));
|
||||
@@ -782,7 +777,6 @@ void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& in
|
||||
const uint8_t * ql = (const uint8_t *)(shb + 8);
|
||||
auto vql1 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+0));
|
||||
auto vql2 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+1));
|
||||
//const uint32_t * qh = (const uint32_t *)(ql + kNumGroups);
|
||||
auto vqh = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)(ql + kNumGroups)));
|
||||
h_helper.vec[0] = _mm512_add_epi16(vql1, _mm512_and_si512(_mm512_slli_epi16(vqh, 8), _mm512_set1_epi16(0xf00)));
|
||||
h_helper.vec[1] = _mm512_add_epi16(vql2, _mm512_and_si512(_mm512_slli_epi16(vqh, 4), _mm512_set1_epi16(0xf00)));
|
||||
@@ -792,15 +786,6 @@ void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& in
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
auto scale = _mm512_set1_ps(s_helper.val[ib]);
|
||||
|
||||
// qh[(Q::kNg*ib + j)%(kNumGroups/2)] -> qh[(8*ib+j)%32], j = 0...7
|
||||
// ib = 0 -> 0....7 (uint8_t) -> 0, 1 (uint32_t), shift = 0
|
||||
// ib = 1 -> 8...15 (uint8_t) -> 2, 3 (uint32_t), shift = 0
|
||||
// ib = 2 -> 16..23 (uint8_t) -> 4, 5 (uint32_t), shift = 0
|
||||
// ib = 3 -> 24..31 (uint8_t) -> 6, 7 (uint32_t), shift = 0
|
||||
// ib = 4 -> 0....7 (uint8_t) -> 1, 1 (uint32_t), shift = 4
|
||||
// ib = 5 -> 8...15 (uint8_t) -> 2, 3 (uint32_t), shift = 4
|
||||
// ib = 6 -> 16..23 (uint8_t) -> 4, 5 (uint32_t), shift = 4
|
||||
// ib = 7 -> 24..31 (uint8_t) -> 6, 7 (uint32_t), shift = 4
|
||||
uint32_t val1 = h_helper.val[8*ib+0] + ((shb[ib] << 4) & 0x7000) + o_helper.val[ib];
|
||||
uint32_t val2 = h_helper.val[8*ib+1] + ((shb[ib] << 1) & 0x7000) + o_helper.val[ib];
|
||||
uint32_t val3 = h_helper.val[8*ib+2] + ((shb[ib] >> 2) & 0x7000) + o_helper.val[ib];
|
||||
@@ -882,11 +867,12 @@ bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void *
|
||||
#ifdef __AVX512BF16__
|
||||
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (ggml_bf16_t *)y, stride_y, nrc_x); break;
|
||||
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (ggml_bf16_t *)y, stride_y, nrc_x); break;
|
||||
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (ggml_bf16_t *)y, stride_y, nrc_x); break;
|
||||
#else
|
||||
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
|
||||
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
|
||||
#endif
|
||||
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
|
||||
#endif
|
||||
default: return false;
|
||||
}
|
||||
return true;
|
||||
|
||||
@@ -239,11 +239,12 @@ struct MulMat {
|
||||
#ifdef __AVX512BF16__
|
||||
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_BF16 : type;
|
||||
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_BF16 : type;
|
||||
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_BF16 : type;
|
||||
#else
|
||||
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
|
||||
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
|
||||
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
|
||||
#endif
|
||||
//case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
|
||||
default: break;
|
||||
}
|
||||
#else
|
||||
|
||||
Reference in New Issue
Block a user