From 3b46d3afd5475f5c01439f6570e532f3584e7eab Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 29 Jan 2025 08:19:08 +0200 Subject: [PATCH] Make q4_0_r4 work with tensor row sizes that are not a multiple of 128 .., on AVX2 --- ggml/src/iqk/iqk_mul_mat.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ce3c6376..b54846f9 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -2640,6 +2640,15 @@ static void mul_mat_q4_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(helper.val[k+4]), acc2); } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto qy = (const block_q8_1 *)q8.y[0]; + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d)); + prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4); + auto sumi = accum_q4_0_quants(v, qy[ib].qs); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + acc1 = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc1); + acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc2); + } acc1 = _mm256_fmadd_ps(acc2, _mm256_set1_ps(-8.f), acc1); info.store(ix, 0, acc1); } @@ -2677,6 +2686,18 @@ static void mul_mat_q4_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d)); + auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-8.f)); + prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = accum_q4_0_quants(v, qy[ib].qs); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]); + } + } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, acc[iy]); acc[iy] = _mm256_setzero_ps();