From f1fb59b44b5070d7ec83e68df859d681f3b014c8 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 8 Nov 2024 18:39:23 +0200 Subject: [PATCH] iq3_kt WIP: slowly improving PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.8322, which is starting to be competitive/slightly better than other quants. --- ggml/src/iqk/iqk_quantize.cpp | 126 +++++++++++++++++++++------------- 1 file changed, 80 insertions(+), 46 deletions(-) diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 9336de71..d3e3d012 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -3145,6 +3145,10 @@ __m256 hsum_float_8x8(__m256 * accm) { for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); } +__m256 hsum_float_4x8(__m256 * accm) { + for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); + return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); +} #endif template class QuantizerIQKT { @@ -3470,64 +3474,87 @@ void QuantizerIQKT::find_best_ma best_idx[l] = jbest; } } else { - __m128 sqx[4]; - const __m128i add_idx = _mm_set_epi32(3, 2, 1, 0); - const __m128 sign_bit = _mm_set1_ps(-0.f); - float sx[4]; - int index[4]; - auto vid = _mm_set1_ps(id); + __m256 sqx[4]; + const __m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + const __m256 sign_bit = _mm256_set1_ps(-0.f); + float sx[8]; + int index[8]; + auto vid_p = _mm256_set1_ps(id); + auto vid_m = _mm256_set1_ps(-id); for (int l = 0; l < kNg; ++l) { auto xl = xb + 4*l; auto wl = weight + 4*l; - auto vx = _mm_mul_ps(vid, _mm_loadu_ps(xl)); - auto vw = _mm_loadu_ps(wl); - auto vbest = _mm_set1_ps(INFINITY); - auto best_index = _mm_set1_epi32(-1); + auto vx4 = _mm_loadu_ps(xl); + auto vx_p = _mm256_mul_ps(vid_p, _mm256_set_m128(vx4, vx4)); + //auto vx_m = _mm256_mul_ps(vid_m, _mm256_set_m128(vx4, vx4)); + auto vw4 = _mm_loadu_ps(wl); + auto vw = _mm256_set_m128(vw4, vw4); + auto vbest = _mm256_set1_ps(INFINITY); + auto best_index = _mm256_set1_epi32(-1); float best = INFINITY; int jbest = -1; - for (int j = 0; j < ncluster; j += 4) { - auto idx = _mm_add_epi32(_mm_set1_epi32(j), add_idx); + for (int j = 0; j < ncluster; j += 8) { + auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx); for (int i = 0; i < 4; ++i) { - auto vq = _mm_loadu_ps(m_clusters.data() + kGroupSize*(j+i)); - auto vdiff = _mm_sub_ps(vq, vx); + auto vq = _mm256_set_m128(_mm_loadu_ps(m_clusters.data() + kGroupSize*(j+i+4)), _mm_loadu_ps(m_clusters.data() + kGroupSize*(j+i))); + auto vdiff = _mm256_sub_ps(vq, vx_p); //sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, vdiff)); - vdiff = _mm_andnot_ps(sign_bit, vdiff); - sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, _mm_mul_ps(vdiff, vdiff))); + vdiff = _mm256_andnot_ps(sign_bit, vdiff); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff))); + //vdiff = _mm256_sub_ps(vq, vx_m); + ////sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, vdiff)); + //vdiff = _mm256_andnot_ps(sign_bit, vdiff); + //sqx[i+4] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff))); } - auto score = hsum_float_4x4(sqx); - auto mask = _mm_cmp_ps(score, vbest, _CMP_LT_OQ); - best_index = _mm_or_si128(_mm_and_si128(_mm_castps_si128(mask), idx), - _mm_andnot_si128(_mm_castps_si128(mask), best_index)); - vbest = _mm_min_ps(vbest, score); + auto score = hsum_float_4x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_min_ps(vbest, score); + //score = hsum_float_4x8(sqx+4); + //mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + //best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + // _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + //vbest = _mm256_min_ps(vbest, score); } - _mm_store_ps(sx, vbest); - _mm_store_si128((__m128i *)index, best_index); - for (int i = 0; i < 4; ++i) { + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { if (sx[i] < best) { best = sx[i]; jbest = index[i]; } } auto& points = m_in_cluster[jbest]; - GGML_ASSERT(!points.empty() && points.size()%4 == 0); + GGML_ASSERT(!points.empty() && points.size()%8 == 0); int jbest_cluster = jbest; - vbest = _mm_set1_ps(INFINITY); - best_index = _mm_set1_epi32(-1); + vbest = _mm256_set1_ps(INFINITY); + best_index = _mm256_set1_epi32(-1); best = INFINITY; jbest = -1; - for (int j = 0; j < int(points.size()); j += 4) { - auto idx = _mm_loadu_si128((const __m128i*)(points.data() + j)); + for (int j = 0; j < int(points.size()); j += 8) { + auto idx = _mm256_loadu_si256((const __m256i*)(points.data() + j)); for (int i = 0; i < 4; ++i) { - auto vq = _mm_loadu_ps(m_values.data() + kGroupSize*points[j+i]); - auto vdiff = _mm_sub_ps(vq, vx); + auto vq = _mm256_set_m128(_mm_loadu_ps(m_values.data() + kGroupSize*points[j+i+4]), + _mm_loadu_ps(m_values.data() + kGroupSize*points[j+i+0])); + auto vdiff = _mm256_sub_ps(vq, vx_p); //sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, vdiff)); - vdiff = _mm_andnot_ps(sign_bit, vdiff); - sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, _mm_mul_ps(vdiff, vdiff))); + vdiff = _mm256_andnot_ps(sign_bit, vdiff); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff))); + //vdiff = _mm256_sub_ps(vq, vx_m); + ////sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, vdiff)); + //vdiff = _mm256_andnot_ps(sign_bit, vdiff); + //sqx[i+4] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff))); } - auto score = hsum_float_4x4(sqx); - auto mask = _mm_cmp_ps(score, vbest, _CMP_LT_OQ); - best_index = _mm_or_si128(_mm_and_si128(_mm_castps_si128(mask), idx), - _mm_andnot_si128(_mm_castps_si128(mask), best_index)); - vbest = _mm_min_ps(vbest, score); + auto score = hsum_float_4x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_min_ps(vbest, score); + //score = hsum_float_4x8(sqx+4); + //mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + //best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + // _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + //vbest = _mm256_min_ps(vbest, score); } - _mm_store_ps(sx, vbest); - _mm_store_si128((__m128i *)index, best_index); - for (int i = 0; i < 4; ++i) { + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { if (sx[i] < best) { best = sx[i]; jbest = index[i]; } } if (jbest < 0) { @@ -3962,12 +3989,19 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f scales[ib] = 0; if (!amax) continue; float best = 0; - for (int itry = -5; itry <= 5; ++itry) { + //for (int itry = -5; itry <= 5; ++itry) { + for (int itry = -3; itry <= 3; ++itry) { quantizer.find_best_match(amax/(96.f + 4.f*itry), xb, weight, best_idx); - auto [d, score] = quantizer.find_best_scale(xb, weight, best_idx); - if (score > best) { - best = score; - scales[ib] = d; + auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx); + if (score_p > best) { + best = score_p; + scales[ib] = dp; + } + quantizer.find_best_match(-amax/(96.f + 4.f*itry), xb, weight, best_idx); + auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx); + if (score_m > best) { + best = score_m; + scales[ib] = dm; } } //float d = amax/96.f;