iq2_bn_r4: simdify q8_K16 quantization (AVX2)

PP-512 becomes 834 t/s and TG-128 now saturates to the same
performance as iq2_bn for 4 threads.
This commit is contained in:
Iwan Kawrakow
2024-12-06 08:41:54 +02:00
parent 4d730ebfd9
commit e06c83c8ee

View File

@@ -520,10 +520,78 @@ void quantize_row_q8_K64(const float * x, void * y, int64_t k) {
quantize_row_q8_K64_ref(x, (block_q8_K64 *)y, k);
}
#ifdef __AVX2__
namespace {
inline float hsum_float_4(__m128 x) {
x = _mm_add_ps(x, _mm_movehl_ps(x, x));
x = _mm_add_ss(x, _mm_movehdup_ps(x));
return _mm_cvtss_f32(x);
}
inline float hsum_float_8(__m256 x) {
return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));
}
inline int hsum_i32_8(const __m256i a) {
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
const __m128i sum64 = _mm_add_epi32(hi64, sum128);
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
}
inline float hmax_f32_8(__m256 x) {
__m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4));
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4));
return _mm_cvtss_f32(max4);
}
}
#endif
void quantize_row_q8_K16(const float * x, void * vy, int64_t nk) {
float * dptr = (float *)vy;
int8_t * qy = (int8_t *)(dptr + 5);
int n64 = nk / 64;
#ifdef __AVX2__
__m256 sign_bit = _mm256_set1_ps(-0.f);
__m256 vmax[4] = {};
__m256 vsum[4] = {};
for (int i64 = 0; i64 < n64; ++i64) {
for (int k = 0; k < 4; ++k) {
auto v1 = _mm256_loadu_ps(x + 64*i64 + 16*k + 0);
auto v2 = _mm256_loadu_ps(x + 64*i64 + 16*k + 8);
vsum[k] = _mm256_add_ps(vsum[k], _mm256_add_ps(v1, v2));
v1 = _mm256_andnot_ps(sign_bit, v1);
v2 = _mm256_andnot_ps(sign_bit, v2);
vmax[k] = _mm256_max_ps(vmax[k], _mm256_max_ps(v1, v2));
}
}
__m256 sum = _mm256_add_ps(_mm256_add_ps(vsum[0], vsum[1]), _mm256_add_ps(vsum[2], vsum[3]));
dptr[4] = hsum_float_8(sum);
for (int k = 0; k < 4; ++k) {
float max = hmax_f32_8(vmax[k]);
dptr[k] = max/127;
vmax[k] = _mm256_set1_ps(dptr[k] > 0 ? 1/dptr[k] : 0.f);
}
__m256i ival[8];
const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
for (int i64 = 0; i64 < n64; ++i64) {
for (int k = 0; k < 4; ++k) {
__m256 v0 = _mm256_mul_ps(vmax[k], _mm256_loadu_ps(x + 64*i64 + 16*k + 0));
__m256 v1 = _mm256_mul_ps(vmax[k], _mm256_loadu_ps(x + 64*i64 + 16*k + 8));
v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST);
v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST);
ival[2*k+0] = _mm256_cvtps_epi32(v0);
ival[2*k+1] = _mm256_cvtps_epi32(v1);
}
for (int k = 0; k < 2; ++k) {
auto i0 = _mm256_packs_epi32(ival[4*k+0], ival[4*k+1]);
auto i1 = _mm256_packs_epi32(ival[4*k+2], ival[4*k+3]);
i0 = _mm256_packs_epi16(i0, i1);
i0 = _mm256_permutevar8x32_epi32(i0, perm);
_mm256_storeu_si256((__m256i *)qy, i0);
qy += 32;
}
}
#else
float amax[4] = {0.f, 0.f, 0.f, 0.f};
for (int i64 = 0; i64 < n64; ++i64) {
for (int k = 0; k < 4; ++k) {
@@ -547,6 +615,7 @@ void quantize_row_q8_K16(const float * x, void * vy, int64_t nk) {
}
}
dptr[4] = sumf;
#endif
}
//
@@ -2368,23 +2437,6 @@ size_t quantize_iq6_k(const float * src, void * dst, int64_t nrows, int64_t n_pe
return nrows * nblock * sizeof(block_iq6_k);
}
#ifdef __AVX2__
namespace {
inline int hsum_i32_8(const __m256i a) {
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
const __m128i sum64 = _mm_add_epi32(hi64, sum128);
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
}
inline float hmax_f32_8(__m256 x) {
__m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4));
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4));
return _mm_cvtss_f32(max4);
}
}
#endif
void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) {
assert(k % QK_K == 0);