From b68c2cb0e06c4072eb238a78cb41c16c9da186ae Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 15 Oct 2024 13:49:51 +0300 Subject: [PATCH] iq4_kss: slightly better quantization --- ggml/src/iqk/iqk_quantize.cpp | 241 ++++++++++++++++++++++++++++++++-- 1 file changed, 233 insertions(+), 8 deletions(-) diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index b0f294a5..9cc9dfb5 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -2893,10 +2893,236 @@ uint16_t prune_iq4ks(uint16_t v, const int8_t * values, const float * x, const f q4[jbest] = bestq; return (q4[0] | (q4[1] << 4) | (q4[2] << 8) | (q4[3] << 12)); } +static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy, + float * all_scales, float * weight, + const int8_t * values, + const float * quant_weights, + const uint16_t * table, + const int ntry) { + + constexpr int super_block_size = 256; + constexpr int block_size = 32; + + float * dptr = (float *)cy; + *dptr = 0; + block_iq4_kss * y = (block_iq4_kss *)(dptr + 1); + + const int8_t * shifted_values = values + 16; + + uint16_t vps[block_size/2], vms[block_size/2], vs[block_size/2]; + //uint8_t Lp[block_size], Lm[block_size], L[block_size]; + float xv[4], wv[4]; + + float amax_scale = 0; + + for (int ibl = 0; ibl < n_per_row/super_block_size; ++ibl) { + memset(&y[ibl], 0, sizeof(block_iq4_kss)); + const float * xbl = x + ibl*super_block_size; + auto scales = all_scales + ibl*(super_block_size/block_size); + float sigma2 = 0; + for (int j = 0; j < super_block_size; ++j) sigma2 += xbl[j]*xbl[j]; + sigma2 *= 2.f/super_block_size; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { + const float * xb = xbl + ib*block_size; + if (quant_weights) { + const float * qw = quant_weights + ibl*super_block_size + ib*block_size; + for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j]; + } + float amax = 0, max = 0; + for (int j = 0; j < block_size; ++j) { + float ax = fabsf(xb[j]); + if (ax > amax) { + amax = ax; max = xb[j]; + } + } + if (!amax) { + scales[ib] = 0; + continue; + } + float best = 0; + bool is_shifted = false; + float d = -max/iq4k_values[0]; + std::memset(vs, 0, block_size); + for (int itry = -ntry; itry <= ntry; ++itry) { + float id = (itry + values[0])/max; + float sumqx_p = 0, sumq2_p = 0; + float sumqx_m = 0, sumq2_m = 0; + //for (int j = 0; j < block_size; ++j) { + // float al = id*xb[j]; + // Lp[j] = best_index_iq4nl(values, al); + // Lm[j] = best_index_iq4nl(values, -al); + //} + float this_d = 1/id; + for (int k = 0; k < block_size/4; ++k) { + xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2]; + wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2]; + uint16_t vp = 0, vm = 0; + for (int j = 0; j < 4; ++j) { + float al = id*xv[j]; + vp |= (best_index_iq4nl(values, al) << 4*j); + vm |= (best_index_iq4nl(values, -al) << 4*j); + } + vp = prune_iq4ks(vp, values, xv, wv, this_d); + vm = prune_iq4ks(vm, values, xv, wv, this_d); + for (int j = 0; j < 4; ++j) { + float w = wv[j]; + float q = values[(vp >> 4*j) & 0xf]; + sumqx_p += w*q*xv[j]; + sumq2_p += w*q*q; + q = values[(vm >> 4*j) & 0xf]; + sumqx_m += w*q*xv[j]; + sumq2_m += w*q*q; + } + vps[k] = vp; + vms[k] = vm; + //uint16_t vp = Lp[2*k+0] | (Lp[2*k+0+block_size/2] << 4) | (Lp[2*k+1] << 8) | (Lp[2*k+1+block_size/2] << 12); + //uint16_t vm = Lm[2*k+0] | (Lm[2*k+0+block_size/2] << 4) | (Lm[2*k+1] << 8) | (Lm[2*k+1+block_size/2] << 12); + //auto vps = prune_iq4ks(vp, values, xv, wv, this_d); + //auto vms = prune_iq4ks(vm, values, xv, wv, -this_d); + //Lp[2*k+0] = (vps >> 0) & 0xf; Lp[2*k+0+block_size/2] = (vps >> 4) & 0xf; + //Lp[2*k+1] = (vps >> 8) & 0xf; Lp[2*k+1+block_size/2] = (vps >> 12) & 0xf; + //Lm[2*k+0] = (vms >> 0) & 0xf; Lm[2*k+0+block_size/2] = (vms >> 4) & 0xf; + //Lm[2*k+1] = (vms >> 8) & 0xf; Lm[2*k+1+block_size/2] = (vms >> 12) & 0xf; + } + //for (int j = 0; j < block_size; ++j) { + // float w = weight[j]; + // float q = values[Lp[j]]; + // sumqx_p += w*q*xb[j]; + // sumq2_p += w*q*q; + // q = values[Lm[j]]; + // sumqx_m += w*q*xb[j]; + // sumq2_m += w*q*q; + //} + bool copy_p = false, copy_m = false; + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; copy_p = true; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; copy_m = true; + } + if (copy_m) { + std::memcpy(vs, vms, block_size); + } else if (copy_p) { + std::memcpy(vs, vps, block_size); + } + + id = (itry + shifted_values[0])/max; + this_d = 1/id; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int k = 0; k < block_size/4; ++k) { + xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2]; + wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2]; + uint16_t vp = 0, vm = 0; + for (int j = 0; j < 4; ++j) { + float al = id*xv[j]; + vp |= (best_index_iq4nl(shifted_values, al) << 4*j); + vm |= (best_index_iq4nl(shifted_values, -al) << 4*j); + } + vp = prune_iq4ks(vp, shifted_values, xv, wv, this_d); + vm = prune_iq4ks(vm, shifted_values, xv, wv, this_d); + for (int j = 0; j < 4; ++j) { + float w = wv[j]; + float q = shifted_values[(vp >> 4*j) & 0xf]; + sumqx_p += w*q*xv[j]; + sumq2_p += w*q*q; + q = shifted_values[(vm >> 4*j) & 0xf]; + sumqx_m += w*q*xv[j]; + sumq2_m += w*q*q; + } + vps[k] = vp; + vms[k] = vm; + } + copy_p = copy_m = false; + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; copy_p = true; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; copy_m = true; + } + if (copy_m) { + std::memcpy(vs, vms, block_size); + } else if (copy_p) { + std::memcpy(vs, vps, block_size); + } + //for (int j = 0; j < block_size; ++j) { + // float al = id*xb[j]; + // Lp[j] = best_index_iq4nl(shifted_values, al); + // Lm[j] = best_index_iq4nl(shifted_values, -al); + //} + //for (int k = 0; k < block_size/4; ++k) { + // xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2]; + // wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2]; + // uint16_t vp = Lp[2*k+0] | (Lp[2*k+0+block_size/2] << 4) | (Lp[2*k+1] << 8) | (Lp[2*k+1+block_size/2] << 12); + // uint16_t vm = Lm[2*k+0] | (Lm[2*k+0+block_size/2] << 4) | (Lm[2*k+1] << 8) | (Lm[2*k+1+block_size/2] << 12); + // auto vps = prune_iq4ks(vp, shifted_values, xv, wv, this_d); + // auto vms = prune_iq4ks(vm, shifted_values, xv, wv, -this_d); + // Lp[2*k+0] = (vps >> 0) & 0xf; Lp[2*k+0+block_size/2] = (vps >> 4) & 0xf; + // Lp[2*k+1] = (vps >> 8) & 0xf; Lp[2*k+1+block_size/2] = (vps >> 12) & 0xf; + // Lm[2*k+0] = (vms >> 0) & 0xf; Lm[2*k+0+block_size/2] = (vms >> 4) & 0xf; + // Lm[2*k+1] = (vms >> 8) & 0xf; Lm[2*k+1+block_size/2] = (vms >> 12) & 0xf; + //} + //for (int j = 0; j < block_size; ++j) { + // float w = weight[j]; + // float q = shifted_values[Lp[j]]; + // sumqx_p += w*q*xb[j]; + // sumq2_p += w*q*q; + // q = shifted_values[Lm[j]]; + // sumqx_m += w*q*xb[j]; + // sumq2_m += w*q*q; + //} + //copy_p = copy_m = false; + //if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + // d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; copy_p = true; + //} + //if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + // d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; copy_m = true; + //} + //if (copy_m) { + // std::memcpy(L, Lm, block_size); + //} else if (copy_p) { + // std::memcpy(L, Lp, block_size); + //} + } + for (int k = 0; k < block_size/8; ++k) { + auto v1 = table[vs[2*k+0] & 0x7fff]; + auto v2 = table[vs[2*k+1] & 0x7fff]; + y[ibl].qs[(block_size/8)*ib + k] = v1 | (v2 << 15); + } + //for (int k = 0; k < block_size/8; ++k) { + // uint16_t v1 = Lp[4*k+0] | (Lp[4*k+0+block_size/2] << 4) | (Lp[4*k+1] << 8) | (Lp[4*k+1+block_size/2] << 12); + // uint16_t v2 = Lp[4*k+2] | (Lp[4*k+2+block_size/2] << 4) | (Lp[4*k+3] << 8) | (Lp[4*k+3+block_size/2] << 12); + // v1 = table[v1 & 0x7fff]; + // v2 = table[v2 & 0x7fff]; + // y[ibl].qs[(block_size/8)*ib + k] = v1 | (v2 << 15); + //} + if (is_shifted) y[ibl].qs[(block_size/8)*ib] |= (1u << 30); + scales[ib] = d; + amax_scale = std::max(amax_scale, std::abs(d)); + } + } + float d = amax_scale/127; + *dptr = d; + if (!d) return; + float id = 1/d; + for (int ibl = 0; ibl < n_per_row/super_block_size; ++ibl) { + auto scales = all_scales + (super_block_size/block_size)*ibl; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]+127.f)); + l = std::max(0, std::min(127, l)) << 1; + for (int k = 0; k < 4; ++k) { + y[ibl].qs[4*ib+k] |= (((l >> 2*k) & 3) << 30); + } + } + } +} + void prune_iq4ks_to_iq4kss(int n_per_row, const uint16_t * table, const char * cx, const float * x, char *cy, const float * quant_weights, float * weight, float * all_scales) { constexpr int kBlockSize = 32; - float xv[4], w[4]; + float xv[4], wv[4]; uint16_t vps[kBlockSize/4]; const float * dptr_ks = (const float *)cx; const float d_ks = *dptr_ks; @@ -2925,14 +3151,14 @@ void prune_iq4ks_to_iq4kss(int n_per_row, const uint16_t * table, const char * c float dl = d_ks * ((iq4ks[ibl].scales[ib] & 254) - 127); float sumqx = 0, sumq2 = 0; for (int k = 0; k < kBlockSize/4; ++k) { - xv[0] = xb[2*k+0]; xv[1] = xb[2*k+kBlockSize/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+kBlockSize/2]; - w[0] = weight[2*k+0]; w[1] = weight[2*k+kBlockSize/2]; w[2] = weight[2*k+1]; w[3] = weight[2*k+1+kBlockSize/2]; - auto vp = prune_iq4ks(q4[k], values, xv, w, dl); + xv[0] = xb[2*k+0]; xv[1] = xb[2*k+kBlockSize/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+kBlockSize/2]; + wv[0] = weight[2*k+0]; wv[1] = weight[2*k+kBlockSize/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+kBlockSize/2]; + auto vp = prune_iq4ks(q4[k], values, xv, wv, dl); vps[k] = table[vp & 0x7fff]; for (int j = 0; j < 4; ++j) { float q = values[(vp >> 4*j) & 0xf]; - sumqx += w[j]*q*xv[j]; - sumq2 += w[j]*q*q; + sumqx += wv[j]*q*xv[j]; + sumq2 += wv[j]*q*q; } } for (int k = 0; k < kBlockSize/8; ++k) { @@ -2974,8 +3200,7 @@ size_t quantize_iq4_kss(const float * src, void * dst, int64_t nrows, int64_t n_ auto qrow = (char *)dst; auto table = scramble_table(); for (int row = 0; row < nrows; ++row) { - quantize_row_iq4_k_impl_bs128(QK_K, kBlockSize, n_per_row, src, work.data(), all_scales.data(), weight, iq4k_values, imatrix, 7); - prune_iq4ks_to_iq4kss(n_per_row, table, work.data(), src, qrow, imatrix, weight, all_scales.data()); + quantize_row_iq4_kss_impl(n_per_row, src, qrow, all_scales.data(), weight, iq4k_values, imatrix, table, 7); src += n_per_row; qrow += row_size; }