diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index c88033b6..4c68b6ca 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -45,6 +45,7 @@ static const std::vector QUANT_OPTIONS = { { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, { "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", }, { "IQ4_KSS", LLAMA_FTYPE_MOSTLY_IQ4_KSS, " 4.0 bpw non-linear quantization", }, + { "IQ4_KNN", LLAMA_FTYPE_MOSTLY_IQ4_KNN, " 4.0 bpw non-linear quantization", }, { "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",}, { "IQ2_KS", LLAMA_FTYPE_MOSTLY_IQ2_KS, " 2.1875 bpw non-linear quantization",}, { "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index a467c297..7d7d4f66 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -406,6 +406,7 @@ extern "C" { GGML_TYPE_IQ4_KS = 144, GGML_TYPE_IQ2_KS = 145, GGML_TYPE_IQ4_KSS = 146, + GGML_TYPE_IQ4_KNN = 147, GGML_TYPE_COUNT, }; @@ -464,6 +465,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_KS = 137, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_KS = 138, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_KNN = 140, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index f8824b0e..32f6d2b9 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -452,6 +452,11 @@ typedef struct { } block_iq4_kss; static_assert(sizeof(block_iq4_kss) == QK_K/8*sizeof(uint32_t), "wrong iq4_kss block size/padding"); +typedef struct { + uint8_t qs[QK_K/2]; +} block_iq4_knn; +static_assert(sizeof(block_iq4_knn) == QK_K/2, "wrong iq4_knn block size/padding"); + typedef struct { ggml_half d; uint16_t extra; diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index e26f36a0..e8c1f08c 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2830,6 +2830,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ4_KNN: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ3_K: diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a5658a24..c556c9b9 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -550,6 +550,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI4_XS; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index e9d15b5d..4ca01855 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -669,6 +669,29 @@ static __global__ void dequantize_block_iq4_kss(const void * __restrict__ vx, ds } } +template +static __global__ void dequantize_block_iq4_knn(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { + + int64_t ii = blockIdx.x; + int64_t row = (QK_K * ii) / n_per_row; + const char * cx = (const char *)vx + row * row_size; + const float * dptr = (const float *)cx; + float d = *dptr; + const int8_t * values = (const int8_t *)(dptr + 1); + const block_iq4_knn * x = (const block_iq4_knn *)(values + 16); + const int64_t i = ii - (row*n_per_row)/QK_K; + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + ii*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[i].qs + 16*ib + 4*il; + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * values[q4[j] & 0xf]; + y[j+16] = d * values[q4[j] >> 4]; + } +} + template static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { const int64_t i = blockIdx.x; @@ -1019,6 +1042,14 @@ static void dequantize_row_iq4_kss_cuda(const void * vx, dst_t * y, const int64_ dequantize_block_iq4_kss<<>>(vx, y, n_per_row, row_size); } +template +static void dequantize_row_iq4_knn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_KNN, n_per_row); + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq4_knn<<>>(vx, y, n_per_row, row_size); +} + template static void dequantize_row_iq2_ks_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; @@ -1193,6 +1224,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq4_ks_cuda; case GGML_TYPE_IQ4_KSS: return dequantize_row_iq4_kss_cuda; + case GGML_TYPE_IQ4_KNN: + return dequantize_row_iq4_knn_cuda; case GGML_TYPE_IQ2_KS: return dequantize_row_iq2_ks_cuda; case GGML_TYPE_IQ2_K: @@ -1268,6 +1301,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq4_ks_cuda; case GGML_TYPE_IQ4_KSS: return dequantize_row_iq4_kss_cuda; + case GGML_TYPE_IQ4_KNN: + return dequantize_row_iq4_knn_cuda; case GGML_TYPE_IQ2_KS: return dequantize_row_iq2_ks_cuda; case GGML_TYPE_IQ2_K: diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index dec54b5e..548c7b86 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -268,6 +268,31 @@ __device__ __forceinline__ float vec_dot_iq4_kss_q8_1( return dl * __low2float(bq8_1[ib32].ds) * sumi; } +#define VDR_IQ4_KNN_Q8_1_MMVQ 4 +#define VDR_IQ4_KNN_Q8_1_MMQ 4 + +__device__ __forceinline__ float vec_dot_iq4_knn_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const float * dptr = (const float *)vbq; + const float d = *dptr; + const uint8_t * values = (const uint8_t *)(dptr + 1); + const block_iq4_knn * bq4 = (const block_iq4_knn *)(values + 16) + kbx; + + // iqs is 0...28 + const int ib32 = iqs/4; // Why iqs/4 ? + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32; + int v1, v2; + int sumi = 0; + for (int j = 0; j < 4; ++j) { + get_int_from_table_16_shift(q4[j], 0, values, v1, v2); + sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi); + sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi); + } + return d * __low2float(bq8_1[ib32].ds) * sumi; +} + #define VDR_IQ5_K_Q8_1_MMVQ 4 #define VDR_IQ5_K_Q8_1_MMQ 4 @@ -739,6 +764,13 @@ void mul_mat_vec_iq4_kss_q8_1_cuda( iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } +void mul_mat_vec_iq4_knn_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + void mul_mat_vec_iq2_ks_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 0678c026..6afb571b 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -36,6 +36,10 @@ void mul_mat_vec_iq4_kss_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); +void mul_mat_vec_iq4_knn_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + void mul_mat_vec_iq2_ks_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 107caf45..6a8d2a1b 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -465,6 +465,9 @@ void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_IQ4_KSS: mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ4_KNN: + mul_mat_vec_iq4_knn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; case GGML_TYPE_IQ2_KS: mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 68ec6126..c69bf96e 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15198,6 +15198,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_IQ1_TN: break; case GGML_TYPE_IQ4_KS: break; case GGML_TYPE_IQ4_KSS: break; + case GGML_TYPE_IQ4_KNN: break; case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 35ed68d0..f2580447 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1113,6 +1113,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 4, }, + [GGML_TYPE_IQ4_KNN] = { + .type_name = "iq4_knn", + .blck_size = QK_K, + .type_size = sizeof(block_iq4_knn), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq4_knn, + .from_float = quantize_row_iq4_knn, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_knn_ref, + .vec_dot = vec_dot_iq4_knn_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 20, + }, [GGML_TYPE_Q8_K] = { .type_name = "q8_K", .blck_size = QK_K, @@ -3932,6 +3945,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; case GGML_FTYPE_MOSTLY_IQ4_KS: wtype = GGML_TYPE_IQ4_KS; break; case GGML_FTYPE_MOSTLY_IQ4_KSS: wtype = GGML_TYPE_IQ4_KSS; break; + case GGML_FTYPE_MOSTLY_IQ4_KNN: wtype = GGML_TYPE_IQ4_KNN; break; case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break; case GGML_FTYPE_MOSTLY_IQ2_KS: wtype = GGML_TYPE_IQ2_KS; break; case GGML_FTYPE_MOSTLY_IQ3_K: wtype = GGML_TYPE_IQ3_K; break; @@ -10434,6 +10448,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ4_KNN: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ3_K: @@ -10825,6 +10840,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ4_KNN: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ3_K: @@ -10966,6 +10982,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ4_KNN: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ3_K: @@ -14153,6 +14170,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ4_KNN: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ3_K: @@ -14534,6 +14552,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ4_KNN: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ3_K: @@ -14809,6 +14828,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ4_KNN: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ3_K: @@ -15411,6 +15431,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ4_KNN: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ3_K: @@ -22230,6 +22251,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_KS: result = quantize_iq4_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_KSS: result = quantize_iq4_kss(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ4_KNN: result = quantize_iq4_knn(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_KS: result = quantize_iq2_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_K: result = quantize_iq3_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 26bc5ecb..03fef7e2 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -3258,4 +3258,324 @@ void vec_dot_iq4_kss_q8_k(int n, float * s, size_t bs, const void * vx, size_t b GGML_UNUSED(by); } +namespace { +struct WorkIQ4NN { + constexpr static int nhbin = 256; + WorkIQ4NN(int n) : n_per_row(n) { + weight.resize(n_per_row); + L.resize(n_per_row); + sum0.resize(nhbin+1); + sum1.resize(nhbin+1); + sum2.resize(nhbin+1); + values.resize(16); + adjusted_values.resize(16); + counts.resize(16); + bins.resize(17); + } + std::vector weight; + std::vector sum0, sum1, sum2; + std::vector values, adjusted_values, counts; + std::vector L; + std::vector bins; + int n_per_row; +}; + +inline float get_mse_d(int imin, int imax, const double* sum0, const double* sum1, const double* sum2) { + double n = sum0[imax] - sum0[imin]; + if (!n) return 0.f; + double s = sum1[imax] - sum1[imin]; + return sum2[imax] - sum2[imin] - s*s/n; +} + +inline int divide_bin_d(int n, const double* sum0, const double* sum1, const double* sum2, float& new_mse) { + float best_mse = get_mse_d(0, n, sum0, sum1, sum2); + int best_i = -1; + for (int i = 1; i < n; ++i) { + float mse = get_mse_d(0, i, sum0, sum1, sum2) + get_mse_d(i, n, sum0, sum1, sum2); + if (mse < best_mse) { + best_mse = mse; + best_i = i; + } + } + new_mse = best_mse; + return best_i; +} + +int make_the_bins_d(int nbin, int nhave, const double * sum0, const double * sum1, const double * sum2, int * bins) { + float tmp_mse; + while (nhave < nbin + 1) { + float best_delta = 0; + int best_bin = -1, div = -1; + for (int bin=0; bin 1) { + int bin_div = divide_bin_d(n, sum0 + bins[bin], sum1 + bins[bin], sum2 + bins[bin], tmp_mse); + if (bin_div >= 0 && tmp_mse < cur_mse) { + float delta = cur_mse - tmp_mse; + if (delta > best_delta) { + best_delta = delta; best_bin = bin; + div = bins[bin] + bin_div; + } + } + } + } + if (best_bin < 0) { + //printf("Oops: failed to find bin\n"); + //printf("nbin = %d, nh = %d, nhave = %d\n",nbin,nh,nhave); + //for (int bin=0; bin best_bin; --i) bins[i+1] = bins[i]; + bins[best_bin+1] = div; + ++nhave; + } + for (int it = 0; it < 10; ++it) { + int nchanged = 0; + for (int bin=0; bin= 0) { + if (bins[bin] + bin_div != bins[bin+1]) { + bins[bin+1] = bins[bin] + bin_div; + ++nchanged; + } + } + } + for (int bin=1; bin= 0) { + if (bins[bin] + bin_div != bins[bin+1]) { + bins[bin+1] = bins[bin] + bin_div; + ++nchanged; + } + } + } + if (nchanged == 0) break; + } + return nhave; +} + +int make_bins_d(int nbin, int nh, const double * sum0, const double * sum1, const double * sum2, int * bins) { + int nhave = 0; + bins[nhave++] = 0; + bins[nhave++] = nh; + return make_the_bins_d(nbin, nhave, sum0, sum1, sum2, bins); +} + +inline int find_index(int n, const int * bins, int l) { + if (l <= bins[0]) return 0; + if (l >= bins[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (l < bins[mav]) mu = mav; else ml = mav; + } + return mu-1; +} + +void quantize_row_iq4_knn_impl(const float * x, char * qrow, const float * imatrix, WorkIQ4NN& work) { + GGML_UNUSED(qrow); + GGML_UNUSED(imatrix); + GGML_UNUSED(work); + + int n_per_row = work.n_per_row; + + float * dptr = (float *)qrow; + *dptr = 0; + int8_t * int_values = (int8_t *)(dptr + 1); + uint8_t * qs = (uint8_t *)(int_values + 16); + memset(int_values, 0, 16); + memset(qs, 0, work.n_per_row/2); + + float sigma2 = 0; + float xmin = x[0], xmax = x[0]; + for (int j = 0; j < work.n_per_row; ++j) { + sigma2 += x[j]*x[j]; + xmin = std::min(xmin, x[j]); + xmax = std::max(xmax, x[j]); + } + if (xmax == 0) { + *dptr = xmin; + int_values[0] = 1; + return; + } + sigma2 *= 2.f/work.n_per_row; + auto weight = work.weight.data(); + if (imatrix) { + for (int j = 0; j < n_per_row; ++j) weight[j] = imatrix[j] * sqrtf(sigma2 + x[j]*x[j]); + } else { + for (int j = 0; j < n_per_row; ++j) weight[j] = 0.25f*sigma2 + x[j]*x[j]; + } + + auto sum0 = work.sum0.data(); + auto sum1 = work.sum1.data(); + auto sum2 = work.sum2.data(); + std::memset(sum0, 0, (work.nhbin+1)*sizeof(double)); + std::memset(sum1, 0, (work.nhbin+1)*sizeof(double)); + std::memset(sum2, 0, (work.nhbin+1)*sizeof(double)); + + float id = (work.nhbin - 1)/(xmax - xmin); + for (int j = 0; j < n_per_row; ++j) { + int l = nearest_int(id*(x[j] - xmin)); + work.L[j] = l; + double w = double(weight[j]); + double xv = double(x[j]); + sum0[l+1] += w; + sum1[l+1] += w*xv; + sum2[l+1] += w*xv*xv; + } + + for (int j = 0; j < work.nhbin; ++j) { + sum0[j+1] += sum0[j]; + sum1[j+1] += sum1[j]; + sum2[j+1] += sum2[j]; + } + + auto bins = work.bins.data(); + int nbin = make_bins_d(16, work.nhbin, sum0, sum1, sum2, bins); + + //printf("Got %d bins:\n", nbin); + //for (int i = 0; i < nbin; ++i) printf("%2d %3d\n", i, bins[i]); + + GGML_ASSERT(nbin <= 17); + + std::memset(work.values.data(), 0, 16*sizeof(float)); + memset(work.adjusted_values.data(), 0, 16*sizeof(float)); + memset(work.counts.data(), 0, 16*sizeof(float)); + for (int j = 0; j < n_per_row; ++j) { + int l = find_index(nbin, bins, work.L[j]); + l = std::min(15, l); + work.L[j] = l; + float w = weight[j]; + work.adjusted_values[l] += w*x[j]; + work.counts[l] += w; + } + + const int ntry = 11; + + int nchanged = 0; + for (int itry = 0; itry < ntry; ++itry) { + //printf("======== Iteration %d\n", itry); + for (int i = 0; i < 16; ++i) { + if (work.counts[i] > 0) work.values[i] = work.adjusted_values[i]/work.counts[i]; + //printf("%2d %g %g %g\n", i, work.values[i], work.adjusted_values[i], work.counts[i]); + work.adjusted_values[i] = work.counts[i] = 0; + } + nchanged = 0; + for (int j = 0; j < n_per_row; ++j) { + int idx0 = work.L[j]; + int idx = idx0; + float diff = fabsf(x[j] - work.values[idx]); + if (idx0 > 0) { + float this_diff = fabsf(x[j] - work.values[idx0-1]); + if (this_diff < diff) { + diff = this_diff; idx = idx0-1; + } + } + if (idx0 < nbin-1) { + float this_diff = fabsf(x[j] - work.values[idx0+1]); + if (this_diff < diff) { + diff = this_diff; idx = idx0+1; + } + } + if (idx != idx0) { + ++nchanged; + work.L[j] = idx; + } + float w = weight[j]; + work.adjusted_values[idx] += w*x[j]; + work.counts[idx] += w; + } + //printf("nchanged = %d\n", nchanged); + if (nchanged == 0) break; + } + if (nchanged > 0) { + for (int i = 0; i < 16; ++i) { + if (work.counts[i] > 0) work.values[i] = work.adjusted_values[i]/work.counts[i]; + } + } + + float max = 0, amax = 0; + for (int i = 0; i < 16; ++i) { + float ax = fabsf(work.values[i]); + if (ax > amax) { + amax = ax; max = work.values[i]; + } + } + + float d = -max/128; + //printf("amax = %g, max = %g d = %g\n", amax, max, d); + *dptr = d; + id = d ? 1/d : 0.f; + for (int i = 0; i < 16; ++i) { + int l = nearest_int(id*work.values[i]); + int_values[i] = std::max(-128, std::min(127, l)); + //printf("int_values[%d] = %d\n", i, int_values[i]); + } + + int nb32 = n_per_row/32; + auto L = work.L.data(); + for (int ib = 0; ib < nb32; ++ib) { + for (int j = 0; j < 16; ++j) { + qs[j] = L[j] | (L[j+16] << 4); + } + qs += 16; + L += 32; + } +} +} + +size_t quantize_iq4_knn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + auto row_size = ggml_row_size(GGML_TYPE_IQ4_KNN, n_per_row); + auto qrow = (char *)dst; + WorkIQ4NN work(n_per_row); + for (int row = 0; row < nrows; ++row) { + quantize_row_iq4_knn_impl(src, qrow, imatrix, work); + src += n_per_row; + qrow += row_size; + } + return nrows * row_size; +} + +void quantize_row_iq4_knn_ref(const float * x, block_iq4_knn * y, int64_t k) { + quantize_iq4_knn(x, y, 1, k, nullptr); +} + +void quantize_row_iq4_knn(const float * x, void * y, int64_t k) { + quantize_iq4_knn(x, (block_iq4_knn *)y, 1, k, nullptr); +} + +void dequantize_row_iq4_knn(const block_iq4_knn * x, float * y, int64_t k) { + const float * dptr = (const float *)x; + const float d = *dptr; + const int8_t * values = (const int8_t *)(dptr + 1); + const uint8_t * qs = (const uint8_t *)(values + 16); + int nblock = k/32; + for (int ib = 0; ib < nblock; ++ib) { + for (int j = 0; j < 16; ++j) { + y[j+ 0] = d * values[qs[j] & 0xf]; + y[j+16] = d * values[qs[j] >> 4]; + } + y += 32; + qs += 16; + } +} + +void vec_dot_iq4_knn_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_KNN, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index e0dde0d8..2e685b61 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -67,6 +67,12 @@ size_t quantize_iq4_kss(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds void dequantize_row_iq4_kss(const block_iq4_kss * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq4_kss_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_iq4_knn_ref(const float * GGML_RESTRICT x, block_iq4_knn * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_knn(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq4_knn(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq4_knn(const block_iq4_knn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq4_knn_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void quantize_row_iq2_ks_ref(const float * GGML_RESTRICT x, block_iq2_ks * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_ks(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); size_t quantize_iq2_ks(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/include/llama.h b/include/llama.h index b2906693..577f965b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -181,6 +181,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ3_KL = 146, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_KS = 147, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_KSS = 148, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ4_KNN = 149, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama.cpp b/src/llama.cpp index cae91619..919f80ce 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3796,6 +3796,7 @@ struct llama_model_loader { case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ4_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS; break; case GGML_TYPE_IQ4_KSS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KSS; break; + case GGML_TYPE_IQ4_KNN: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KNN; break; case GGML_TYPE_IQ2_K: ftype = LLAMA_FTYPE_MOSTLY_IQ2_K; break; case GGML_TYPE_IQ3_K: ftype = LLAMA_FTYPE_MOSTLY_IQ3_K; break; case GGML_TYPE_IQ4_K: ftype = LLAMA_FTYPE_MOSTLY_IQ4_K; break; @@ -4500,6 +4501,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_KS: return "IQ4_KS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_KSS: return "IQ4_KSS - 4.0 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_KNN: return "IQ4_KNN - 4.0 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_K: return "IQ2_K - 2.375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_K: return "IQ3_K - 3.4325 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_KL: return "IQ3_KL - 4 bpw"; @@ -15654,7 +15656,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = !qs.has_output ? GGML_TYPE_IQ4_K : GGML_TYPE_Q5_K; } else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || - ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) && !qs.has_output) { + ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KNN) && !qs.has_output) { new_type = GGML_TYPE_IQ5_K; } else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_IQ6_K) { @@ -15747,7 +15749,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n } else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || - ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) && qs.model.hparams.n_gqa() >= 2) { + ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS|| + ftype == LLAMA_FTYPE_MOSTLY_IQ4_KNN) && qs.model.hparams.n_gqa() >= 2) { new_type = GGML_TYPE_IQ5_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_K && qs.model.hparams.n_gqa() >= 2) { @@ -15831,7 +15834,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n } } else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || - ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) && !qs.has_imatrix) { + ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KNN) && !qs.has_imatrix) { new_type = GGML_TYPE_Q5_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K; @@ -15923,7 +15926,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K || new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_IQ2_TN || new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ1_TN || new_type == GGML_TYPE_IQ4_KS || - new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS) { + new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS || new_type == GGML_TYPE_IQ4_KNN) { int nx = tensor->ne[0]; int ny = tensor->ne[1]; if (nx % QK_K != 0) { @@ -15956,6 +15959,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ4_KNN: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; case GGML_TYPE_IQ4_K: @@ -16070,6 +16074,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ4_KS: default_type = GGML_TYPE_IQ4_KS; break; case LLAMA_FTYPE_MOSTLY_IQ4_KSS: default_type = GGML_TYPE_IQ4_KSS; break; + case LLAMA_FTYPE_MOSTLY_IQ4_KNN: default_type = GGML_TYPE_IQ4_KNN; break; case LLAMA_FTYPE_MOSTLY_IQ2_K: default_type = GGML_TYPE_IQ2_K; break; case LLAMA_FTYPE_MOSTLY_IQ3_K: default_type = GGML_TYPE_IQ3_K; break; case LLAMA_FTYPE_MOSTLY_IQ3_KL: default_type = GGML_TYPE_IQ3_K; break;