diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 218bd253..28103350 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -427,13 +427,13 @@ static float find_best_scale(int block_size, const float * xb, const float * wei } static void analyze_x_v2(const char * name, int nrows, int n_per_row, const float * values, float& tot_mse, float& tot_mse_q, float& tot_elements) { - constexpr int kNumVal = 1 << 12; + constexpr int kNumVal = 1 << 15; constexpr int kBlockSize = 32; constexpr int kGroupSize = 8; constexpr int kNg = kBlockSize/kGroupSize; constexpr int kSuperBlockSize = 256; static_assert(kNumVal%8 == 0); - auto codes = make_values(kNumVal, kGroupSize, 31.5f); + auto codes = make_values(kNumVal, kGroupSize, 31.75f); int nthread = std::max(1, int(std::thread::hardware_concurrency()/2)); int chunk = (nrows + 8*nthread - 1)/(8*nthread); std::mutex mutex; @@ -461,14 +461,29 @@ static void analyze_x_v2(const char * name, int nrows, int n_per_row, const floa #endif for (int row = first; row < last; ++row) { auto xr = values + row*n_per_row; + float sigma2 = 0; + for (int j = 0; j < n_per_row; ++j) sigma2 += xr[j]*xr[j]; + sigma2 /= n_per_row; + //int last_ibl = -1; for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) { + //int ibl = ib*kBlockSize/kSuperBlockSize; + //if (ibl != last_ibl) { + // auto xbl = xr + ibl*kSuperBlockSize; + // int n = kSuperBlockSize*(ibl + 1) <= n_per_row ? kSuperBlockSize : n_per_row - ibl*kSuperBlockSize; + // float sumx2 = 0; + // for (int i = 0; i < n; ++i) sumx2 += xbl[i]*xbl[i]; + // sigma2 = sumx2/n; + // last_ibl = ibl; + //} auto xb = xr + kBlockSize*ib; + for (int i = 0; i < kBlockSize; ++i) weight[i] = 0.25f*sigma2 + xb[i]*xb[i]; float d = find_best_scale(kBlockSize, xb, weight.data(), iq4k_values, 5); float id = d ? 1/d : 0.f; #ifdef __AVX2__ auto vid = _mm256_set1_ps(id); for (int l = 0; l < kNg; ++l) { auto vx = _mm256_mul_ps(vid, _mm256_loadu_ps(xb+8*l)); + auto vw = _mm256_loadu_ps(weight.data() + 8*l); auto vbest = _mm256_set1_ps(INFINITY); auto best_index = _mm256_set1_epi32(-1); float best = INFINITY; int jbest = -1; @@ -477,7 +492,7 @@ static void analyze_x_v2(const char * name, int nrows, int n_per_row, const floa for (int i = 0; i < 8; ++i) { auto vq = _mm256_loadu_ps(codes.data() + kGroupSize*(j+i)); auto vdiff = _mm256_sub_ps(vq, vx); - sqx[i] = _mm256_mul_ps(vdiff, vdiff); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); } auto score = hsum_float_8x8(sqx); auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); @@ -496,13 +511,15 @@ static void analyze_x_v2(const char * name, int nrows, int n_per_row, const floa auto vq2 = _mm256_setzero_ps(); for (int l = 0; l < kNg; ++l) { auto vx = _mm256_loadu_ps(xb+8*l); + auto vw = _mm256_loadu_ps(weight.data() + 8*l); auto vq = _mm256_loadu_ps(codes.data() + kGroupSize*best_idx[ib*kNg + l]); - vqx = _mm256_fmadd_ps(vq, vx, vqx); - vq2 = _mm256_fmadd_ps(vq, vq, vq2); + auto vqw = _mm256_mul_ps(vq, vw); + vqx = _mm256_fmadd_ps(vqw, vx, vqx); + vq2 = _mm256_fmadd_ps(vqw, vq, vq2); } auto sumqx = hsum_float_8(vqx); auto sumq2 = hsum_float_8(vq2); - scales[ib] = sumqx/sumq2; + scales[ib] = sumq2 > 0 ? sumqx/sumq2 : 0.f; #else #endif }