From 00b4bff2864911c2b67bd79d413531603dd4328f Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 11 Nov 2024 12:34:00 +0200 Subject: [PATCH] Adding iq4_kt - not competitive at this point --- examples/quantize/quantize.cpp | 1 + ggml/include/ggml.h | 2 + ggml/src/ggml-common.h | 6 + ggml/src/ggml-cuda.cu | 1 + ggml/src/ggml-cuda/convert.cu | 43 ++++++ ggml/src/ggml-cuda/dmmv.cu | 21 ++- ggml/src/ggml-quants.c | 1 + ggml/src/ggml.c | 22 ++++ ggml/src/iqk/iqk_quantize.cpp | 231 +++++++++++++++++++++++++++++++++ ggml/src/iqk/iqk_quantize.h | 6 + include/llama.h | 1 + src/llama.cpp | 11 +- 12 files changed, 343 insertions(+), 3 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index dee0249f..d266ed97 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -33,6 +33,7 @@ static const std::vector QUANT_OPTIONS = { { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", }, { "IQ3_KT", LLAMA_FTYPE_MOSTLY_IQ3_KT, " 3.125 bpw quantization", }, + { "IQ4_KT", LLAMA_FTYPE_MOSTLY_IQ4_KT, " 4.125 bpw quantization", }, { "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", }, { "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M, " 3.66 bpw quantization mix", }, { "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 4430ac28..380e7dfd 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -408,6 +408,7 @@ extern "C" { GGML_TYPE_IQ4_KSS = 146, GGML_TYPE_IQ2_KT = 147, GGML_TYPE_IQ3_KT = 148, + GGML_TYPE_IQ4_KT = 149, GGML_TYPE_COUNT, }; @@ -468,6 +469,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_KT = 140, // except 1d tensors GGML_FTYPE_MOSTLY_IQ3_KT = 141, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_KT = 142, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 02231960..24934498 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -467,6 +467,12 @@ typedef struct { } block_iq3_kt; static_assert(sizeof(block_iq3_kt) == QK_K/4 + QK_K/8 + QK_K/64, "wrong iq3_kt block size/padding"); +typedef struct { + int8_t scales[QK_K/64]; + uint8_t ql[QK_K/2]; +} block_iq4_kt; +static_assert(sizeof(block_iq4_kt) == QK_K/2 + QK_K/64, "wrong iq4_kt block size/padding"); + typedef struct { ggml_half d; uint16_t extra; diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 7a582e55..61ccba23 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2850,6 +2850,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 43f07465..2ebfe573 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -402,6 +402,38 @@ static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst } } +template +static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { + + int64_t ii = blockIdx.x; + int64_t row = (QK_K * ii) / n_per_row; + const char * cx = (const char *)vx + row * row_size; + float scale = *(const float *)cx; + const block_iq4_kt * x = (const block_iq4_kt *)(cx + sizeof(float)); + const int64_t i = ii - (row*n_per_row)/QK_K; + + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t kmask = 0x8fff8fff; + constexpr uint32_t km32 = 0x3b603b60; + + const int64_t tid = threadIdx.x; + const int64_t ib = tid; // 0...31 + dst_t * y = yy + ii*QK_K + 8*ib; + const uint16_t * ql = (const uint16_t *)x[i].ql; + uint32_t idx1 = ql[2*ib+0] + 4096; + uint32_t idx2 = ql[2*ib+1] + 4096; + const float dl = scale * x[i].scales[ib/8] * 31.75f; + uint32_t s[2]; + const half * h = (const half *)s; + for (int j = 0; j < 4; ++j) { + idx1 = ka*idx1 + kb; s[0] = (idx1 & kmask) ^ km32; + idx2 = ka*idx2 + kb; s[1] = (idx2 & kmask) ^ km32; + y[j+0] = dl * (float)(h[0] + h[1]); + y[j+4] = dl * (float)(h[2] + h[3]); + } +} + template static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -945,6 +977,13 @@ static void dequantize_row_iq3_kt_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_iq3_kt<<>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ3_KT, n_per_row)); } +template +static void dequantize_row_iq4_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int nb = k / QK_K; + dequantize_block_iq4_kt<<>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ4_KT, n_per_row)); +} + template static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; @@ -1185,6 +1224,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq2_kt_cuda; case GGML_TYPE_IQ3_KT: return dequantize_row_iq3_kt_cuda; + case GGML_TYPE_IQ4_KT: + return dequantize_row_iq4_kt_cuda; case GGML_TYPE_IQ2_XS: return dequantize_row_iq2_xs_cuda; case GGML_TYPE_IQ2_S: @@ -1260,6 +1301,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq2_kt_cuda; case GGML_TYPE_IQ3_KT: return dequantize_row_iq3_kt_cuda; + case GGML_TYPE_IQ4_KT: + return dequantize_row_iq4_kt_cuda; case GGML_TYPE_IQ2_XS: return dequantize_row_iq2_xs_cuda; case GGML_TYPE_IQ2_S: diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu index c784610f..440ec727 100644 --- a/ggml/src/ggml-cuda/dmmv.cu +++ b/ggml/src/ggml-cuda/dmmv.cu @@ -169,6 +169,10 @@ static __global__ void dequantize_mul_mat_vec_iq3_kt(const void * __restrict__ v } } +static __global__ void dequantize_mul_mat_vec_iq4_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst, + const int ncols, int nrows, int64_t row_size) { +} + static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); @@ -735,6 +739,16 @@ static void dequantize_mul_mat_vec_iq3_kt_cuda(const void * vx, const dfloat * y dequantize_mul_mat_vec_iq3_kt<<>>(vx, y, dst, ncols, nrows, row_size); } +static void dequantize_mul_mat_vec_iq4_kt_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + constexpr int ny = 2; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(32, ny, 1); + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_KT, ncols); + dequantize_mul_mat_vec_iq4_kt<<>>(vx, y, dst, ncols, nrows, row_size); +} + static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int ny = 2 / K_QUANTS_PER_ITERATION; @@ -797,7 +811,7 @@ void ggml_cuda_op_dequantize_mul_mat_vec( src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16 || - src0->type == GGML_TYPE_IQ2_KT || src0->type == GGML_TYPE_IQ3_KT; + src0->type == GGML_TYPE_IQ2_KT || src0->type == GGML_TYPE_IQ3_KT || src0->type == GGML_TYPE_IQ4_KT; if (src1_convert_f16) { src1_dfloat = src1_dfloat_a.alloc(ne00); @@ -834,6 +848,9 @@ void ggml_cuda_op_dequantize_mul_mat_vec( case GGML_TYPE_IQ3_KT: dequantize_mul_mat_vec_iq3_kt_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); break; + case GGML_TYPE_IQ4_KT: + dequantize_mul_mat_vec_iq4_kt_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; case GGML_TYPE_Q3_K: dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); break; @@ -867,6 +884,6 @@ bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) { src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K || src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K || - src0_type == GGML_TYPE_IQ2_KT || src0_type == GGML_TYPE_IQ3_KT || + src0_type == GGML_TYPE_IQ2_KT || src0_type == GGML_TYPE_IQ3_KT || src0_type == GGML_TYPE_IQ4_KT || src0_type == GGML_TYPE_F16; } diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 81f2a27d..e6295b9d 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15192,6 +15192,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_IQ2_KS: break; case GGML_TYPE_IQ2_KT: break; case GGML_TYPE_IQ3_KT: break; + case GGML_TYPE_IQ4_KT: break; case GGML_TYPE_IQ3_K: break; case GGML_TYPE_IQ4_K: break; case GGML_TYPE_IQ5_K: break; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 94db044a..a0285d3f 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1219,6 +1219,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 4, }, + [GGML_TYPE_IQ4_KT] = { + .type_name = "iq4_kt", + .blck_size = QK_K, + .type_size = sizeof(block_iq4_kt), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq4_kt, + .from_float = quantize_row_iq4_kt, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_kt_ref, + .vec_dot = vec_dot_iq4_kt_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 4, + }, [GGML_TYPE_IQ3_K] = { .type_name = "iq3_k", .blck_size = QK_K, @@ -3936,6 +3949,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ2_KS: wtype = GGML_TYPE_IQ2_KS; break; case GGML_FTYPE_MOSTLY_IQ2_KT: wtype = GGML_TYPE_IQ2_KT; break; case GGML_FTYPE_MOSTLY_IQ3_KT: wtype = GGML_TYPE_IQ3_KT; break; + case GGML_FTYPE_MOSTLY_IQ4_KT: wtype = GGML_TYPE_IQ4_KT; break; case GGML_FTYPE_MOSTLY_IQ3_K: wtype = GGML_TYPE_IQ3_K; break; case GGML_FTYPE_MOSTLY_IQ4_K: wtype = GGML_TYPE_IQ4_K; break; case GGML_FTYPE_MOSTLY_IQ5_K: wtype = GGML_TYPE_IQ5_K; break; @@ -10461,6 +10475,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: @@ -10905,6 +10920,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: @@ -11046,6 +11062,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: @@ -14233,6 +14250,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: @@ -14614,6 +14632,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: @@ -14889,6 +14908,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: @@ -15491,6 +15511,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: @@ -22319,6 +22340,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ2_KS: result = quantize_iq2_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_KT: result = quantize_iq2_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_KT: result = quantize_iq3_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ4_KT: result = quantize_iq4_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_K: result = quantize_iq3_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_K: result = quantize_iq4_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ5_K: result = quantize_iq5_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index c27a035a..71195f5d 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -4163,3 +4163,234 @@ void vec_dot_iq3_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx #endif } + +// ======================================== iq4_kt + +namespace{ + +using QuantizerIQ4KT = QuantizerIQKT<64, 4, 16, 128>; + +const QuantizerIQ4KT& iq4kt_quantizer() { + static std::mutex mutex; + std::lock_guard lock(mutex); + static QuantizerIQ4KT quantizer; + return quantizer; +} + +void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights) { + + constexpr float kSigmaScale = 2.0f; + using Q = QuantizerIQ4KT; + + static_assert(Q::kNumVal%8 == 0); + + float * dptr = (float *)vy; + + block_iq4_kt * y = (block_iq4_kt *)(dptr + 1); + + int best_idx[2*Q::kNg]; + + auto& quantizer = iq4kt_quantizer(); + + int nblock = n_per_row / Q::kSuperBlockSize; + + Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights); + + float amax_row = 0; + for (int j = 0; j < n_per_row; ++j) amax_row = std::max(amax_row, std::abs(x[j])); + if (!amax_row) { + *dptr = 0.f; + std::memset(y, 0, nblock*sizeof(block_iq4_kt)); + return; + } + + float amax_scale = 0, max_scale = 0; + + for (int ibl = 0; ibl < nblock; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq4_kt)); + + const float * xbl = x + ibl*Q::kSuperBlockSize; + auto scales = all_scales + ibl*Q::kNblock; + + auto ql = (uint16_t *)y[ibl].ql; + + for (int ib = 0; ib < Q::kNblock; ++ib) { + const float * xb = xbl + Q::kBlockSize*ib; + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + float amax = 0; + for (int j = 0; j < Q::kBlockSize; ++j) { + float ax = std::abs(xb[j]); + amax = std::max(amax, ax); + } + if (!amax) { + scales[ib] = 0; + ql += Q::kNg; + continue; + } + //float scale_0 = 127.f*amax/amax_row; + //float scale_0 = std::max(64.f, 127.f*amax/amax_row); + float scale_0 = std::max(80.f, 127.f*amax/amax_row); + quantizer.find_best_match( amax/scale_0, xb, weight, best_idx); + auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx); + quantizer.find_best_match(-amax/scale_0, xb, weight, best_idx+Q::kNg); + auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx+Q::kNg); + if (score_p > score_m) { + scales[ib] = dp; + for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j]; + } else { + scales[ib] = dm; + for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j+Q::kNg]; + } + + //float mse = 0; + //for (int j = 0; j < Q::kNg; ++j) { + // auto values = quantizer.values() + ql[j]*Q::kGroupSize; + // for (int k = 0; k < Q::kGroupSize; ++k) { + // float diff = scales[ib]*values[k] - xb[j*Q::kGroupSize+k]; + // mse += diff*diff; + // } + //} + //printf("rmse(%d) = %g\n", ib, sqrt(mse/Q::kBlockSize)); + ql += Q::kNg; + + //scales[ib] = score_p > score_m ? dp : dm; + + float abs_scale = std::abs(scales[ib]); + if (abs_scale > amax_scale) { + amax_scale = abs_scale; + max_scale = scales[ib]; + } + } + + } + + float d = -max_scale/128; + float id = d ? 1/d : 0.f; + //float mse = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + auto scales = all_scales + ibl*Q::kNblock; + //const float * xbl = x + ibl*Q::kSuperBlockSize; + //auto ql = (uint16_t *)y[ibl].ql; + for (int ib = 0; ib < Q::kNblock; ++ib) { + int ls = nearest_int(id*scales[ib]); + y[ibl].scales[ib] = std::min(ls, 127); + //float dl = d*y[ibl].scales[ib]; + //const float * xb = xbl + Q::kBlockSize*ib; + //for (int j = 0; j < Q::kNg; ++j) { + // auto values = quantizer.values() + ql[j]*Q::kGroupSize; + // for (int k = 0; k < Q::kGroupSize; ++k) { + // float diff = dl*values[k] - xb[j*Q::kGroupSize+k]; + // mse += diff*diff; + // } + // ql += Q::kNg; + //} + } + } + //printf("rmse = %g\n", sqrt(mse/n_per_row)); + + *dptr = d; + if (!d) return; + + for (int iloop = 0; iloop < 1; ++iloop) { + + float sumqx = 0, sumq2 = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + + auto qs = (uint16_t *)y[ibl].ql; + const float * xbl = x + ibl*Q::kSuperBlockSize; + + for (int ib = 0; ib < Q::kNblock; ++ib) { + const float * xb = xbl + Q::kBlockSize*ib; + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + int ls = y[ibl].scales[ib]; + float dl = d*ls; + quantizer.find_best_match(dl, xb, weight, best_idx); + + for (int j = 0; j < Q::kNg; ++j) { + qs[j] = best_idx[j]; + auto xl = xb + Q::kGroupSize*j; + auto wl = weight + Q::kGroupSize*j; + //auto ql = quantizer.values() + best_idx[j]*Q::kGroupSize; + auto ql = quantizer.values() + qs[j]*Q::kGroupSize; + for (int k = 0; k < Q::kGroupSize; ++k) { + float q = ql[k]*ls; + sumqx += wl[k]*xl[k]*q; + sumq2 += wl[k]*q*q; + } + } + qs += Q::kNg; + } + } + if (sumq2 > 0) { + d = sumqx/sumq2; + *dptr = d; + if (!d) return; + } else { + break; + } + } +} +} + +void quantize_row_iq4_kt_ref(const float * GGML_RESTRICT x, block_iq4_kt * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq4_kt(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq4_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + block_iq4_kt * y = (block_iq4_kt *)vy; + quantize_row_iq4_kt_ref(x, y, k); +} + +size_t quantize_iq4_kt(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + auto row_size = ggml_row_size(GGML_TYPE_IQ4_KT, n_per_row); + std::vector scales(n_per_row/QuantizerIQ4KT::kBlockSize); + std::vector weights(n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_iq4_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data()); + src += n_per_row; + qrow += row_size; + } + return nrows * row_size; +} + +void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) { + using Q = QuantizerIQ4KT; + assert(k % Q::kSuperBlockSize == 0); + const int nb = k / Q::kSuperBlockSize; + const float * dptr = (const float *)x; + const float d = *dptr * Q::kScale; + x = (const block_iq4_kt *)(dptr + 1); + auto& deq = iq4kt_quantizer(); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint16_t * ql = (const uint16_t *)x[ibl].ql; + for (int ib = 0; ib < Q::kNblock; ++ib) { + float sl = d * x[ibl].scales[ib]; + for (int ig = 0; ig < Q::kNg; ++ig) { + deq.set_values(ql[ig], y, sl); + y += Q::kGroupSize; + } + ql += Q::kNg; + } + } +} + +void vec_dot_iq4_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_KT, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + +} diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 0b5e0818..364165e2 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -73,6 +73,12 @@ size_t quantize_iq3_kt(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst void dequantize_row_iq3_kt(const block_iq3_kt * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq3_kt_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_iq4_kt_ref(const float * GGML_RESTRICT x, block_iq4_kt * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq4_kt(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq4_kt(const block_iq4_kt * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq4_kt_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); #ifdef __cplusplus diff --git a/include/llama.h b/include/llama.h index f44ebe41..4571c4ff 100644 --- a/include/llama.h +++ b/include/llama.h @@ -181,6 +181,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_KSS = 148, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_KT = 149, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ3_KT = 150, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ4_KT = 151, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama.cpp b/src/llama.cpp index 4233a1cf..b7254e73 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3846,6 +3846,7 @@ struct llama_model_loader { case GGML_TYPE_IQ2_S: ftype = LLAMA_FTYPE_MOSTLY_IQ2_S; break; case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break; case GGML_TYPE_IQ3_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ3_KT; break; + case GGML_TYPE_IQ4_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KT; break; case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break; case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break; case GGML_TYPE_IQ1_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ1_BN; break; @@ -4554,6 +4555,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_KT: return "IQ3_KT - 3.125 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_KT: return "IQ4_KT - 4.125 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; @@ -15822,6 +15824,10 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_IQ4_K : qs.model.hparams.n_gqa() >= 2 ? GGML_TYPE_IQ3_K : !qs.has_imatrix ? GGML_TYPE_IQ3_K : GGML_TYPE_IQ3_KT; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_KT) { + new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_IQ6_K : qs.model.hparams.n_gqa() >= 2 ? GGML_TYPE_IQ5_K + : !qs.has_imatrix ? GGML_TYPE_IQ4_K : GGML_TYPE_IQ4_KT; + } else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 2) { new_type = GGML_TYPE_IQ4_K; } @@ -16027,7 +16033,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K || new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ4_KS || new_type == GGML_TYPE_IQ2_KT || - new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS || new_type == GGML_TYPE_IQ3_KT) { + new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS || new_type == GGML_TYPE_IQ3_KT || + new_type == GGML_TYPE_IQ4_KT) { int nx = tensor->ne[0]; int ny = tensor->ne[1]; if (nx % QK_K != 0) { @@ -16052,6 +16059,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: @@ -16166,6 +16174,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break; case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break; case LLAMA_FTYPE_MOSTLY_IQ3_KT: default_type = GGML_TYPE_IQ3_KT; break; + case LLAMA_FTYPE_MOSTLY_IQ4_KT: default_type = GGML_TYPE_IQ4_KT; break; case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break; case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break; case LLAMA_FTYPE_MOSTLY_IQ1_BN: default_type = GGML_TYPE_IQ1_BN; break;