From 4c76471979374e0f06c98d04a00946347f592e42 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 8 Oct 2024 10:52:53 +0300 Subject: [PATCH] iq4_k_xxs: basics --- examples/quantize/quantize.cpp | 1 + ggml/include/ggml.h | 2 + ggml/src/ggml-common.h | 6 + ggml/src/ggml.c | 22 ++++ ggml/src/iqk/iqk_quantize.cpp | 213 +++++++++++++++++++++++++++++++++ ggml/src/iqk/iqk_quantize.h | 6 + include/llama.h | 1 + src/llama.cpp | 12 +- 8 files changed, 259 insertions(+), 4 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 2b240299..581f14d8 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -43,6 +43,7 @@ static const std::vector QUANT_OPTIONS = { { "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", }, { "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", }, { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, + { "IQ4_XXS", LLAMA_FTYPE_MOSTLY_IQ4_XXS, " 4.06 bpw non-linear quantization", }, { "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",}, { "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", }, { "IQ4_K", LLAMA_FTYPE_MOSTLY_IQ4_K, " 4.5 bpw non-linear quantization", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 13aaeafb..c21a61d5 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -403,6 +403,7 @@ extern "C" { GGML_TYPE_IQ6_K = 141, GGML_TYPE_IQ2_TN = 142, GGML_TYPE_IQ1_TN = 143, + GGML_TYPE_IQ4_XXS = 144, GGML_TYPE_COUNT, }; @@ -458,6 +459,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ6_K = 134, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_TN = 135, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_TN = 136, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_XXS = 137, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 02ecf071..ddd6b528 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -441,6 +441,12 @@ typedef struct { } block_iq4_xs; static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); +typedef struct { + uint8_t scales[2]; + uint8_t qs[QK_K/2]; +} block_iq4_xxs; +static_assert(sizeof(block_iq4_xxs) == 2 + QK_K/2, "wrong iq4_xxs block size/padding"); + typedef struct { ggml_half d; uint16_t extra; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7ad666eb..44b5ad54 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1087,6 +1087,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_IQ4_XXS] = { + .type_name = "iq4_xxs", + .blck_size = QK_K, + .type_size = sizeof(block_iq4_xxs), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq4_xxs, + .from_float = quantize_row_iq4_xxs, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_xxs_ref, + .vec_dot = vec_dot_iq4_xxs_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 4, + }, [GGML_TYPE_Q8_K] = { .type_name = "q8_K", .blck_size = QK_K, @@ -3891,6 +3904,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ1_TN: wtype = GGML_TYPE_IQ1_TN; break; case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; + case GGML_FTYPE_MOSTLY_IQ4_XXS: wtype = GGML_TYPE_IQ4_XXS; break; case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break; case GGML_FTYPE_MOSTLY_IQ3_K: wtype = GGML_TYPE_IQ3_K; break; case GGML_FTYPE_MOSTLY_IQ4_K: wtype = GGML_TYPE_IQ4_K; break; @@ -10390,6 +10404,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_XXS: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: @@ -10778,6 +10793,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_XXS: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: @@ -10916,6 +10932,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_XXS: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: @@ -14095,6 +14112,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_XXS: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: @@ -14473,6 +14491,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_XXS: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: @@ -14745,6 +14764,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_XXS: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: @@ -15344,6 +15364,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_XXS: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: @@ -22160,6 +22181,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_TN: result = quantize_iq1_tn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ4_XXS: result = quantize_iq4_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_K: result = quantize_iq3_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_K: result = quantize_iq4_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 3ff6b4da..02df8b16 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -2166,3 +2166,216 @@ void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) { #endif } + +namespace { +static void quantize_row_iq4_k_impl_bs128(const int super_block_size, const int block_size, + int n_per_row, const float * x, char * cy, + float * all_scales, float * weight, + const int8_t * values, + const float * quant_weights, + const int ntry) { + + GGML_ASSERT(super_block_size == 256 && block_size == 128); + + float * dptr = (float *)cy; + block_iq4_xxs * y = (block_iq4_xxs *)(dptr + 1); + + const int8_t * shifted_values = values + 16; + + float amax_scale = 0; + + for (int ibl = 0; ibl < n_per_row/super_block_size; ++ibl) { + memset(&y[ibl], 0, sizeof(block_iq4_xxs)); + const float * xbl = x + ibl*super_block_size; + auto scales = all_scales + ibl*(super_block_size/block_size); + float sigma2 = 0; + for (int j = 0; j < super_block_size; ++j) sigma2 += xbl[j]*xbl[j]; + sigma2 *= 2.f/super_block_size; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { + const float * xb = xbl + ib*block_size; + if (quant_weights) { + const float * qw = quant_weights + ibl*super_block_size + ib*block_size; + for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j]; + } + float amax = 0, max = 0; + for (int j = 0; j < block_size; ++j) { + float ax = fabsf(xb[j]); + if (ax > amax) { + amax = ax; max = xb[j]; + } + } + if (!amax) { + scales[ib] = 0; + continue; + } + float d = ntry > 0 ? -max/values[0] : max/values[0]; + float id = 1/d; + float sumqx_p = 0, sumq2_p = 0; + float sumqx_m = 0, sumq2_m = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq4nl(values, al); + float q = values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq4nl(values, -al); + q = values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + d = sumqx_p/sumq2_p; + bool is_shifted = false; + float best = d*sumqx_p; + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d*sumqx_m; + } + for (int itry = -ntry; itry <= ntry; ++itry) { + id = (itry + values[0])/max; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq4nl(values, al); + float q = values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq4nl(values, -al); + q = values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; + } + id = (itry + shifted_values[0])/max; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq4nl(shifted_values, al); + float q = shifted_values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq4nl(shifted_values, -al); + q = shifted_values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; + } + } + if (is_shifted) y[ibl].scales[ib] = 0x01; + scales[ib] = d; + amax_scale = std::max(amax_scale, std::abs(d)); + } + } + float d = amax_scale/127; + *dptr = d; + if (!d) return; + float id = d ? 1/d : 0.f; + float sumqx = 0, sumq2 = 0; + float mse = 0; + for (int ibl = 0; ibl < n_per_row/super_block_size; ++ibl) { + const float * xbl = x + ibl*super_block_size; + float sigma2 = 0; + for (int j = 0; j < super_block_size; ++j) sigma2 += xbl[j]*xbl[j]; + sigma2 *= 2.f/super_block_size; + auto scales = all_scales + (super_block_size/block_size)*ibl; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { + const int8_t * block_values = y[ibl].scales[ib] & 0x01 ? shifted_values : values; + int l = nearest_int(0.5f*(id*scales[ib]+127.f)); + l = std::max(0, std::min(127, l)) << 1; + //printf("d = %g, id = %g, scales = %g, l = %d, dl = %g\n", d, id, scales[ib], l, d*(l - 127)); + y[ibl].scales[ib] |= l; + l -= 127; + float dl = d * l; + float idl = dl ? 1/dl : 0.f; + const float * xb = xbl + ib*block_size; + if (quant_weights) { + const float * qw = quant_weights + ibl*super_block_size + ib*block_size; + for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j]; + } + auto qs = y[ibl].qs + ib*(block_size/2); + for (int j = 0; j < block_size/2; ++j) { + uint8_t i1 = best_index_iq4nl(block_values, idl*xb[j]); + uint8_t i2 = best_index_iq4nl(block_values, idl*xb[j+block_size/2]); + qs[j] = i1 | (i2 << 4); + float w1 = weight[j]; + float w2 = weight[j+block_size/2]; + float q1 = block_values[i1]*l; + float q2 = block_values[i2]*l; + sumqx += w1*q1*xb[j] + w2*q2*xb[j+block_size/2]; + sumq2 += w1*q1*q1 + w2*q2*q2; + float diff = xb[j] - d*q1; mse += diff*diff; + diff = xb[j+block_size/2] - d*q2; mse += diff*diff; + } + } + } + //printf("rmse = %g\n", sqrt(mse/n_per_row)); + if (sumq2 > 0) *dptr = sumqx/sumq2; +} +} + +void quantize_row_iq4_xxs_ref(const float * x, block_iq4_xxs * y, int64_t k) { + quantize_iq4_xxs(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq4_xxs(const float * x, void * y, int64_t k) { + quantize_iq4_xxs(x, (void *)y, 1, k, nullptr); +} + +size_t quantize_iq4_xxs(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + //printf("============ %s(%d, %d)\n", __func__, int(nrows), int(n_per_row)); + constexpr int kBlockSize = 128; + GGML_ASSERT(n_per_row%QK_K == 0); + auto row_size = ggml_row_size(GGML_TYPE_IQ4_XXS, n_per_row); + char * qrow = (char *)dst; + float weight[kBlockSize]; + std::vector all_scales(n_per_row/kBlockSize); + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_iq4_k_impl_bs128(QK_K, kBlockSize, n_per_row, src, qrow, all_scales.data(), weight, iq4k_values, imatrix, 7); + src += n_per_row; + qrow += row_size; + } + return nrows * row_size; +} + +void dequantize_row_iq4_xxs(const block_iq4_xxs * x, float * y, int64_t k) { + constexpr int kBlockSize = 128; + GGML_ASSERT(k%QK_K == 0); + const float * dptr = (const float *)x; + float d = *dptr; + x = (const block_iq4_xxs *)(dptr + 1); + int nblock = k/QK_K; + for (int ibl = 0; ibl < nblock; ++ibl) { + auto qs = x[ibl].qs; + for (int ib = 0; ib < QK_K/kBlockSize; ++ib) { + float dl = d * ((int)(x[ibl].scales[ib] & 254) - 127); + const int8_t * values = iq4k_values + ((x[ibl].scales[ib] & 1) << 4); + for (int j = 0; j < kBlockSize/2; ++j) { + y[j ] = dl * values[qs[j] & 0xf]; + y[j+kBlockSize/2] = dl * values[qs[j] >> 4]; + } + y += kBlockSize; + qs += kBlockSize/2; + } + } +} + +void vec_dot_iq4_xxs_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +} + diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index e5c16fc9..bab46501 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -55,6 +55,12 @@ size_t quantize_iq1_tn(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst void dequantize_row_iq1_tn(const block_iq1_tn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq1_tn_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_iq4_xxs_ref(const float * GGML_RESTRICT x, block_iq4_xxs * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq4_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq4_xxs(const block_iq4_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq4_xxs_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); #ifdef __cplusplus diff --git a/include/llama.h b/include/llama.h index 43c0091e..1fefc182 100644 --- a/include/llama.h +++ b/include/llama.h @@ -177,6 +177,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ6_K = 142, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_TN = 143, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_TN = 144, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ4_XXS = 145, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama.cpp b/src/llama.cpp index 9ed109c6..40b24cb2 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3793,6 +3793,7 @@ struct llama_model_loader { case GGML_TYPE_IQ2_TN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_TN; break; case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; + case GGML_TYPE_IQ4_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XXS; break; case GGML_TYPE_IQ2_K: ftype = LLAMA_FTYPE_MOSTLY_IQ2_K; break; case GGML_TYPE_IQ3_K: ftype = LLAMA_FTYPE_MOSTLY_IQ3_K; break; case GGML_TYPE_IQ4_K: ftype = LLAMA_FTYPE_MOSTLY_IQ4_K; break; @@ -4494,6 +4495,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_XXS: return "IQ4_XXS - 4.06 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_K: return "IQ2_K - 2.375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_K: return "IQ3_K - 3.4325 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_K: return "IQ4_K - 4.5 bpw"; @@ -15623,7 +15625,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K) { new_type = !qs.has_output ? GGML_TYPE_IQ4_K : GGML_TYPE_Q5_K; } - else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_output) { + else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XXS) && !qs.has_output) { new_type = GGML_TYPE_IQ5_K; } else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_IQ6_K) { @@ -15710,7 +15712,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; - else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 2) { + else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XXS) && qs.model.hparams.n_gqa() >= 2) { new_type = GGML_TYPE_IQ5_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_K && qs.model.hparams.n_gqa() >= 2) { @@ -15787,7 +15789,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K; } } - else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) { + else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XXS) && !qs.has_imatrix) { new_type = GGML_TYPE_Q5_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K; @@ -15867,7 +15869,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S || new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K || new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_IQ2_TN || - new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ1_TN) { + new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ1_TN || new_type == GGML_TYPE_IQ4_XXS) { int nx = tensor->ne[0]; int ny = tensor->ne[1]; if (nx % QK_K != 0) { @@ -15898,6 +15900,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n case GGML_TYPE_Q3_K: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ4_XXS: case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; case GGML_TYPE_IQ4_K: case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; @@ -16008,6 +16011,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ2_TN: default_type = GGML_TYPE_IQ2_TN; break; case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break; case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; + case LLAMA_FTYPE_MOSTLY_IQ4_XXS: default_type = GGML_TYPE_IQ4_XXS; break; case LLAMA_FTYPE_MOSTLY_IQ2_K: default_type = GGML_TYPE_IQ2_K; break; case LLAMA_FTYPE_MOSTLY_IQ3_K: default_type = GGML_TYPE_IQ3_K; break; case LLAMA_FTYPE_MOSTLY_IQ4_K: default_type = GGML_TYPE_IQ4_K; break;