diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 85ceabfd..fd657373 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -76,6 +76,7 @@ static const std::vector QUANT_OPTIONS = { { "IQ2_K_R4", LLAMA_FTYPE_MOSTLY_IQ2_K_R4, "IQ2_K repacked",}, { "IQ2_KS", LLAMA_FTYPE_MOSTLY_IQ2_KS, " 2.1875 bpw non-linear quantization",}, { "IQ2_KT", LLAMA_FTYPE_MOSTLY_IQ2_KT, " 2.125 bpw trellis quantization", }, + { "IQ3_KS", LLAMA_FTYPE_MOSTLY_IQ3_KS, " 3.19 bpw non-linear quantization", }, { "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", }, { "IQ3_K_R4", LLAMA_FTYPE_MOSTLY_IQ3_K_R4, "IQ3_K repacked", }, { "IQ3_KL", LLAMA_FTYPE_MOSTLY_IQ3_KL, " 4 bpw non-linear quantization mix",}, diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index ae7a55c6..e0035c7a 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3465,6 +3465,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ3_KS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ5_KS: diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index e2690cb3..973af2b8 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -599,6 +599,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI4_XS; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index b40079a3..61c09481 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -1333,6 +1333,51 @@ static __global__ void dequantize_block_iq3_k(const void * __restrict__ vx, dst_ } } +template +static __global__ void dequantize_block_iq3_ks(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { + + int64_t ii = blockIdx.x; + int64_t row = (QK_K * ii) / n_per_row; + const char * cx = (const char *)vx + row * row_size; + float scale = *(const ggml_half *)cx; + const block_iq3_ks * x = (const block_iq3_ks *)(cx + sizeof(ggml_half)); + const int64_t i = ii - (row*n_per_row)/QK_K; + + const int64_t tid = threadIdx.x; + const int64_t is = tid/16; + const int64_t il = tid%16; + dst_t * y = yy + ii*QK_K + 128*is + 2*il; + const uint8_t * qs = x[i].qs + 32*is + 2*il; + const uint8_t * qh = x[i].qh + 2*il; + uint16_t extra = x[i].extra >> 4*is; + const float d0 = scale * (int(((x[i].scales[0] >> 4*is) & 0xf) | ((extra << 4) & 0x10)) - 16); + const float d1 = scale * (int(((x[i].scales[1] >> 4*is) & 0xf) | ((extra << 3) & 0x10)) - 16); + const float d2 = scale * (int(((x[i].scales[2] >> 4*is) & 0xf) | ((extra << 2) & 0x10)) - 16); + const float d3 = scale * (int(((x[i].scales[3] >> 4*is) & 0xf) | ((extra << 1) & 0x10)) - 16); + extra >>= 8; + const int8_t * values0 = iq3nl_values + ((extra & 1) << 3); + const int8_t * values1 = iq3nl_values + ((extra & 2) << 2); + const int8_t * values2 = iq3nl_values + ((extra & 4) << 1); + const int8_t * values3 = iq3nl_values + ((extra & 8) << 0); + if constexpr (std::is_same_v) { + for (int j = 0; j < 2; ++j) { + uint8_t h = qh[j] >> 4*is; + y[j+ 0] = __float2bfloat16(d0 * values0[((qs[j] >> 0) & 3) | ((h << 2) & 4)]); + y[j+32] = __float2bfloat16(d1 * values1[((qs[j] >> 2) & 3) | ((h << 1) & 4)]); + y[j+64] = __float2bfloat16(d2 * values2[((qs[j] >> 4) & 3) | ((h >> 0) & 4)]); + y[j+96] = __float2bfloat16(d3 * values3[((qs[j] >> 6) & 3) | ((h >> 1) & 4)]); + } + } else { + for (int j = 0; j < 2; ++j) { + uint8_t h = qh[j] >> 4*is; + y[j+ 0] = d0 * values0[((qs[j] >> 0) & 3) | ((h << 2) & 4)]; + y[j+32] = d1 * values1[((qs[j] >> 2) & 3) | ((h << 1) & 4)]; + y[j+64] = d2 * values2[((qs[j] >> 4) & 3) | ((h >> 0) & 4)]; + y[j+96] = d3 * values3[((qs[j] >> 6) & 3) | ((h >> 1) & 4)]; + } + } +} + template static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; @@ -1573,6 +1618,14 @@ static void dequantize_row_iq3_k_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_iq3_k<<>>(vx, y); } +template +static void dequantize_row_iq3_ks_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ3_KS, n_per_row); + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq3_ks<<>>(vx, y, n_per_row, row_size); +} + template static void dequantize_row_iq3_k_r4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; @@ -1719,6 +1772,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { return dequantize_row_iq2_k_cuda; case GGML_TYPE_IQ3_K: return dequantize_row_iq3_k_cuda; + case GGML_TYPE_IQ3_KS: + return dequantize_row_iq3_ks_cuda; case GGML_TYPE_IQ4_KSS: return dequantize_row_iq4_kss_cuda; case GGML_TYPE_IQ4_KS: @@ -1821,6 +1876,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq2_k_cuda; case GGML_TYPE_IQ3_K: return dequantize_row_iq3_k_cuda; + case GGML_TYPE_IQ3_KS: + return dequantize_row_iq3_ks_cuda; case GGML_TYPE_IQ4_K: return dequantize_row_iq4_k_cuda; case GGML_TYPE_IQ5_K: @@ -1916,6 +1973,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq2_k_cuda; case GGML_TYPE_IQ3_K: return dequantize_row_iq3_k_cuda; + case GGML_TYPE_IQ3_KS: + return dequantize_row_iq3_ks_cuda; case GGML_TYPE_IQ4_K: return dequantize_row_iq4_k_cuda; case GGML_TYPE_IQ5_K: diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index e69bcc4a..b00dc61d 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -405,6 +405,29 @@ __device__ __forceinline__ void vec_dot_iq4_ks_q8_1( *result += dl * __low2float(bq8_1[ib32].ds) * sumi; } +// TODO +__device__ __forceinline__ void vec_dot_iq3_ks_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { + + float scale = *(const float *)vbq; + const block_iq4_ks * bq4 = (const block_iq4_ks *)((const char *)vbq + sizeof(float)) + kbx; + const uint8_t * all_values = (const uint8_t *)iq4k_values; + + // iqs is 0...28 + const int ib32 = iqs/4; // Why iqs/4 ? + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32; + const float dl = scale * ((bq4->scales[ib32] & 254) - 127); + int v1, v2; + int sumi = 0; + for (int j = 0; j < 4; ++j) { + get_int_from_table_16_shift(q4[j], bq4->scales[ib32] & 1, all_values, v1, v2); + sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi); + sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi); + } + *result += dl * __low2float(bq8_1[ib32].ds) * sumi; +} + __device__ __forceinline__ void vec_dot_iq4_kt_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { @@ -1302,6 +1325,14 @@ void mul_mat_vec_iq4_ks_q8_1_cuda( iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } +void mul_mat_vec_iq3_ks_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +} + void mul_mat_vec_iq4_kt_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index e7c6e1d2..c2416b1e 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -16,6 +16,11 @@ void mul_mat_vec_iq3_k_q8_1_cuda( const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); +void mul_mat_vec_iq3_ks_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); + void mul_mat_vec_iq4_k_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, diff --git a/include/llama.h b/include/llama.h index 53adeb94..6c8bff95 100644 --- a/include/llama.h +++ b/include/llama.h @@ -201,6 +201,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ2_KT = 151, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ3_KT = 152, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_KT = 153, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ3_KS = 154, // except 1d tensors // LLAMA_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors diff --git a/src/llama.cpp b/src/llama.cpp index 1ea2084d..8823ab5b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4374,6 +4374,7 @@ struct llama_model_loader { case GGML_TYPE_IQ5_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ5_KS; break; case GGML_TYPE_IQ2_K: ftype = LLAMA_FTYPE_MOSTLY_IQ2_K; break; case GGML_TYPE_IQ2_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_K_R4;break; + case GGML_TYPE_IQ3_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_KS; break; case GGML_TYPE_IQ3_K: ftype = LLAMA_FTYPE_MOSTLY_IQ3_K; break; case GGML_TYPE_IQ3_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ3_K_R4;break; case GGML_TYPE_IQ4_K: ftype = LLAMA_FTYPE_MOSTLY_IQ4_K; break; @@ -5115,6 +5116,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ5_KS: return "IQ5_KS - 5.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_K: return "IQ2_K - 2.375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_K_R4: return "IQ2_K_R4 - 2.375 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_KS: return "IQ3_KS - 3.1875 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_K: return "IQ3_K - 3.4325 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_K_R4: return "IQ3_K_R4 - 3.4325 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_KL: return "IQ3_KL - 4 bpw"; @@ -18642,7 +18644,7 @@ static ggml_type change_type_if_necessary(ggml_type new_type, int nx, int ny) { new_type == GGML_TYPE_IQ4_K_R4|| new_type == GGML_TYPE_Q8_K_R8 || new_type == GGML_TYPE_IQ3_K_R4|| new_type == GGML_TYPE_IQ2_K_R4|| new_type == GGML_TYPE_IQ5_K_R4|| new_type == GGML_TYPE_IQ4_KS_R4 || new_type == GGML_TYPE_IQ3_XXS_R4 || new_type == GGML_TYPE_IQ2_XXS_R4 || new_type == GGML_TYPE_IQ2_XS_R4 || - new_type == GGML_TYPE_IQ2_S_R4|| new_type == GGML_TYPE_IQ3_S_R4|| + new_type == GGML_TYPE_IQ2_S_R4|| new_type == GGML_TYPE_IQ3_S_R4|| new_type == GGML_TYPE_IQ3_KS || new_type == GGML_TYPE_IQ2_KT || new_type == GGML_TYPE_IQ3_KT || new_type == GGML_TYPE_IQ4_KT || new_type == GGML_TYPE_IQ5_KS || new_type == GGML_TYPE_IQ5_KS_R4) { if (nx % QK_K != 0) { @@ -18676,6 +18678,7 @@ static ggml_type change_type_if_necessary(ggml_type new_type, int nx, int ny) { case GGML_TYPE_Q3_K_R4: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: + case GGML_TYPE_IQ3_KS: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ3_K_R4: case GGML_TYPE_IQ4_KSS: @@ -18810,7 +18813,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K || - ftype == LLAMA_FTYPE_MOSTLY_IQ2_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ2_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M_R4 || @@ -19003,6 +19006,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_K && qs.model.hparams.n_gqa() >= 2) { new_type = GGML_TYPE_IQ4_K; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_KS && qs.model.hparams.n_gqa() >= 2) { + new_type = GGML_TYPE_IQ4_KS; + } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_K_R4 && qs.model.hparams.n_gqa() >= 2) { new_type = GGML_TYPE_IQ4_K_R4; } @@ -19059,6 +19065,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (new_type == GGML_TYPE_Q2_K_R4 || new_type == GGML_TYPE_IQ3_XXS_R4) new_type = GGML_TYPE_IQ3_K_R4; else if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_IQ3_S) new_type = GGML_TYPE_Q4_K; else if (new_type == GGML_TYPE_IQ3_K) new_type = GGML_TYPE_IQ4_K; + else if (new_type == GGML_TYPE_IQ3_KS) new_type = GGML_TYPE_IQ4_KS; else if (new_type == GGML_TYPE_IQ3_S_R4) new_type = GGML_TYPE_Q4_K_R4; else if (new_type == GGML_TYPE_Q3_K_R4) new_type = GGML_TYPE_Q4_K_R4; else if (new_type == GGML_TYPE_Q4_K || new_type == GGML_TYPE_IQ4_XS) new_type = GGML_TYPE_Q5_K; @@ -19185,7 +19192,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ftype == LLAMA_FTYPE_MOSTLY_IQ5_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ5_KS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_KL || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R8 || - ftype == LLAMA_FTYPE_MOSTLY_Q3_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT || + ftype == LLAMA_FTYPE_MOSTLY_Q3_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT || ftype == LLAMA_FTYPE_MOSTLY_IQ3_KS || ftype == LLAMA_FTYPE_MOSTLY_Q2_K_R4|| ftype == LLAMA_FTYPE_MOSTLY_IQ4_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S_R4) { new_type = GGML_TYPE_Q5_K; // should the IQ_K quants be applied here as the new type for the IQ_K ftypes ? @@ -19427,6 +19434,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ5_KS: default_type = GGML_TYPE_IQ5_KS; break; case LLAMA_FTYPE_MOSTLY_IQ2_K: default_type = GGML_TYPE_IQ2_K; break; case LLAMA_FTYPE_MOSTLY_IQ2_K_R4:default_type = GGML_TYPE_IQ2_K_R4;break; + case LLAMA_FTYPE_MOSTLY_IQ3_KS: default_type = GGML_TYPE_IQ3_KS; break; case LLAMA_FTYPE_MOSTLY_IQ3_K: default_type = GGML_TYPE_IQ3_K; break; case LLAMA_FTYPE_MOSTLY_IQ3_K_R4:default_type = GGML_TYPE_IQ3_K_R4;break; case LLAMA_FTYPE_MOSTLY_IQ3_KL: default_type = GGML_TYPE_IQ3_K; break;