Use Q8_K_128 for IQ1_S_R4 and IQ1_M_R4 matrix multiplications (#194)

* iq1_s_r4: Use Q8_K_128 instead of Q8_1_X4 for gemm (AVX2/Zen4)

* iq1_m_r4: Use Q8_K_128 instead of Q8_1_X4 for gemm (AVX2/Zen4)

* iq1_s_r4: Use Q8_K_128 instead of Q8_1_X4 for gemm (Neon)

* iq1_m_r4: Use Q8_K_128 instead of Q8_0_X4 for gemm (Neon)

* Simdify q8_K128 quantization also on Neon

* Cleanup

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-02-09 09:14:52 +02:00
committed by GitHub
parent 716508d196
commit 6658922b94
6 changed files with 169 additions and 43 deletions

View File

@@ -2733,6 +2733,7 @@ size_t quantize_iq6_k(const float * src, void * dst, int64_t nrows, int64_t n_pe
return nrows * nblock * sizeof(block_iq6_k);
}
namespace {
template <int q8_type>
void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) {
assert(k % QK_K == 0);
@@ -2843,7 +2844,7 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) {
x += QK_K;
}
#endif
}
}
void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) {
@@ -2858,6 +2859,120 @@ void quantize_row_q8_KR8(const float * x, void * vy, int64_t k) {
iqk_quantize_row_q8_K_T<2>(x, vy, k);
}
namespace {
// TODO: merge this with the above template
void iqk_quantize_row_q8_K128(const float * x, void * vy, int64_t k) {
constexpr int kBlockSize = 128;
assert(k % kBlockSize == 0);
const int nb = k / kBlockSize;
auto y = (block_q8_K128 *)vy;
#ifdef __AVX2__
const __m256 signBit = _mm256_set1_ps(-0.0f);
const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
for (int i = 0; i < nb; i++) {
const float * xb = x + i*kBlockSize;
__m256 maxAbs = _mm256_setzero_ps();
const float * xx = xb;
for (int ib = 0; ib < kBlockSize/8; ++ib) {
const __m256 v = _mm256_loadu_ps(xx); xx += 8;
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps(signBit, v));
}
const float maxScalar = hmax_f32_8(maxAbs);
const float d = maxScalar / 127.f;
y[i].d = d;
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
xx = xb;
int8_t * q8 = y[i].qs;
for (int ib = 0; ib < kBlockSize/32; ++ib) {
__m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
__m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
__m256 v2 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
__m256 v3 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST);
v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST);
v2 = _mm256_round_ps(v2, _MM_ROUND_NEAREST);
v3 = _mm256_round_ps(v3, _MM_ROUND_NEAREST);
__m256i i0 = _mm256_cvtps_epi32(v0);
__m256i i1 = _mm256_cvtps_epi32(v1);
__m256i i2 = _mm256_cvtps_epi32(v2);
__m256i i3 = _mm256_cvtps_epi32(v3);
y[i].bsums[ib] = hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
i0 = _mm256_packs_epi32( i0, i1 );
i2 = _mm256_packs_epi32( i2, i3 );
i0 = _mm256_packs_epi16( i0, i2 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
_mm256_storeu_si256((__m256i *)q8, i0);
q8 += 32;
}
}
#elif defined __ARM_NEON
int32x4_t ival[8];
for (int i = 0; i < nb; i++) {
const float * xb = x + i*kBlockSize;
auto vmax = vdupq_n_f32(0.f);
for (int j = 0; j < kBlockSize; j += 4) {
vmax = vmaxq_f32(vmax, vabsq_f32(vld1q_f32(xb + j)));
}
auto smax = vmaxvq_f32(vmax);
if (!smax) {
std::memset(&y[i], 0, sizeof(y[i]));
continue;
}
y[i].d = smax/127;
auto vid = vdupq_n_f32(127/smax);
for (int ib = 0; ib < kBlockSize/32; ++ib) {
auto isum = vdupq_n_s32(0);
for (int k = 0; k < 8; ++k) {
auto val = vld1q_f32(xb + 32*ib + 4*k);
ival[k] = vcvtnq_s32_f32(vmulq_f32(val, vid));
isum = vaddq_s32(isum, ival[k]);
}
y[i].bsums[ib] = vaddvq_s32(isum);
for (int k = 0; k < 4; ++k) {
auto i16 = vcombine_s16(vmovn_s32(ival[2*k+0]), vmovn_s32(ival[2*k+1]));
vst1_s8(y[i].qs + 32*ib + 8*k, vmovn_s16(i16));
}
}
}
#else
for (int i = 0; i < nb; i++) {
float amax = 0;
for (int j = 0; j < kBlockSize; ++j) {
float ax = std::abs(x[j]);
amax = std::max(amax, ax);
}
if (!amax) {
y[i].d = 0;
memset(y[i].qs, 0, kBlockSize);
memset(y[i].bsums, 0, kBlockSize/32*(sizeof(int16_t)));
x += kBlockSize;
continue;
}
const float iscale = 127.f/amax;
for (int j = 0; j < kBlockSize; ++j) {
int v = nearest_int(iscale*x[j]);
y[i].qs[j] = v;
}
for (int j = 0; j < kBlockSize/32; ++j) {
int sum = 0;
for (int ii = 0; ii < 32; ++ii) {
sum += y[i].qs[j*32 + ii];
}
y[i].bsums[j] = sum;
}
y[i].d = 1/iscale;
x += kBlockSize;
}
#endif
}
}
void quantize_row_q8_K128(const float * x, void * vy, int64_t k) {
iqk_quantize_row_q8_K128(x, vy, k);
}
namespace {
static void quantize_row_iq4_k_impl_bs128(const int super_block_size, const int block_size,
int n_per_row, const float * x, char * cy,