diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 9cc9dfb5..8c72fd49 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -2910,7 +2910,6 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy, const int8_t * shifted_values = values + 16; uint16_t vps[block_size/2], vms[block_size/2], vs[block_size/2]; - //uint8_t Lp[block_size], Lm[block_size], L[block_size]; float xv[4], wv[4]; float amax_scale = 0; @@ -2949,11 +2948,6 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy, float id = (itry + values[0])/max; float sumqx_p = 0, sumq2_p = 0; float sumqx_m = 0, sumq2_m = 0; - //for (int j = 0; j < block_size; ++j) { - // float al = id*xb[j]; - // Lp[j] = best_index_iq4nl(values, al); - // Lm[j] = best_index_iq4nl(values, -al); - //} float this_d = 1/id; for (int k = 0; k < block_size/4; ++k) { xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2]; @@ -2977,24 +2971,7 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy, } vps[k] = vp; vms[k] = vm; - //uint16_t vp = Lp[2*k+0] | (Lp[2*k+0+block_size/2] << 4) | (Lp[2*k+1] << 8) | (Lp[2*k+1+block_size/2] << 12); - //uint16_t vm = Lm[2*k+0] | (Lm[2*k+0+block_size/2] << 4) | (Lm[2*k+1] << 8) | (Lm[2*k+1+block_size/2] << 12); - //auto vps = prune_iq4ks(vp, values, xv, wv, this_d); - //auto vms = prune_iq4ks(vm, values, xv, wv, -this_d); - //Lp[2*k+0] = (vps >> 0) & 0xf; Lp[2*k+0+block_size/2] = (vps >> 4) & 0xf; - //Lp[2*k+1] = (vps >> 8) & 0xf; Lp[2*k+1+block_size/2] = (vps >> 12) & 0xf; - //Lm[2*k+0] = (vms >> 0) & 0xf; Lm[2*k+0+block_size/2] = (vms >> 4) & 0xf; - //Lm[2*k+1] = (vms >> 8) & 0xf; Lm[2*k+1+block_size/2] = (vms >> 12) & 0xf; } - //for (int j = 0; j < block_size; ++j) { - // float w = weight[j]; - // float q = values[Lp[j]]; - // sumqx_p += w*q*xb[j]; - // sumq2_p += w*q*q; - // q = values[Lm[j]]; - // sumqx_m += w*q*xb[j]; - // sumq2_m += w*q*q; - //} bool copy_p = false, copy_m = false; if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; copy_p = true; @@ -3047,58 +3024,7 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy, } else if (copy_p) { std::memcpy(vs, vps, block_size); } - //for (int j = 0; j < block_size; ++j) { - // float al = id*xb[j]; - // Lp[j] = best_index_iq4nl(shifted_values, al); - // Lm[j] = best_index_iq4nl(shifted_values, -al); - //} - //for (int k = 0; k < block_size/4; ++k) { - // xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2]; - // wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2]; - // uint16_t vp = Lp[2*k+0] | (Lp[2*k+0+block_size/2] << 4) | (Lp[2*k+1] << 8) | (Lp[2*k+1+block_size/2] << 12); - // uint16_t vm = Lm[2*k+0] | (Lm[2*k+0+block_size/2] << 4) | (Lm[2*k+1] << 8) | (Lm[2*k+1+block_size/2] << 12); - // auto vps = prune_iq4ks(vp, shifted_values, xv, wv, this_d); - // auto vms = prune_iq4ks(vm, shifted_values, xv, wv, -this_d); - // Lp[2*k+0] = (vps >> 0) & 0xf; Lp[2*k+0+block_size/2] = (vps >> 4) & 0xf; - // Lp[2*k+1] = (vps >> 8) & 0xf; Lp[2*k+1+block_size/2] = (vps >> 12) & 0xf; - // Lm[2*k+0] = (vms >> 0) & 0xf; Lm[2*k+0+block_size/2] = (vms >> 4) & 0xf; - // Lm[2*k+1] = (vms >> 8) & 0xf; Lm[2*k+1+block_size/2] = (vms >> 12) & 0xf; - //} - //for (int j = 0; j < block_size; ++j) { - // float w = weight[j]; - // float q = shifted_values[Lp[j]]; - // sumqx_p += w*q*xb[j]; - // sumq2_p += w*q*q; - // q = shifted_values[Lm[j]]; - // sumqx_m += w*q*xb[j]; - // sumq2_m += w*q*q; - //} - //copy_p = copy_m = false; - //if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { - // d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; copy_p = true; - //} - //if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { - // d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; copy_m = true; - //} - //if (copy_m) { - // std::memcpy(L, Lm, block_size); - //} else if (copy_p) { - // std::memcpy(L, Lp, block_size); - //} } - for (int k = 0; k < block_size/8; ++k) { - auto v1 = table[vs[2*k+0] & 0x7fff]; - auto v2 = table[vs[2*k+1] & 0x7fff]; - y[ibl].qs[(block_size/8)*ib + k] = v1 | (v2 << 15); - } - //for (int k = 0; k < block_size/8; ++k) { - // uint16_t v1 = Lp[4*k+0] | (Lp[4*k+0+block_size/2] << 4) | (Lp[4*k+1] << 8) | (Lp[4*k+1+block_size/2] << 12); - // uint16_t v2 = Lp[4*k+2] | (Lp[4*k+2+block_size/2] << 4) | (Lp[4*k+3] << 8) | (Lp[4*k+3+block_size/2] << 12); - // v1 = table[v1 & 0x7fff]; - // v2 = table[v2 & 0x7fff]; - // y[ibl].qs[(block_size/8)*ib + k] = v1 | (v2 << 15); - //} - if (is_shifted) y[ibl].qs[(block_size/8)*ib] |= (1u << 30); scales[ib] = d; amax_scale = std::max(amax_scale, std::abs(d)); } @@ -3107,16 +3033,79 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy, *dptr = d; if (!d) return; float id = 1/d; + float sumqx = 0, sumq2 = 0; for (int ibl = 0; ibl < n_per_row/super_block_size; ++ibl) { auto scales = all_scales + (super_block_size/block_size)*ibl; + const float * xbl = x + ibl*super_block_size; + float sigma2 = 0; + for (int j = 0; j < super_block_size; ++j) sigma2 += xbl[j]*xbl[j]; + sigma2 *= 2.f/super_block_size; for (int ib = 0; ib < super_block_size/block_size; ++ib) { + const float * xb = xbl + ib*block_size; + if (quant_weights) { + const float * qw = quant_weights + ibl*super_block_size + ib*block_size; + for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j]; + } int l = nearest_int(0.5f*(id*scales[ib]+127.f)); - l = std::max(0, std::min(127, l)) << 1; - for (int k = 0; k < 4; ++k) { - y[ibl].qs[4*ib+k] |= (((l >> 2*k) & 3) << 30); + l = (std::max(0, std::min(127, l)) << 1) - 127; + if (l) { + float dl = d*l; + float idl = 1/dl; + float mse_p = 0, mse_m = 0; + for (int k = 0; k < block_size/4; ++k) { + xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2]; + wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2]; + uint16_t vp = 0, vm = 0; + for (int j = 0; j < 4; ++j) { + float al = idl*xv[j]; + vp |= (best_index_iq4nl( values, al) << 4*j); + vm |= (best_index_iq4nl(shifted_values, al) << 4*j); + } + vp = prune_iq4ks(vp, values, xv, wv, dl); + vm = prune_iq4ks(vm, shifted_values, xv, wv, dl); + for (int j = 0; j < 4; ++j) { + float w = wv[j]; + float q = values[(vp >> 4*j) & 0xf]; + mse_p += w*(xv[j] - dl*q)*(xv[j] - dl*q); + q = shifted_values[(vm >> 4*j) & 0xf]; + mse_m += w*(xv[j] - dl*q)*(xv[j] - dl*q); + } + vps[k] = vp; + vms[k] = vm; + } + const uint16_t * v = vps; + const int8_t * block_values = values; + if (mse_m < mse_p) { + v = vms; + block_values = values + 16; + } + for (int k = 0; k < block_size/4; ++k) { + xv[0] = xb[2*k+0]; xv[1] = xb[2*k+0+block_size/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+block_size/2]; + wv[0] = weight[2*k+0]; wv[1] = weight[2*k+0+block_size/2]; wv[2] = weight[2*k+1]; wv[3] = weight[2*k+1+block_size/2]; + for (int j = 0; j < 4; ++j) { + float q = block_values[(v[k] >> 4*j) & 0xf] * l; + sumqx += wv[j]*q*xv[j]; + sumq2 += wv[j]*q*q; + } + } + l += 127; + if (mse_m < mse_p) l |= 1; + for (int k = 0; k < block_size/8; ++k) { + auto v1 = table[v[2*k+0] & 0x7fff]; + auto v2 = table[v[2*k+1] & 0x7fff]; + y[ibl].qs[(block_size/8)*ib + k] = v1 | (v2 << 15) | (((l >> 2*k) & 3) << 30); + } + } else { + l += 127; + for (int k = 0; k < block_size/8; ++k) { + y[ibl].qs[(block_size/8)*ib + k] |= (((l >> 2*k) & 3) << 30); + } } } } + if (sumq2 > 0) *dptr = sumqx/sumq2; } void prune_iq4ks_to_iq4kss(int n_per_row, const uint16_t * table, const char * cx, const float * x, char *cy,