mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-26 09:29:27 +00:00
Faster IQ4_XS_R4 on Zen4 (#128)
* Faster iq4_xs_r4 on Zen4 The trick is to simply prepare the Q8 block sums for blocks of 32 as floats. This brings PP-512 up to 254.6 t/s from 224 t/s. * Fix broken matrix x vector product on Zen4 --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -407,6 +407,7 @@ extern "C" {
|
||||
GGML_TYPE_IQ2_KS = 145,
|
||||
GGML_TYPE_IQ4_KSS = 146,
|
||||
GGML_TYPE_Q8_K16 = 147,
|
||||
GGML_TYPE_Q8_K32 = 148,
|
||||
|
||||
GGML_TYPE_Q4_0_R4 = 202,
|
||||
GGML_TYPE_Q5_0_R4 = 206,
|
||||
|
||||
@@ -1124,6 +1124,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.from_float = quantize_row_q8_K16,
|
||||
.row_meta_size = 20,
|
||||
},
|
||||
[GGML_TYPE_Q8_K32] = {
|
||||
.type_name = "q8_K32",
|
||||
.blck_size = QK_K,
|
||||
.type_size = sizeof(block_q8_K),
|
||||
.is_quantized = true,
|
||||
.from_float = quantize_row_q8_K32,
|
||||
.row_meta_size = 0,
|
||||
},
|
||||
[GGML_TYPE_BF16] = {
|
||||
.type_name = "bf16",
|
||||
.blck_size = 1,
|
||||
@@ -1292,7 +1300,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.from_float = quantize_row_iq4_xs_r4,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_r4_ref,
|
||||
.vec_dot = vec_dot_iq4_xs_r4_q8_k,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K32,
|
||||
.nrows = 1,
|
||||
.row_meta_size = 0,
|
||||
},
|
||||
@@ -15633,6 +15641,7 @@ static void ggml_compute_forward_clamp(
|
||||
case GGML_TYPE_Q8_K:
|
||||
case GGML_TYPE_Q8_K64:
|
||||
case GGML_TYPE_Q8_K16:
|
||||
case GGML_TYPE_Q8_K32:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
case GGML_TYPE_Q4_0_8_8:
|
||||
|
||||
@@ -2917,15 +2917,16 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = _mm256_set1_epi8(0xf);
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto m1 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
|
||||
auto values = MM256_SET_M128I(values128, values128);
|
||||
//auto values = load_iq4nl_values_256();
|
||||
#else
|
||||
auto values = load_iq4nl_values_256();
|
||||
#endif
|
||||
int nbl = n / QK_K;
|
||||
using helper_t = union { __m256i vec; uint32_t val[8]; };
|
||||
helper_t h;
|
||||
@@ -2969,7 +2970,7 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
|
||||
float d8 = q8.scale(iy, ibl);
|
||||
float m8 = d8 * (q8.y[iy][ibl].bsums[2*ib+0] + q8.y[iy][ibl].bsums[2*ib+1]);
|
||||
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
|
||||
#else
|
||||
@@ -2979,15 +2980,6 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const
|
||||
_mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])));
|
||||
auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
//auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00))),
|
||||
// _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))));
|
||||
//auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa))),
|
||||
// _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))));
|
||||
//auto sumi = _mm256_add_epi32(sumi1, sumi2);
|
||||
//float d8 = q8.scale(iy, ibl);
|
||||
//float m8 = d8 * (q8.y[iy][ibl].bsums[2*ib+0] + q8.y[iy][ibl].bsums[2*ib+1]);
|
||||
//acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
//acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -3057,7 +3049,7 @@ static void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
|
||||
float d8 = q8.scale(iy, ibl);
|
||||
float m8 = d8 * (q8.y[iy][ibl].bsums[2*ib+0] + q8.y[iy][ibl].bsums[2*ib+1]);
|
||||
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
|
||||
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, _mm512_set1_ps(d8)), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
|
||||
acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[2*iy+1]);
|
||||
}
|
||||
@@ -5074,7 +5066,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
|
||||
mm.funcs[5] = mul_mat_iq4_xs_r4_q8_k<6>;
|
||||
mm.funcs[6] = mul_mat_iq4_xs_r4_q8_k<7>;
|
||||
mm.funcs[7] = mul_mat_iq4_xs_r4_q8_k<8>;
|
||||
expected_typeB = GGML_TYPE_Q8_K;
|
||||
expected_typeB = GGML_TYPE_Q8_K32;
|
||||
break;
|
||||
case GGML_TYPE_Q4_0_R4:
|
||||
assert (ne00 % QK4_NL == 0);
|
||||
|
||||
@@ -2469,8 +2469,8 @@ size_t quantize_iq6_k(const float * src, void * dst, int64_t nrows, int64_t n_pe
|
||||
return nrows * nblock * sizeof(block_iq6_k);
|
||||
}
|
||||
|
||||
|
||||
void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) {
|
||||
template <bool is_K32>
|
||||
void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
const int nb = k / QK_K;
|
||||
block_q8_K * y = (block_q8_K *)vy;
|
||||
@@ -2505,8 +2505,14 @@ void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) {
|
||||
__m256i i1 = _mm256_cvtps_epi32(v1);
|
||||
__m256i i2 = _mm256_cvtps_epi32(v2);
|
||||
__m256i i3 = _mm256_cvtps_epi32(v3);
|
||||
y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1));
|
||||
y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3));
|
||||
if constexpr (is_K32) {
|
||||
int bsum = hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
|
||||
auto bs = (float *)y[i].bsums;
|
||||
bs[ib] = d*bsum;
|
||||
} else {
|
||||
y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1));
|
||||
y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3));
|
||||
}
|
||||
i0 = _mm256_packs_epi32( i0, i1 );
|
||||
i2 = _mm256_packs_epi32( i2, i3 );
|
||||
i0 = _mm256_packs_epi16( i0, i2 );
|
||||
@@ -2539,12 +2545,24 @@ void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) {
|
||||
int v = nearest_int(iscale*x[j]);
|
||||
y[i].qs[j] = MIN(127, v);
|
||||
}
|
||||
for (int j = 0; j < QK_K/16; ++j) {
|
||||
int sum = 0;
|
||||
for (int ii = 0; ii < 16; ++ii) {
|
||||
sum += y[i].qs[j*16 + ii];
|
||||
if constexpr (is_K32) {
|
||||
auto bs = (float *)y[i].bsums;
|
||||
float d = 1/iscale;
|
||||
for (int j = 0; j < QK_K/32; ++j) {
|
||||
int sum = 0;
|
||||
for (int ii = 0; ii < 32; ++ii) {
|
||||
sum += y[i].qs[j*32 + ii];
|
||||
}
|
||||
bs[j] = d*sum;
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < QK_K/16; ++j) {
|
||||
int sum = 0;
|
||||
for (int ii = 0; ii < 16; ++ii) {
|
||||
sum += y[i].qs[j*16 + ii];
|
||||
}
|
||||
y[i].bsums[j] = sum;
|
||||
}
|
||||
y[i].bsums[j] = sum;
|
||||
}
|
||||
y[i].d = 1/iscale;
|
||||
x += QK_K;
|
||||
@@ -2553,6 +2571,14 @@ void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) {
|
||||
|
||||
}
|
||||
|
||||
void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) {
|
||||
iqk_quantize_row_q8_K_T<false>(x, vy, k);
|
||||
}
|
||||
|
||||
void quantize_row_q8_K32(const float * x, void * vy, int64_t k) {
|
||||
iqk_quantize_row_q8_K_T<true>(x, vy, k);
|
||||
}
|
||||
|
||||
namespace {
|
||||
static void quantize_row_iq4_k_impl_bs128(const int super_block_size, const int block_size,
|
||||
int n_per_row, const float * x, char * cy,
|
||||
|
||||
@@ -61,8 +61,6 @@ size_t quantize_iq2_ks(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst
|
||||
void dequantize_row_iq2_ks(const block_iq2_ks * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void vec_dot_iq2_ks_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
|
||||
void quantize_row_iq4_nl_r4_ref(const float * GGML_RESTRICT x, block_iq4_nl_r4 * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_iq4_nl_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
size_t quantize_iq4_nl_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
@@ -111,9 +109,11 @@ void dequantize_row_iq2_bn_r4(const block_iq2_bn * GGML_RESTRICT x, float * GG
|
||||
size_t quantize_iq2_bn_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
void vec_dot_iq2_bn_r4_q8_K64(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q8_K16(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q8_K32(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user