From f0efb1f52a4556e97142bbda07d377ac359c16c3 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 26 May 2025 18:35:03 +0300 Subject: [PATCH] CUDA: iq4_ks_r4 GEMV and GEMM --- ggml/src/ggml-cuda.cu | 1 + ggml/src/ggml-cuda/convert.cu | 64 +++++++++++++++++++++++++++++++-- ggml/src/ggml-cuda/iqk_mmvq.cu | 45 +++++++++++++++++++++++ ggml/src/ggml-cuda/iqk_mmvq.cuh | 5 +++ ggml/src/ggml-cuda/mmvq.cu | 4 +++ 5 files changed, 116 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 6331bc17..c9b5a4f4 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3473,6 +3473,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ3_K_R4: case GGML_TYPE_IQ4_K_R4: + case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ5_K_R4: return true; default: diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 2ccca01b..8862f7d3 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -801,6 +801,50 @@ static __global__ void dequantize_block_iq4_k_r4(const void * __restrict__ vx, d } } +template +static __global__ void dequantize_block_iq4_ks_r4(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { + + int64_t ii = blockIdx.x; + + int64_t nblock = n_per_row/256; + int64_t row = ii/nblock; + int64_t row4 = row/4; + int64_t ir = row%4; + int64_t ibl = ii%nblock; + + const int tid = threadIdx.x; + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + + const float * dptr = (const float *)((const char *)vx + 4*row4*row_size); + const float d = dptr[ir]; + const block_iq4_ks_r4 * x = (const block_iq4_ks_r4 *)(dptr + 4); + dst_t * y = yy + 256*ii + 32*ib; + + float dl = d * ((x[ibl].scales[4*ib + ir] & 254) - 127); + auto values = iq4k_values + ((x[ibl].scales[4*ib + ir] & 1) << 4); + auto qs = x[ibl].qs + 64*ib + 4*ir; + if constexpr (std::is_same_v) { + y[il+ 0] = __float2bfloat16(dl * values[qs[il+ 0] & 0xf]); + y[il+ 8] = __float2bfloat16(dl * values[qs[il+ 0] >> 4]); + y[il+16] = __float2bfloat16(dl * values[qs[il+16] & 0xf]); + y[il+24] = __float2bfloat16(dl * values[qs[il+16] >> 4]); + y[il+ 4] = __float2bfloat16(dl * values[qs[il+32] & 0xf]); + y[il+12] = __float2bfloat16(dl * values[qs[il+32] >> 4]); + y[il+20] = __float2bfloat16(dl * values[qs[il+48] & 0xf]); + y[il+28] = __float2bfloat16(dl * values[qs[il+48] >> 4]); + } else { + y[il+ 0] = dl * values[qs[il+ 0] & 0xf]; + y[il+ 4] = dl * values[qs[il+32] & 0xf]; + y[il+ 8] = dl * values[qs[il+ 0] >> 4]; + y[il+12] = dl * values[qs[il+32] >> 4]; + y[il+16] = dl * values[qs[il+16] & 0xf]; + y[il+20] = dl * values[qs[il+48] & 0xf]; + y[il+24] = dl * values[qs[il+16] >> 4]; + y[il+28] = dl * values[qs[il+48] >> 4]; + } +} + template static __global__ void dequantize_block_iq5_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -1395,7 +1439,7 @@ static void dequantize_row_iq3_k_cuda(const void * vx, dst_t * y, const int64_t template static void dequantize_row_iq3_k_r4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; - const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_K, n_per_row); + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ3_K, n_per_row); const int nb = (k + QK_K - 1) / QK_K; dequantize_block_iq3_k_r4<<>>(vx, y, n_per_row, row_size); } @@ -1403,7 +1447,7 @@ static void dequantize_row_iq3_k_r4_cuda(const void * vx, dst_t * y, const int64 template static void dequantize_row_iq2_k_r4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; - const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_K, n_per_row); + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_K, n_per_row); const int nb = (k + QK_K - 1) / QK_K; dequantize_block_iq2_k_r4<<>>(vx, y, n_per_row, row_size); } @@ -1423,6 +1467,14 @@ static void dequantize_row_iq4_k_r4_cuda(const void * vx, dst_t * y, const int64 dequantize_block_iq4_k_r4<<>>(vx, y, n_per_row, row_size); } +template +static void dequantize_row_iq4_ks_r4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_KS, n_per_row); + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq4_ks_r4<<>>(vx, y, n_per_row, row_size); +} + template static void dequantize_row_iq5_k_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; @@ -1433,7 +1485,7 @@ static void dequantize_row_iq5_k_cuda(const void * vx, dst_t * y, const int64_t template static void dequantize_row_iq5_k_r4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; - const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_K, n_per_row); + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ5_K, n_per_row); const int nb = (k + QK_K - 1) / QK_K; dequantize_block_iq5_k_r4<<>>(vx, y, n_per_row, row_size); } @@ -1540,6 +1592,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { return dequantize_row_iq3_k_r4_cuda; case GGML_TYPE_IQ4_K_R4: return dequantize_row_iq4_k_r4_cuda; + case GGML_TYPE_IQ4_KS_R4: + return dequantize_row_iq4_ks_r4_cuda; case GGML_TYPE_IQ5_K_R4: return dequantize_row_iq5_k_r4_cuda; default: @@ -1630,6 +1684,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq3_k_r4_cuda; case GGML_TYPE_IQ4_K_R4: return dequantize_row_iq4_k_r4_cuda; + case GGML_TYPE_IQ4_KS_R4: + return dequantize_row_iq4_ks_r4_cuda; case GGML_TYPE_IQ5_K_R4: return dequantize_row_iq5_k_r4_cuda; default: @@ -1717,6 +1773,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq3_k_r4_cuda; case GGML_TYPE_IQ4_K_R4: return dequantize_row_iq4_k_r4_cuda; + case GGML_TYPE_IQ4_KS_R4: + return dequantize_row_iq4_ks_r4_cuda; case GGML_TYPE_IQ5_K_R4: return dequantize_row_iq5_k_r4_cuda; default: diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 20bacd97..6b8c1a25 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -36,6 +36,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI5_XS; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + // Reminder: // constexpr int qk = ggml_cuda_type_traits::qk; @@ -309,6 +316,36 @@ __device__ __forceinline__ void vec_dot_iq4_k_r4_q8_1( } } +__device__ __forceinline__ void vec_dot_iq4_ks_r4_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { + + const float * dptr = (const float *)vbq; + const block_iq4_ks_r4 * bq4 = (const block_iq4_ks_r4 *)(dptr + 4) + kbx; + + // iqs is 0...28 in steps of 2 + const int ib16 = iqs/2; + const float d8 = __low2float(bq8_1[ib16/2].ds); + const int32_t * q8 = (const int *)bq8_1[ib16/2].qs + 4*(ib16%2); + + int ib32 = ib16/2; + int is = ib16%2; + const uint32_t * scales32 = (const uint32_t *)bq4->scales; + int scales = __vsub4(scales32[ib32] & 0xfefefefe, 0x7f7f7f7f); + const int8_t * s8 = (const int8_t *)&scales; + int2 val; + const int * q4 = (const int *)bq4->qs + 16*ib32; + for (int i = 0; i < 4; ++i) { + auto values = iq4k_values + ((bq4->scales[4*ib32+i] & 1) << 4); + int sumi = 0; + val = get_int_from_table_16(q4[i+4*is+0], values); + sumi = ggml_cuda_dp4a(val.x, q8[0], ggml_cuda_dp4a(val.y, q8[2], sumi)); + val = get_int_from_table_16(q4[i+4*is+8], values); + sumi = ggml_cuda_dp4a(val.x, q8[1], ggml_cuda_dp4a(val.y, q8[3], sumi)); + const float d = dptr[i] * d8; + result[i] += d * sumi * s8[i]; + } +} + #define VDR_IQ4_KS_Q8_1_MMVQ 4 #define VDR_IQ4_KS_Q8_1_MMQ 4 @@ -1013,6 +1050,14 @@ void mul_mat_vec_iq4_k_r4_q8_1_cuda( iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } +void mul_mat_vec_iq4_ks_r4_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +} + void mul_mat_vec_iq5_k_r4_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 228c513b..ae56de0f 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -80,3 +80,8 @@ void mul_mat_vec_iq5_k_r4_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); + +void mul_mat_vec_iq4_ks_r4_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index d7bed266..705e1be0 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -551,6 +551,9 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm case GGML_TYPE_IQ4_K_R4: mul_mat_vec_iq4_k_r4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; + case GGML_TYPE_IQ4_KS_R4: + mul_mat_vec_iq4_ks_r4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; case GGML_TYPE_IQ5_K_R4: mul_mat_vec_iq5_k_r4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; @@ -670,6 +673,7 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) { case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ3_K_R4: case GGML_TYPE_IQ4_K_R4: + case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ5_K_R4: return true; default: