iq3_kt WIP: slowly improving

PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7689 after shrinking
by 0.015 bpw by using iq4_k instead of q5_k for attn_v.
This commit is contained in:
Iwan Kawrakow
2024-11-09 11:42:14 +02:00
parent dfcc8a9cf3
commit 8f0d075f5e
2 changed files with 12 additions and 3 deletions

View File

@@ -3941,6 +3941,14 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f
int nblock = n_per_row / Q::kSuperBlockSize;
float amax_row = 0;
for (int j = 0; j < n_per_row; ++j) amax_row = std::max(amax_row, std::abs(x[j]));
if (!amax_row) {
*dptr = 0.f;
std::memset(y, 0, nblock*sizeof(block_iq3_kt));
return;
}
float amax_scale = 0, max_scale = 0;
for (int ibl = 0; ibl < nblock; ++ibl) {
@@ -3969,15 +3977,16 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f
}
scales[ib] = 0;
if (!amax) continue;
float scale_0 = std::max(80.f, 127.f*amax/amax_row);
float best = 0;
for (int itry = -3; itry <= 3; ++itry) {
quantizer.find_best_match(amax/(96.f + kStep*itry), xb, weight, best_idx);
quantizer.find_best_match(amax/(scale_0 + kStep*itry), xb, weight, best_idx);
auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx);
if (score_p > best) {
best = score_p;
scales[ib] = dp;
}
quantizer.find_best_match(-amax/(96.f + kStep*itry), xb, weight, best_idx);
quantizer.find_best_match(-amax/(scale_0 + kStep*itry), xb, weight, best_idx);
auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx);
if (score_m > best) {
best = score_m;

View File

@@ -15819,7 +15819,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
: !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT) {
new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : qs.model.hparams.n_gqa() >= 2 ? GGML_TYPE_IQ3_K
new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_IQ4_K : qs.model.hparams.n_gqa() >= 2 ? GGML_TYPE_IQ3_K
: !qs.has_imatrix ? GGML_TYPE_IQ3_K : GGML_TYPE_IQ3_KT;
}
else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 2) {