From 32ff1f956f391483327b2eacae9adbf9dcc73b80 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 16 Jun 2025 16:57:16 +0300 Subject: [PATCH] iq3_kt: use integer trellis + CUDA dequantize and MMVQ --- ggml/src/ggml-cuda/common.cuh | 21 +++++++++----- ggml/src/ggml-cuda/convert.cu | 4 +-- ggml/src/ggml-cuda/iqk_mmvq.cu | 50 +++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/iqk_mmvq.cuh | 15 ++++++---- ggml/src/ggml-cuda/mmvq.cu | 12 +++++--- ggml/src/iqk/iqk_quantize.cpp | 10 +++++-- 6 files changed, 91 insertions(+), 21 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 291378f4..a0cdab28 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -578,6 +578,20 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI4_XS; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; @@ -648,13 +662,6 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI3_S; }; -template<> -struct ggml_cuda_type_traits { - static constexpr int qk = QK_K; - static constexpr int qr = QR4_XS; - static constexpr int qi = QI4_XS; -}; - ////////////////////// struct ggml_cuda_device_info { diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 7fddcc89..b40079a3 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -394,10 +394,10 @@ static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst dst_t * y = yy + ii*QK_K + 8*ib; const uint16_t * ql = (const uint16_t *)x[i].ql; uint32_t idx = ql[ib] + 4096; - const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 31.75f * 1.01f; //1.015f; + const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 1.01f; //1.015f; uint8_t mask = 1 << (ib/4); for (int j = 0; j < 8; ++j) { - y[j] = dl * std::abs(trellis_next(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f); + y[j] = dl * std::abs(trellis_next_int(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f); } } diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index bec6a739..c19215de 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -504,6 +504,48 @@ __device__ __forceinline__ void vec_dot_iq2_kt_q8_1( *result += dl * __low2float(bq8_1[ib32].ds) * sumi; } +__device__ __forceinline__ void vec_dot_iq3_kt_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + + float scale = *(const float *)vbq; + const block_iq3_kt * bq3 = (const block_iq3_kt *)((const char *)vbq + sizeof(float)) + kbx; + + // iqs is 0...28 + const int ib32 = iqs/4; + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + const int ls = (bq3->scales[ib32%4] >> 4*(ib32/4)) & 0xf; + const float dl = scale * ls * 1.015f; + auto ql = (const uint16_t *)bq3->ql; + uint32_t mask = 0x01010101 << ib32; + const uint32_t * qh = (const uint32_t *)bq3->qh; + int sumi = 0; + for (int j = 0; j < 4; ++j) { + uint32_t val = ql[4*ib32+j] + 4096; + int v4 = 0; + for (int k = 0; k < 4; ++k) { + val *= ka; + int8_t q = std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)); + v4 |= q << 8*k; + } + uint32_t signs = __vcmpne4(qh[2*j+0] & mask, 0); + v4 = __vsub4(v4 ^ signs, signs); + sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi); + v4 = 0; + for (int k = 0; k < 4; ++k) { + val *= ka; + int8_t q = std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)); + v4 |= q << 8*k; + } + signs = __vcmpne4(qh[2*j+1] & mask, 0); + v4 = __vsub4(v4 ^ signs, signs); + sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi); + } + *result += dl * __low2float(bq8_1[ib32].ds) * sumi; +} + #define VDR_IQ4_KSS_Q8_1_MMVQ 4 #define VDR_IQ4_KSS_Q8_1_MMQ 4 @@ -1304,6 +1346,14 @@ void mul_mat_vec_iq2_kt_q8_1_cuda( iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } +void mul_mat_vec_iq3_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +} + void mul_mat_vec_iq4_kss_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index a77bef54..e7c6e1d2 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -101,12 +101,17 @@ void mul_mat_vec_iq1_m_r4_q8_1_cuda( const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); -void mul_mat_vec_iq4_kt_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); - void mul_mat_vec_iq2_kt_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); + +void mul_mat_vec_iq3_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); + +void mul_mat_vec_iq4_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 76126d76..6412be30 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -527,12 +527,15 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm case GGML_TYPE_IQ4_KSS: mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; - case GGML_TYPE_IQ4_KT: - mul_mat_vec_iq4_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); - break; case GGML_TYPE_IQ2_KT: mul_mat_vec_iq2_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; + case GGML_TYPE_IQ3_KT: + mul_mat_vec_iq3_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ4_KT: + mul_mat_vec_iq4_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; case GGML_TYPE_IQ2_KS: mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; @@ -693,8 +696,9 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) { case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ1_M_R4: - case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: return true; default: return false; diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 65cd1c3e..0384e49a 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -7427,7 +7427,11 @@ public: for (int k = 0; k < kGroupSize; ++k) { x = ka*x; s = x & 0x3f3f3f3f; - result[k] = scale*(i8[0] + i8[1] + i8[2] + i8[3] - 126.f); + if constexpr (is_abs) { + result[k] = scale*std::abs(i8[0] + i8[1] + i8[2] + i8[3] - 126.f); + } else { + result[k] = scale*(i8[0] + i8[1] + i8[2] + i8[3] - 126.f); + } } } else { constexpr uint32_t ka = 89226354; @@ -8289,7 +8293,7 @@ void vec_dot_iq2_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx namespace { -using QuantizerIQ3KT = QuantizerIQKT<32, 8, 16, true>; +using QuantizerIQ3KT = QuantizerIQKT<32, 8, 16, true, true>; const QuantizerIQ3KT& iq3kt_quantizer() { static std::mutex mutex; std::lock_guard lock(mutex); @@ -8500,7 +8504,7 @@ size_t quantize_iq3_kt(const float * src, void * dst, int64_t nrows, int64_t n_p void dequantize_row_iq3_kt(const block_iq3_kt * x, float * y, int64_t k) { #ifdef __AVX2__ - if (iqk_dequantize_ktquants(GGML_TYPE_IQ3_KT, k, x, 0, y, 0, 1)) return; + //if (iqk_dequantize_ktquants(GGML_TYPE_IQ3_KT, k, x, 0, y, 0, 1)) return; #endif using Q = QuantizerIQ3KT; constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;