AVX2 quantization for Q8_K (#22)

It has been there for a while, but forgot to add here.

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2024-08-19 15:33:27 +03:00
committed by GitHub
parent 5652100afc
commit a73702d93b
3 changed files with 107 additions and 0 deletions

View File

@@ -12,6 +12,7 @@
#include "ggml-impl.h"
#if GGML_USE_IQK_MULMAT
#include "iqk/iqk_mul_mat.h"
#include "iqk/iqk_quantize.h"
#endif
@@ -3770,7 +3771,11 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int6
}
void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
#ifdef GGML_USE_IQK_MULMAT
iqk_quantize_row_q8_K(x, y, k);
#else
quantize_row_q8_K_ref(x, y, k);
#endif
}
//===================================== Dot ptoducts =================================

View File

@@ -1982,3 +1982,103 @@ void vec_dot_iq2_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t
*s = sumf;
}
#ifdef __AVX2__
namespace {
inline int hsum_i32_8(const __m256i a) {
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
const __m128i sum64 = _mm_add_epi32(hi64, sum128);
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
}
inline float hmax_f32_8(__m256 x) {
__m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4));
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4));
return _mm_cvtss_f32(max4);
}
}
#endif
void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
block_q8_K * y = (block_q8_K *)vy;
#ifdef __AVX2__
const __m256 signBit = _mm256_set1_ps(-0.0f);
const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
for (int i = 0; i < nb; i++) {
const float * xb = x + i*QK_K;
__m256 maxAbs = _mm256_setzero_ps();
const float * xx = xb;
for (int ib = 0; ib < QK_K/8; ++ib) {
const __m256 v = _mm256_loadu_ps(xx); xx += 8;
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps(signBit, v));
}
const float maxScalar = hmax_f32_8(maxAbs);
const float d = maxScalar / 127.f;
y[i].d = d;
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
xx = xb;
int8_t * q8 = y[i].qs;
for (int ib = 0; ib < QK_K/32; ++ib) {
__m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
__m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
__m256 v2 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
__m256 v3 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST);
v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST);
v2 = _mm256_round_ps(v2, _MM_ROUND_NEAREST);
v3 = _mm256_round_ps(v3, _MM_ROUND_NEAREST);
__m256i i0 = _mm256_cvtps_epi32(v0);
__m256i i1 = _mm256_cvtps_epi32(v1);
__m256i i2 = _mm256_cvtps_epi32(v2);
__m256i i3 = _mm256_cvtps_epi32(v3);
y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1));
y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3));
i0 = _mm256_packs_epi32( i0, i1 );
i2 = _mm256_packs_epi32( i2, i3 );
i0 = _mm256_packs_epi16( i0, i2 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
_mm256_storeu_si256((__m256i *)q8, i0);
q8 += 32;
}
}
#else
for (int i = 0; i < nb; i++) {
float max = 0;
float amax = 0;
for (int j = 0; j < QK_K; ++j) {
float ax = fabsf(x[j]);
if (ax > amax) {
amax = ax; max = x[j];
}
}
if (!amax) {
y[i].d = 0;
memset(y[i].qs, 0, QK_K);
x += QK_K;
continue;
}
//const float iscale = -128.f/max;
// We need this change for IQ2_XXS, else the AVX implementation becomes very awkward
const float iscale = -127.f/max;
for (int j = 0; j < QK_K; ++j) {
int v = nearest_int(iscale*x[j]);
y[i].qs[j] = MIN(127, v);
}
for (int j = 0; j < QK_K/16; ++j) {
int sum = 0;
for (int ii = 0; ii < 16; ++ii) {
sum += y[i].qs[j*16 + ii];
}
y[i].bsums[j] = sum;
}
y[i].d = 1/iscale;
x += QK_K;
}
#endif
}

View File

@@ -49,6 +49,8 @@ size_t quantize_iq2_tn(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst
void dequantize_row_iq2_tn(const block_iq2_tn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_iq2_tn_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
#ifdef __cplusplus
}
#endif