mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-01 17:40:25 +00:00
iq4_knn: Zen4
We get PP-512(LLaMA-3.1-8B) = 225 t/s TG-128(LLaMA-3.1-8B) = 15.2 t/s We could do slightly better if we arranged the bits in blocks of 128 instead of 32. Thus saves 4 permutes per 256 weights and results in PP-512 = 230 t/s, TG-128 = 15.65 t/s. But for now we leave it the way it is.
This commit is contained in:
@@ -1344,6 +1344,74 @@ static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const Da
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq4_knn_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
const int nb = n / QK_K;
|
||||
|
||||
Q8<nrc_y> q8(info);
|
||||
|
||||
const char * qrow = (const char *)vx;
|
||||
|
||||
__m512 accd[nrc_y];
|
||||
__m512i qx[4];
|
||||
|
||||
const __m512i ml = _mm512_set1_epi8(0xf);
|
||||
const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
|
||||
const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
|
||||
const __m256i m127 = _mm256_set1_epi16(-127);
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
const float * dptr = (const float *)qrow;
|
||||
const float d = *dptr;
|
||||
const int8_t * int_values = (const int8_t *)(dptr + 1);
|
||||
auto val128 = _mm_loadu_si128((const __m128i *)int_values);
|
||||
val128 = _mm_add_epi8(val128, _mm_set1_epi8(127));
|
||||
auto val256 = MM256_SET_M128I(val128, val128);
|
||||
auto values = _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
|
||||
|
||||
const block_iq4_knn * x = (const block_iq4_knn *)(int_values + 16);
|
||||
|
||||
qrow += bx;
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
auto bits1 = _mm512_loadu_si512((const __m512i *)x[i].qs+0);
|
||||
auto bits2 = _mm512_loadu_si512((const __m512i *)x[i].qs+1);
|
||||
qx[0] = _mm512_and_si512(bits1, ml);
|
||||
qx[1] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), ml);
|
||||
qx[2] = _mm512_and_si512(bits2, ml);
|
||||
qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), ml);
|
||||
|
||||
auto tmp = _mm512_permutex2var_epi64(qx[0], permute1, qx[1]);
|
||||
qx[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(qx[0], permute2, qx[1]));
|
||||
qx[0] = _mm512_shuffle_epi8(values, tmp);
|
||||
tmp = _mm512_permutex2var_epi64(qx[2], permute1, qx[3]);
|
||||
qx[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(qx[2], permute2, qx[3]));
|
||||
qx[2] = _mm512_shuffle_epi8(values, tmp);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto bs = _mm256_madd_epi16(m127, q8.load_bsums(iy, i));
|
||||
auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), bs, 0);
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[0], q8.load_quants64(iy, i, 0));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[1], q8.load_quants64(iy, i, 1));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[2], q8.load_quants64(iy, i, 2));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[3], q8.load_quants64(iy, i, 3));
|
||||
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, _mm512_reduce_add_ps(accd[iy]));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
@@ -4124,6 +4192,19 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerIQ4KSS>(mm);
|
||||
break;
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
case GGML_TYPE_IQ4_KNN:
|
||||
assert (ne00 % QK_K == 0);
|
||||
mm.funcs[0] = mul_mat_iq4_knn_q8_K_AVX512<1>;
|
||||
mm.funcs[1] = mul_mat_iq4_knn_q8_K_AVX512<2>;
|
||||
mm.funcs[2] = mul_mat_iq4_knn_q8_K_AVX512<3>;
|
||||
mm.funcs[3] = mul_mat_iq4_knn_q8_K_AVX512<4>;
|
||||
mm.funcs[4] = mul_mat_iq4_knn_q8_K_AVX512<5>;
|
||||
mm.funcs[5] = mul_mat_iq4_knn_q8_K_AVX512<6>;
|
||||
mm.funcs[6] = mul_mat_iq4_knn_q8_K_AVX512<7>;
|
||||
mm.funcs[7] = mul_mat_iq4_knn_q8_K_AVX512<8>;
|
||||
break;
|
||||
#endif
|
||||
case GGML_TYPE_IQ2_K:
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerIQ2K>(mm);
|
||||
|
||||
@@ -3499,22 +3499,18 @@ void quantize_row_iq4_knn_impl(const float * x, char * qrow, const float * imatr
|
||||
}
|
||||
}
|
||||
|
||||
float max = 0, amax = 0;
|
||||
float amax = 0;
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
float ax = fabsf(work.values[i]);
|
||||
if (ax > amax) {
|
||||
amax = ax; max = work.values[i];
|
||||
}
|
||||
amax = std::max(ax, amax);
|
||||
}
|
||||
|
||||
float d = -max/128;
|
||||
//printf("amax = %g, max = %g d = %g\n", amax, max, d);
|
||||
float d = amax/127;
|
||||
*dptr = d;
|
||||
id = d ? 1/d : 0.f;
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
int l = nearest_int(id*work.values[i]);
|
||||
int_values[i] = std::max(-128, std::min(127, l));
|
||||
//printf("int_values[%d] = %d\n", i, int_values[i]);
|
||||
int_values[i] = std::max(-127, std::min(127, l));
|
||||
}
|
||||
|
||||
int nb32 = n_per_row/32;
|
||||
|
||||
Reference in New Issue
Block a user