diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index bbd53487..b2b1f819 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -542,8 +542,7 @@ void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const fl float scales[QK_K/kBlockSize]; float weight[kBlockSize]; float sumx[kBlockSize+1], sumw[kBlockSize+1]; - float sw[QK_K/kBlockSize]; - int8_t Ls[QK_K/kBlockSize]; + uint8_t L[QK_K]; std::array, kBlockSize> pairs; @@ -561,7 +560,7 @@ void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const fl uint16_t extra = 0; - float max_abs_scale = 0; + float max_abs_scale = 0, max_scale = 0; for (int ib = 0; ib < QK_K/kBlockSize; ++ib) { const float * xb = xbl + kBlockSize*ib; @@ -571,9 +570,7 @@ void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const fl } else { for (int j = 0; j < kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; } - sw[ib] = 0; for (int j = 0; j < kBlockSize; ++j) { - sw[ib] += weight[j]; pairs[j] = {xb[j], j}; } std::sort(pairs.begin(), pairs.end()); @@ -586,6 +583,8 @@ void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const fl float best = 0, d = 0; bool is_shifted = false; float sumqx, sumq2; + int besti1 = -1, besti2 = -1, besti3 = -1; + bool reverse = false; for (int i1 = 0; i1 < kBlockSize; ++i1) { for (int i2 = i1; i2 < kBlockSize; ++i2) { for (int i3 = i2; i3 < kBlockSize; ++i3) { @@ -594,6 +593,7 @@ void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const fl sumq2 = (sumw[i1] - sumw[ 0])*iq2nl_values[0]*iq2nl_values[0] + (sumw[i2] - sumw[i1])*iq2nl_values[1]*iq2nl_values[1] + (sumw[i3] - sumw[i2])*iq2nl_values[2]*iq2nl_values[2] + (sumw[kBlockSize] - sumw[i3])*iq2nl_values[3]*iq2nl_values[3]; if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + besti1 = i1; besti2 = i2; besti3 = i3; reverse = false; d = sumqx/sumq2; best = d*sumqx; is_shifted = false; } sumqx = (sumx[i1] - sumx[ 0])*shifted_values[0] + (sumx[i2] - sumx[i1])*shifted_values[1] @@ -601,6 +601,7 @@ void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const fl sumq2 = (sumw[i1] - sumw[ 0])*shifted_values[0]*shifted_values[0] + (sumw[i2] - sumw[i1])*shifted_values[1]*shifted_values[1] + (sumw[i3] - sumw[i2])*shifted_values[2]*shifted_values[2] + (sumw[kBlockSize] - sumw[i3])*shifted_values[3]*shifted_values[3]; if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + besti1 = i1; besti2 = i2; besti3 = i3; reverse = false; d = sumqx/sumq2; best = d*sumqx; is_shifted = true; } sumqx = (sumx[i1] - sumx[ 0])*iq2nl_values[3] + (sumx[i2] - sumx[i1])*iq2nl_values[2] @@ -608,6 +609,7 @@ void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const fl sumq2 = (sumw[i1] - sumw[ 0])*iq2nl_values[3]*iq2nl_values[3] + (sumw[i2] - sumw[i1])*iq2nl_values[2]*iq2nl_values[2] + (sumw[i3] - sumw[i2])*iq2nl_values[1]*iq2nl_values[1] + (sumw[kBlockSize] - sumw[i3])*iq2nl_values[0]*iq2nl_values[0]; if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + besti1 = i1; besti2 = i2; besti3 = i3; reverse = true; d = sumqx/sumq2; best = d*sumqx; is_shifted = false; } sumqx = (sumx[i1] - sumx[ 0])*shifted_values[3] + (sumx[i2] - sumx[i1])*shifted_values[2] @@ -615,6 +617,7 @@ void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const fl sumq2 = (sumw[i1] - sumw[ 0])*shifted_values[3]*shifted_values[3] + (sumw[i2] - sumw[i1])*shifted_values[2]*shifted_values[2] + (sumw[i3] - sumw[i2])*shifted_values[1]*shifted_values[1] + (sumw[kBlockSize] - sumw[i3])*shifted_values[0]*shifted_values[0]; if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + besti1 = i1; besti2 = i2; besti3 = i3; reverse = true; d = sumqx/sumq2; best = d*sumqx; is_shifted = true; } } @@ -623,21 +626,66 @@ void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const fl scales[ib] = d; if (is_shifted) extra |= (1 << ib); + if (reverse) { + for (int j = 0; j < besti1; ++j) L[ib*kBlockSize+pairs[j].second] = 3; + for (int j = besti1; j < besti2; ++j) L[ib*kBlockSize+pairs[j].second] = 2; + for (int j = besti2; j < besti3; ++j) L[ib*kBlockSize+pairs[j].second] = 1; + for (int j = besti3; j < kBlockSize; ++j) L[ib*kBlockSize+pairs[j].second] = 0; + } else { + for (int j = 0; j < besti1; ++j) L[ib*kBlockSize+pairs[j].second] = 0; + for (int j = besti1; j < besti2; ++j) L[ib*kBlockSize+pairs[j].second] = 1; + for (int j = besti2; j < besti3; ++j) L[ib*kBlockSize+pairs[j].second] = 2; + for (int j = besti3; j < kBlockSize; ++j) L[ib*kBlockSize+pairs[j].second] = 3; + } + float abs_scale = fabsf(scales[ib]); - max_abs_scale = std::max(max_abs_scale, abs_scale); + if (abs_scale > max_abs_scale) { + max_abs_scale = abs_scale; + max_scale = scales[ib]; + } } if (!max_abs_scale) continue; - float d = make_qx_quants(QK_K/kBlockSize, 8, scales, Ls, sw); + + float d = -max_scale/8; + float best_id = 1/d; + + float best = 0; + for (int itry = -17; itry <= 17; ++itry) { + float id = (-8 + 0.1f*itry)/max_scale; + double sumqx = 0, sumq2 = 0; + for (int ib = 0; ib < QK_K/kBlockSize; ++ib) { + auto xb = xbl + kBlockSize*ib; + auto Lb = L + kBlockSize*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*kBlockSize; + for (int j = 0; j < kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j]; + } + auto block_values = extra & (1 << ib) ? shifted_values : iq2nl_values; + int ls = nearest_int(id*scales[ib]); + ls = std::max(-8, std::min(7, ls)); + for (int j = 0; j < kBlockSize; ++j) { + float w = weight[j]; + float q = block_values[Lb[j]]*ls; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; best_id = id; + } + } if (!d) continue; - //float d = -max_scale/8; y[ibl].extra = extra; - float id = 1/d; + + best_id = 0.5f*(best_id + 1/d); float sumqx = 0, sumq2 = 0; for (int ib = 0; ib < QK_K/kBlockSize; ++ib) { - int ls = nearest_int(id*scales[ib]); + int ls = nearest_int(best_id*scales[ib]); ls = std::max(-8, std::min(7, ls)); y[ibl].scales[ib/2] |= ((ls + 8) << 4*(ib%2)); float dl = d * ls; @@ -766,8 +814,6 @@ void quantize_row_iq2_ks_impl(const float * x, void * vy, int n_per_row, const f float sval[4] = {float(iq2nl_values[4]), float(iq2nl_values[5]), float(iq2nl_values[6]), float(iq2nl_values[7])}; float sums[16]; - const int8_t * shifted_values = iq2nl_values + 4; - const int nblock = n_per_row/QK_K; float max_scale = 0, amax_scale = 0;