mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-24 16:39:45 +00:00
If available, use bf16 for iq3_kt gemm/gemv
With that, we get PP-512 = 233 t/s.
This commit is contained in:
@@ -1604,6 +1604,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.vec_dot = vec_dot_iq3_kt_q8_k,
|
||||
#ifdef __ARM_NEON
|
||||
.vec_dot_type = GGML_TYPE_F16,
|
||||
#elif defined __AVX512BF16__
|
||||
.vec_dot_type = GGML_TYPE_BF16,
|
||||
#else
|
||||
.vec_dot_type = GGML_TYPE_F32,
|
||||
#endif
|
||||
|
||||
@@ -64,12 +64,17 @@ struct Trellis1 {
|
||||
v = _mm512_add_ps(_mm512_permutexvar_ps(shuf1, v), _mm512_permutexvar_ps(shuf2, v));
|
||||
return _mm512_castps512_ps256(v);
|
||||
}
|
||||
//template <bool is_abs = false>
|
||||
inline __m256i gen8bh(__m256i i8_1, __m256i i8_2, __m512 scale) const {
|
||||
auto v1 = _mm512_cvtph_ps(i8_1);
|
||||
auto v2 = _mm512_cvtph_ps(i8_2);
|
||||
auto vs1 = _mm512_permutex2var_ps(v1, shuf1, v2);
|
||||
auto vs2 = _mm512_permutex2var_ps(v1, shuf2, v2);
|
||||
auto v = _mm512_mul_ps(scale, _mm512_add_ps(vs1, vs2));
|
||||
//if constexpr (is_abs) {
|
||||
// v = _mm512_andnot_ps(_mm512_set1_ps(-0.0f), v);
|
||||
//}
|
||||
//v = _mm512_mul_ps(scale, v);
|
||||
return __m256i(_mm512_cvtneps_pbh(v));
|
||||
}
|
||||
#endif
|
||||
@@ -423,6 +428,116 @@ void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
void iqk_dequantize_iq3_kt(int n, const void * vx, size_t bx, ggml_bf16_t * y, size_t stride_y, int nrc_x) {
|
||||
GGML_ASSERT(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
|
||||
Trellis1 trellis;
|
||||
|
||||
union { __m256 vec; float val[8]; } s_helper;
|
||||
|
||||
auto shifts = _mm_set_epi32(0, 0, 4, 0);
|
||||
|
||||
__m256i all_signs[2];
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||
if (dptr[0] < 0) {
|
||||
printf("Oops: row scale is %g\n", dptr[0]);
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
auto vd = _mm256_set1_ps(dptr[0] * 31.75f * 1.015f);
|
||||
const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const uint16_t * ql = (const uint16_t *)x[i].ql;
|
||||
const uint8_t * qh = x[i].qh;
|
||||
auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales);
|
||||
s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf));
|
||||
auto s32 = _mm256_cvtepi8_epi32(s8);
|
||||
s_helper.vec = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(s32));
|
||||
for (int j = 0; j < 2; ++j) all_signs[j] = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)(qh + 16*j)));
|
||||
auto mask = _mm256_set1_epi16(0x01);
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
auto scale = _mm512_set1_ps(s_helper.val[ib]);
|
||||
auto xval1 = trellis.gen8bh(trellis.next8(ql[4*ib+0]+4096), trellis.next8(ql[4*ib+1]+4096), scale);
|
||||
auto xval2 = trellis.gen8bh(trellis.next8(ql[4*ib+2]+4096), trellis.next8(ql[4*ib+3]+4096), scale);
|
||||
auto sign1 = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(all_signs[0], mask), mask), _mm256_set1_epi16(0x8000));
|
||||
auto sign2 = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(all_signs[1], mask), mask), _mm256_set1_epi16(0x8000));
|
||||
auto x1 = _mm256_or_si256(sign1, _mm256_and_si256(xval1, _mm256_set1_epi16(0x7fff)));
|
||||
auto x2 = _mm256_or_si256(sign2, _mm256_and_si256(xval2, _mm256_set1_epi16(0x7fff)));
|
||||
_mm256_storeu_si256((__m256i *)(y+i*QK_K+32*ib+ 0), x1);
|
||||
_mm256_storeu_si256((__m256i *)(y+i*QK_K+32*ib+16), x2);
|
||||
mask = _mm256_slli_epi16(mask, 1);
|
||||
}
|
||||
}
|
||||
y += stride_y;
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq3_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
|
||||
Trellis1 trellis;
|
||||
|
||||
union { __m256 vec; float val[8]; } s_helper;
|
||||
|
||||
auto shifts = _mm_set_epi32(0, 0, 4, 0);
|
||||
|
||||
__m256i all_signs[2];
|
||||
|
||||
__m256 accd[2*nrc_y];
|
||||
const ggml_bf16_t * y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||
if (dptr[0] < 0) {
|
||||
printf("Oops: row scale is %g\n", dptr[0]);
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
auto vd = _mm256_set1_ps(dptr[0] * 31.75f * 1.015f);
|
||||
const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1);
|
||||
|
||||
for (int iy = 0; iy < 2*nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const uint16_t * ql = (const uint16_t *)x[i].ql;
|
||||
const uint8_t * qh = x[i].qh;
|
||||
auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales);
|
||||
s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf));
|
||||
auto s32 = _mm256_cvtepi8_epi32(s8);
|
||||
s_helper.vec = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(s32));
|
||||
for (int j = 0; j < 2; ++j) all_signs[j] = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)(qh + 16*j)));
|
||||
auto mask = _mm256_set1_epi16(0x01);
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
auto scale = _mm512_set1_ps(s_helper.val[ib]);
|
||||
auto xval1 = trellis.gen8bh(trellis.next8(ql[4*ib+0]+4096), trellis.next8(ql[4*ib+1]+4096), scale);
|
||||
auto xval2 = trellis.gen8bh(trellis.next8(ql[4*ib+2]+4096), trellis.next8(ql[4*ib+3]+4096), scale);
|
||||
auto sign1 = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(all_signs[0], mask), mask), _mm256_set1_epi16(0x8000));
|
||||
auto sign2 = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(all_signs[1], mask), mask), _mm256_set1_epi16(0x8000));
|
||||
mask = _mm256_slli_epi16(mask, 1);
|
||||
auto x1 = __m256bh(_mm256_or_si256(sign1, _mm256_and_si256(xval1, _mm256_set1_epi16(0x7fff))));
|
||||
auto x2 = __m256bh(_mm256_or_si256(sign2, _mm256_and_si256(xval2, _mm256_set1_epi16(0x7fff))));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y1 = __m256bh(_mm256_loadu_si256((const __m256i *)(y[iy]+i*QK_K+32*ib+ 0)));
|
||||
auto y2 = __m256bh(_mm256_loadu_si256((const __m256i *)(y[iy]+i*QK_K+32*ib+16)));
|
||||
accd[2*iy+0] = _mm256_dpbf16_ps(accd[2*iy+0], y1, x1);
|
||||
accd[2*iy+1] = _mm256_dpbf16_ps(accd[2*iy+1], y2, x2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, hsum_float_8(_mm256_add_ps(accd[2*iy], accd[2*iy+1])));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) {
|
||||
GGML_ASSERT(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
@@ -564,6 +679,11 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_BF16_T, kernels);
|
||||
return true;
|
||||
}
|
||||
if (typeA == GGML_TYPE_IQ3_KT) {
|
||||
if (typeB != GGML_TYPE_BF16) return false;
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_kt_BF16_T, kernels);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (ggml_type(typeB) != GGML_TYPE_F32) {
|
||||
@@ -592,10 +712,11 @@ bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void *
|
||||
switch (type) {
|
||||
#ifdef __AVX512BF16__
|
||||
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (ggml_bf16_t *)y, stride_y, nrc_x); break;
|
||||
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (ggml_bf16_t *)y, stride_y, nrc_x); break;
|
||||
#else
|
||||
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
|
||||
#endif
|
||||
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
|
||||
#endif
|
||||
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
|
||||
default: return false;
|
||||
}
|
||||
|
||||
@@ -238,10 +238,11 @@ struct MulMat {
|
||||
switch (type) {
|
||||
#ifdef __AVX512BF16__
|
||||
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_BF16 : type;
|
||||
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_BF16 : type;
|
||||
#else
|
||||
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
|
||||
#endif
|
||||
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
|
||||
#endif
|
||||
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
|
||||
default: break;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user