mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
iq2_bn_r4: simdify q8_K16 quantization (AVX2)
PP-512 becomes 834 t/s and TG-128 now saturates to the same performance as iq2_bn for 4 threads.
This commit is contained in:
@@ -520,10 +520,78 @@ void quantize_row_q8_K64(const float * x, void * y, int64_t k) {
|
||||
quantize_row_q8_K64_ref(x, (block_q8_K64 *)y, k);
|
||||
}
|
||||
|
||||
#ifdef __AVX2__
|
||||
namespace {
|
||||
inline float hsum_float_4(__m128 x) {
|
||||
x = _mm_add_ps(x, _mm_movehl_ps(x, x));
|
||||
x = _mm_add_ss(x, _mm_movehdup_ps(x));
|
||||
return _mm_cvtss_f32(x);
|
||||
}
|
||||
inline float hsum_float_8(__m256 x) {
|
||||
return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));
|
||||
}
|
||||
inline int hsum_i32_8(const __m256i a) {
|
||||
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
|
||||
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
|
||||
const __m128i sum64 = _mm_add_epi32(hi64, sum128);
|
||||
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
|
||||
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
||||
}
|
||||
inline float hmax_f32_8(__m256 x) {
|
||||
__m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
|
||||
max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4));
|
||||
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4));
|
||||
return _mm_cvtss_f32(max4);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void quantize_row_q8_K16(const float * x, void * vy, int64_t nk) {
|
||||
float * dptr = (float *)vy;
|
||||
int8_t * qy = (int8_t *)(dptr + 5);
|
||||
int n64 = nk / 64;
|
||||
#ifdef __AVX2__
|
||||
__m256 sign_bit = _mm256_set1_ps(-0.f);
|
||||
__m256 vmax[4] = {};
|
||||
__m256 vsum[4] = {};
|
||||
for (int i64 = 0; i64 < n64; ++i64) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto v1 = _mm256_loadu_ps(x + 64*i64 + 16*k + 0);
|
||||
auto v2 = _mm256_loadu_ps(x + 64*i64 + 16*k + 8);
|
||||
vsum[k] = _mm256_add_ps(vsum[k], _mm256_add_ps(v1, v2));
|
||||
v1 = _mm256_andnot_ps(sign_bit, v1);
|
||||
v2 = _mm256_andnot_ps(sign_bit, v2);
|
||||
vmax[k] = _mm256_max_ps(vmax[k], _mm256_max_ps(v1, v2));
|
||||
}
|
||||
}
|
||||
__m256 sum = _mm256_add_ps(_mm256_add_ps(vsum[0], vsum[1]), _mm256_add_ps(vsum[2], vsum[3]));
|
||||
dptr[4] = hsum_float_8(sum);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
float max = hmax_f32_8(vmax[k]);
|
||||
dptr[k] = max/127;
|
||||
vmax[k] = _mm256_set1_ps(dptr[k] > 0 ? 1/dptr[k] : 0.f);
|
||||
}
|
||||
__m256i ival[8];
|
||||
const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
|
||||
for (int i64 = 0; i64 < n64; ++i64) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
__m256 v0 = _mm256_mul_ps(vmax[k], _mm256_loadu_ps(x + 64*i64 + 16*k + 0));
|
||||
__m256 v1 = _mm256_mul_ps(vmax[k], _mm256_loadu_ps(x + 64*i64 + 16*k + 8));
|
||||
v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST);
|
||||
v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST);
|
||||
ival[2*k+0] = _mm256_cvtps_epi32(v0);
|
||||
ival[2*k+1] = _mm256_cvtps_epi32(v1);
|
||||
}
|
||||
for (int k = 0; k < 2; ++k) {
|
||||
auto i0 = _mm256_packs_epi32(ival[4*k+0], ival[4*k+1]);
|
||||
auto i1 = _mm256_packs_epi32(ival[4*k+2], ival[4*k+3]);
|
||||
i0 = _mm256_packs_epi16(i0, i1);
|
||||
i0 = _mm256_permutevar8x32_epi32(i0, perm);
|
||||
_mm256_storeu_si256((__m256i *)qy, i0);
|
||||
qy += 32;
|
||||
}
|
||||
}
|
||||
#else
|
||||
float amax[4] = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i64 = 0; i64 < n64; ++i64) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
@@ -547,6 +615,7 @@ void quantize_row_q8_K16(const float * x, void * vy, int64_t nk) {
|
||||
}
|
||||
}
|
||||
dptr[4] = sumf;
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
@@ -2368,23 +2437,6 @@ size_t quantize_iq6_k(const float * src, void * dst, int64_t nrows, int64_t n_pe
|
||||
return nrows * nblock * sizeof(block_iq6_k);
|
||||
}
|
||||
|
||||
#ifdef __AVX2__
|
||||
namespace {
|
||||
inline int hsum_i32_8(const __m256i a) {
|
||||
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
|
||||
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
|
||||
const __m128i sum64 = _mm_add_epi32(hi64, sum128);
|
||||
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
|
||||
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
||||
}
|
||||
inline float hmax_f32_8(__m256 x) {
|
||||
__m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
|
||||
max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4));
|
||||
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4));
|
||||
return _mm_cvtss_f32(max4);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
|
||||
Reference in New Issue
Block a user