diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index fd0eb7a6..9336de71 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -3164,7 +3164,8 @@ public: const float * values() const { return m_values.data(); } inline void find_best_match(float d, const float * xb, const float * weight, int * best_idx) const; - inline float find_best_scale(const float * xb, const float * weight, const int * best_idx) const; + inline void find_best_match(const float * xb, const float * weight, int * best_idx) const; + inline std::pair find_best_scale(const float * xb, const float * weight, const int * best_idx) const; static inline void set_values(uint32_t i, float * result, float scale) { constexpr uint32_t ka = 89226354; @@ -3223,18 +3224,18 @@ QuantizerIQKT::QuantizerIQKT() { } template -float QuantizerIQKT::find_best_scale(const float * xb, const float * weight, const int * best_idx) const { +std::pair QuantizerIQKT::find_best_scale( + const float * xb, const float * weight, const int * best_idx) const { float sumqx = 0, sumq2 = 0; -#ifdef z__AVX2__ - // TODO: fix this for kGroupSize != 8 +#ifdef __AVX2__ auto vqx = _mm256_setzero_ps(); auto vq2 = _mm256_setzero_ps(); - for (int l = 0; l < kNg; ++l) { - auto vx = _mm256_loadu_ps(xb+8*l); - auto vw = _mm256_loadu_ps(weight+8*l); - auto vq = kGroupSize == 8 ? _mm256_loadu_ps(m_values.data() + kGroupSize*best_idx[l]) : - _mm256_set_m128(_mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l+1]), - _mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l+0])); + for (int l = 0; l < kBlockSize; l += 8) { + auto vx = _mm256_loadu_ps(xb+l); + auto vw = _mm256_loadu_ps(weight+l); + auto vq = kGroupSize == 8 ? _mm256_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize]) : + _mm256_set_m128(_mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+1]), + _mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+0])); auto vqw = _mm256_mul_ps(vq, vw); vqx = _mm256_fmadd_ps(vqw, vx, vqx); vq2 = _mm256_fmadd_ps(vqw, vq, vq2); @@ -3252,7 +3253,149 @@ float QuantizerIQKT::find_best_s } } #endif - return sumq2 > 0 ? sumqx/sumq2 : 0.f; + return sumq2 > 0 ? std::make_pair(sumqx/sumq2, sumqx*sumqx/sumq2) : std::make_pair(0.f, 0.f); +} + +template +void QuantizerIQKT::find_best_match(const float * xb, const float * weight, int * best_idx) const { + int ncluster = m_clusters.size()/kGroupSize; +#ifdef __AVX2__ + if constexpr (kGroupSize == 8) { + __m256 sqx[8]; + const __m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + float sx[8]; + int index[8]; + for (int l = 0; l < kNg; ++l) { + auto xl = xb + 8*l; + auto wl = weight + 8*l; + auto vx = _mm256_loadu_ps(xl); + auto vw = _mm256_loadu_ps(wl); + auto vbest = _mm256_set1_ps(0.f); + auto best_index = _mm256_set1_epi32(-1); + float best = 0; int jbest = -1; + for (int j = 0; j < ncluster; j += 8) { + auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx); + for (int i = 0; i < 8; ++i) { + auto vq = _mm256_loadu_ps(m_clusters.data() + kGroupSize*(j+i)); + auto sumqx = _mm256_mul_ps(vw, _mm256_mul_ps(vx, vq)); + auto sumq2 = hsum_float_8(_mm256_mul_ps(vw, _mm256_mul_ps(vq, vq))); + sqx[i] = _mm256_mul_ps(_mm256_set1_ps(sumq2 > 0 ? 1/sumq2 : 0), _mm256_mul_ps(sumqx, sumqx)); + } + auto score = hsum_float_8x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_GT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_max_ps(vbest, score); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { + if (sx[i] > best) { best = sx[i]; jbest = index[i]; } + } + auto& points = m_in_cluster[jbest]; + GGML_ASSERT(!points.empty() && points.size()%8 == 0); + int jbest_cluster = jbest; + vbest = _mm256_set1_ps(0.f); + best_index = _mm256_set1_epi32(-1); + best = 0; jbest = -1; + for (int j = 0; j < int(points.size()); j += 8) { + auto idx = _mm256_loadu_si256((const __m256i*)(points.data() + j)); + for (int i = 0; i < 8; ++i) { + auto vq = _mm256_loadu_ps(m_values.data() + kGroupSize*points[j+i]); + auto sumqx = _mm256_mul_ps(vw, _mm256_mul_ps(vx, vq)); + auto sumq2 = hsum_float_8(_mm256_mul_ps(vw, _mm256_mul_ps(vq, vq))); + sqx[i] = _mm256_mul_ps(_mm256_set1_ps(sumq2 > 0 ? 1/sumq2 : 0), _mm256_mul_ps(sumqx, sumqx)); + } + auto score = hsum_float_8x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_GT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_max_ps(vbest, score); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { + if (sx[i] > best) { best = sx[i]; jbest = index[i]; } + } + if (jbest < 0) { + fprintf(stderr, "Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size())); + GGML_ASSERT(false); + } + best_idx[l] = jbest; + } + } else { + __m128 sqx[4]; + const __m128i add_idx = _mm_set_epi32(3, 2, 1, 0); + float sx[4]; + int index[4]; + for (int l = 0; l < kNg; ++l) { + auto xl = xb + 4*l; + auto wl = weight + 4*l; + auto vx = _mm_loadu_ps(xl); + auto sumx2 = hsum_float_4(_mm_mul_ps(vx, vx)); + if (!sumx2) { + best_idx[l] = 0; continue; + } + auto vw = _mm_loadu_ps(wl); + auto vbest = _mm_set1_ps(0); + auto best_index = _mm_set1_epi32(-1); + float best = 0; int jbest = -1; + for (int j = 0; j < ncluster; j += 4) { + auto idx = _mm_add_epi32(_mm_set1_epi32(j), add_idx); + for (int i = 0; i < 4; ++i) { + auto vq = _mm_loadu_ps(m_clusters.data() + kGroupSize*(j+i)); + auto sumqx = _mm_mul_ps(vw, _mm_mul_ps(vx, vq)); + auto sumq2 = hsum_float_4(_mm_mul_ps(vw, _mm_mul_ps(vq, vq))); + sqx[i] = _mm_mul_ps(_mm_set1_ps(sumq2 > 0 ? 1/sumq2 : 0), _mm_mul_ps(sumqx, sumqx)); + } + auto score = hsum_float_4x4(sqx); + auto mask = _mm_cmp_ps(score, vbest, _CMP_GT_OQ); + best_index = _mm_or_si128(_mm_and_si128(_mm_castps_si128(mask), idx), + _mm_andnot_si128(_mm_castps_si128(mask), best_index)); + vbest = _mm_max_ps(vbest, score); + } + _mm_store_ps(sx, vbest); + _mm_store_si128((__m128i *)index, best_index); + for (int i = 0; i < 4; ++i) { + if (sx[i] > best) { best = sx[i]; jbest = index[i]; } + } + GGML_ASSERT(jbest >= 0 && jbest <= int(m_in_cluster.size())); + auto& points = m_in_cluster[jbest]; + GGML_ASSERT(!points.empty() && points.size()%4 == 0); + int jbest_cluster = jbest; + vbest = _mm_set1_ps(0); + best_index = _mm_set1_epi32(-1); + best = 0; jbest = -1; + for (int j = 0; j < int(points.size()); j += 4) { + auto idx = _mm_loadu_si128((const __m128i*)(points.data() + j)); + for (int i = 0; i < 4; ++i) { + auto vq = _mm_loadu_ps(m_values.data() + kGroupSize*points[j+i]); + auto sumqx = _mm_mul_ps(vw, _mm_mul_ps(vx, vq)); + auto sumq2 = hsum_float_4(_mm_mul_ps(vw, _mm_mul_ps(vq, vq))); + sqx[i] = _mm_mul_ps(_mm_set1_ps(sumq2 > 0 ? 1/sumq2 : 0), _mm_mul_ps(sumqx, sumqx)); + } + auto score = hsum_float_4x4(sqx); + auto mask = _mm_cmp_ps(score, vbest, _CMP_GT_OQ); + best_index = _mm_or_si128(_mm_and_si128(_mm_castps_si128(mask), idx), + _mm_andnot_si128(_mm_castps_si128(mask), best_index)); + vbest = _mm_max_ps(vbest, score); + } + _mm_store_ps(sx, vbest); + _mm_store_si128((__m128i *)index, best_index); + for (int i = 0; i < 4; ++i) { + if (sx[i] > best) { best = sx[i]; jbest = index[i]; } + } + if (jbest < 0) { + fprintf(stderr, "Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size())); + GGML_ASSERT(false); + } + best_idx[l] = jbest; + } + } +#else + // TODO + std::memset(best_idx, 0, kNg*sizeof(int)); +#endif } template @@ -3329,6 +3472,7 @@ void QuantizerIQKT::find_best_ma } else { __m128 sqx[4]; const __m128i add_idx = _mm_set_epi32(3, 2, 1, 0); + const __m128 sign_bit = _mm_set1_ps(-0.f); float sx[4]; int index[4]; auto vid = _mm_set1_ps(id); @@ -3345,7 +3489,9 @@ void QuantizerIQKT::find_best_ma for (int i = 0; i < 4; ++i) { auto vq = _mm_loadu_ps(m_clusters.data() + kGroupSize*(j+i)); auto vdiff = _mm_sub_ps(vq, vx); - sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, vdiff)); + //sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, vdiff)); + vdiff = _mm_andnot_ps(sign_bit, vdiff); + sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, _mm_mul_ps(vdiff, vdiff))); } auto score = hsum_float_4x4(sqx); auto mask = _mm_cmp_ps(score, vbest, _CMP_LT_OQ); @@ -3369,7 +3515,9 @@ void QuantizerIQKT::find_best_ma for (int i = 0; i < 4; ++i) { auto vq = _mm_loadu_ps(m_values.data() + kGroupSize*points[j+i]); auto vdiff = _mm_sub_ps(vq, vx); - sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, vdiff)); + //sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, vdiff)); + vdiff = _mm_andnot_ps(sign_bit, vdiff); + sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, _mm_mul_ps(vdiff, vdiff))); } auto score = hsum_float_4x4(sqx); auto mask = _mm_cmp_ps(score, vbest, _CMP_LT_OQ); @@ -3589,7 +3737,8 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f } float d = amax/96.f; quantizer.find_best_match(d, xb, weight, best_idx); - scales[ib] = quantizer.find_best_scale(xb, weight, best_idx); + auto pair = quantizer.find_best_scale(xb, weight, best_idx); + scales[ib] = pair.first; for (int j = 0; j < Q::kNg; ++j) qs[j] = best_idx[j]; qs += Q::kNg; @@ -3665,7 +3814,8 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f const float * xb = xbl + Q::kBlockSize*ib; const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; for (int j = 0; j < Q::kNg; ++j) best_idx[j] = qs[ib*Q::kNg+j]; - scales[ib] = quantizer.find_best_scale(xb, weight, best_idx); + auto pair = quantizer.find_best_scale(xb, weight, best_idx); + scales[ib] = pair.first; } } float id = d ? 1/d : 0.f; @@ -3809,9 +3959,21 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f float ax = std::abs(xb[j]); amax = std::max(amax, ax); } - float d = amax/96.f; - quantizer.find_best_match(d, xb, weight, best_idx); - scales[ib] = quantizer.find_best_scale(xb, weight, best_idx); + scales[ib] = 0; + if (!amax) continue; + float best = 0; + for (int itry = -5; itry <= 5; ++itry) { + quantizer.find_best_match(amax/(96.f + 4.f*itry), xb, weight, best_idx); + auto [d, score] = quantizer.find_best_scale(xb, weight, best_idx); + if (score > best) { + best = score; + scales[ib] = d; + } + } + //float d = amax/96.f; + //quantizer.find_best_match(d, xb, weight, best_idx); + ////quantizer.find_best_match(xb, weight, best_idx); + //scales[ib] = quantizer.find_best_scale(xb, weight, best_idx); for (int j = 0; j < Q::kNg; ++j) { int jj = ib*Q::kNg + j;