From a961a48e884cef1d4076408d3ea9d67a301c6764 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 6 Nov 2024 11:05:07 +0200 Subject: [PATCH] WIP --- examples/quantize-stats/quantize-stats.cpp | 186 ++++++++++++++++++++- 1 file changed, 180 insertions(+), 6 deletions(-) diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 99f03495..218bd253 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -296,7 +296,7 @@ static const int8_t scale_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10 // return result; //} -static std::vector make_values(int nval, int n_per_val) { +static std::vector make_values(int nval, int n_per_val, float scale = 16.f) { std::vector result(nval*n_per_val); uint16_t m16 = ggml_fp32_to_fp16(0.922f); uint32_t m32 = (uint32_t(m16) << 16) | m16; @@ -308,10 +308,8 @@ static std::vector make_values(int nval, int n_per_val) { x = a*x + b; uint32_t s = (x & 0b10001111111111111000111111111111) ^ m32; float val = ggml_fp16_to_fp32(s & 65535) + ggml_fp16_to_fp32(s >> 16); - //int ival = nearest_int(31.5f*val); - int ival = nearest_int(16.f*val); + int ival = nearest_int(scale*val); data[k] = ival; - //data[k] = ggml_fp16_to_fp32(s & 65535) + ggml_fp16_to_fp32(s >> 16); } data += n_per_val; } @@ -369,7 +367,183 @@ inline int best_index_scale(const int8_t * values, float x) { ix = scale_index[ix]; return ix < 16 ? ix : x - values[ix-16] < values[ix-15] - x ? ix-16 : ix-15; } +inline int best_index_iq4nl(const int8_t * values, float x) { return best_index_scale(values, x); } +static float find_best_scale(int block_size, const float * xb, const float * weight, const int8_t * values, int ntry) { + float amax = 0, max = 0; + for (int j = 0; j < block_size; ++j) { + float ax = fabsf(xb[j]); + if (ax > amax) { + amax = ax; max = xb[j]; + } + } + return amax/96.f; //120.f; //127.f; + if (!amax) return 0.f; + float d = ntry > 0 ? -max/values[0] : max/values[0]; + float id = 1/d; + float sumqx_p = 0, sumq2_p = 0; + float sumqx_m = 0, sumq2_m = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq4nl(values, al); + float q = values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq4nl(values, -al); + q = values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + d = sumqx_p/sumq2_p; + float best = d*sumqx_p; + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d*sumqx_m; + } + for (int itry = -ntry; itry <= ntry; ++itry) { + id = (itry + values[0])/max; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq4nl(values, al); + float q = values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq4nl(values, -al); + q = values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; + } + } + return d; +} + +static void analyze_x_v2(const char * name, int nrows, int n_per_row, const float * values, float& tot_mse, float& tot_mse_q, float& tot_elements) { + constexpr int kNumVal = 1 << 12; + constexpr int kBlockSize = 32; + constexpr int kGroupSize = 8; + constexpr int kNg = kBlockSize/kGroupSize; + constexpr int kSuperBlockSize = 256; + static_assert(kNumVal%8 == 0); + auto codes = make_values(kNumVal, kGroupSize, 31.5f); + int nthread = std::max(1, int(std::thread::hardware_concurrency()/2)); + int chunk = (nrows + 8*nthread - 1)/(8*nthread); + std::mutex mutex; + int counter = 0; + float mse = 0, mse_q = 0; + auto compute = [&mutex, &counter, &mse, &mse_q, &codes, values, nrows, n_per_row, chunk] () { + float lmse = 0, lmse_q = 0; + std::vector scales(n_per_row/kBlockSize); + std::vector best_idx(n_per_row/kGroupSize); + std::vector weight(kBlockSize, 1.f); + while (true) { + std::unique_lock lock(mutex); + int first = counter; counter += chunk; + if (first >= nrows) { + mse += lmse; mse_q += lmse_q; + return; + } + lock.unlock(); + int last = std::min(first + chunk, nrows); +#ifdef __AVX2__ + __m256 sqx[8]; + __m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + float sx[8]; + int index[8]; +#endif + for (int row = first; row < last; ++row) { + auto xr = values + row*n_per_row; + for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) { + auto xb = xr + kBlockSize*ib; + float d = find_best_scale(kBlockSize, xb, weight.data(), iq4k_values, 5); + float id = d ? 1/d : 0.f; +#ifdef __AVX2__ + auto vid = _mm256_set1_ps(id); + for (int l = 0; l < kNg; ++l) { + auto vx = _mm256_mul_ps(vid, _mm256_loadu_ps(xb+8*l)); + auto vbest = _mm256_set1_ps(INFINITY); + auto best_index = _mm256_set1_epi32(-1); + float best = INFINITY; int jbest = -1; + for (int j = 0; j < kNumVal; j += 8) { + auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx); + for (int i = 0; i < 8; ++i) { + auto vq = _mm256_loadu_ps(codes.data() + kGroupSize*(j+i)); + auto vdiff = _mm256_sub_ps(vq, vx); + sqx[i] = _mm256_mul_ps(vdiff, vdiff); + } + auto score = hsum_float_8x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_min_ps(vbest, score); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { + if (sx[i] < best) { best = sx[i]; jbest = index[i]; } + } + best_idx[ib*kNg + l] = jbest; + } + auto vqx = _mm256_setzero_ps(); + auto vq2 = _mm256_setzero_ps(); + for (int l = 0; l < kNg; ++l) { + auto vx = _mm256_loadu_ps(xb+8*l); + auto vq = _mm256_loadu_ps(codes.data() + kGroupSize*best_idx[ib*kNg + l]); + vqx = _mm256_fmadd_ps(vq, vx, vqx); + vq2 = _mm256_fmadd_ps(vq, vq, vq2); + } + auto sumqx = hsum_float_8(vqx); + auto sumq2 = hsum_float_8(vq2); + scales[ib] = sumqx/sumq2; +#else +#endif + } + float amax_scale = std::abs(scales[0]); + float max_scale = scales[0]; + for (int ib = 1; ib < n_per_row/kBlockSize; ++ib) { + float ax = std::abs(scales[ib]); + if (ax > amax_scale) { + amax_scale = ax; + max_scale = scales[ib]; + } + } + float d = max_scale/scale_values[0]; + float id = d ? 1/d : 0.f; + for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) { + int ls = best_index_scale(scale_values, id*scales[ib]); + float dl = d * scale_values[ls]; + auto xb = xr + kBlockSize*ib; + for (int l = 0; l < kNg; ++l) { + auto q = codes.data() + kGroupSize*best_idx[ib*kNg+l]; + for (int k = 0; k < kGroupSize; ++k) { + float diff1 = xb[kGroupSize*l + k] - scales[ib]*q[k]; + float diff2 = xb[kGroupSize*l + k] - dl*q[k]; + lmse += diff1*diff1; + lmse_q += diff2*diff2; + } + } + } + } + } + }; + std::vector workers(nthread-1); + for (auto& w : workers) w = std::thread(compute); + compute(); + for (auto& w : workers) w.join(); + tot_mse += mse; + tot_mse_q += mse_q; + tot_elements += n_per_row*nrows; + printf("%s: %g %g %g %g\n", name, sqrt(mse/(n_per_row*nrows)), sqrt(tot_mse/tot_elements), + sqrt(mse_q/(n_per_row*nrows)), sqrt(tot_mse_q/tot_elements)); +} static void analyze_x(const char * name, int nrows, int n_per_row, const float * values, float& tot_mse, float& tot_mse_q, float& tot_elements) { constexpr int kNumVal = 1 << 12; @@ -635,7 +809,7 @@ static void analyze_iq4ks(const ggml_tensor * t, float& tot_mse, float& tot_mse_ } if (t->type == GGML_TYPE_F32) { //analyze_iq4ks(t->name, t->ne[1], t->ne[0], (const float *)t->data, tot_mse, tot_elements); - analyze_x(t->name, t->ne[1], t->ne[0], (const float *)t->data, tot_mse, tot_mse_q, tot_elements); + analyze_x_v2(t->name, t->ne[1], t->ne[0], (const float *)t->data, tot_mse, tot_mse_q, tot_elements); } else { std::vector aux(t->ne[0]*t->ne[1]); if (t->type == GGML_TYPE_F16) { @@ -644,7 +818,7 @@ static void analyze_iq4ks(const ggml_tensor * t, float& tot_mse, float& tot_mse_ ggml_bf16_to_fp32_row((const ggml_bf16_t *)t->data, aux.data(), aux.size()); } //analyze_iq4ks(t->name, t->ne[1], t->ne[0], aux.data(), tot_mse, tot_elements); - analyze_x(t->name, t->ne[1], t->ne[0], aux.data(), tot_mse, tot_mse_q, tot_elements); + analyze_x_v2(t->name, t->ne[1], t->ne[0], aux.data(), tot_mse, tot_mse_q, tot_elements); } }