From 1aceccae357b57c71ba0dc0f5a5d749c1ab71985 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 14 May 2025 17:06:27 +0300 Subject: [PATCH] MMQ for iq3_k --- ggml/src/ggml-cuda/mmq.cu | 4 + ggml/src/ggml-cuda/mmq.cuh | 86 +++++++++++++++++++ .../template-instances/mmq-instance-iq3_k.cu | 5 ++ 3 files changed, 95 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k.cu diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 68f4d73d..5c6f36b2 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -97,6 +97,9 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_IQ2_K: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_IQ3_K: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_IQ4_K: mul_mat_q_case(ctx, args, stream); break; @@ -145,6 +148,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ6_K: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 54de45c7..df212b97 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -83,6 +83,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: @@ -186,6 +187,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ2_K : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ3_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ4_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16; @@ -228,6 +230,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ2_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ3_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ4_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K; @@ -2441,6 +2444,80 @@ template static __device__ __forceinlin } } +template static __device__ __forceinline__ void load_tiles_iq3_k( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + constexpr int qstep = 8; + const int kqsx = threadIdx.x % qstep; + + auto values = iq3nl_values; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { + int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_k * bxi = (const block_iq3_k *)(x + i*stride) + kbx0; + + const float d = bxi->d; + + uint16_t extra = bxi->extra >> (kqsx/4); + int qh = get_int_b2(bxi->qh, kqsx); + + #pragma unroll + for (int l = 0; l < qstep/4; ++l) { + + const int ql = get_int_b2(bxi->qs, kqsx + qstep*l); + aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404) | (((extra << 3) & 8) * 0x01010101); + aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404) | (((extra << 1) & 8) * 0x01010101); + aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404) | (((extra >> 1) & 8) * 0x01010101); + aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404) | (((extra >> 3) & 8) * 0x01010101); + extra >>= 8; + qh >>= 4; + + const char4 val0 = make_char4(values[aux8[ 0]], values[aux8[ 1]], values[aux8[ 2]], values[aux8[ 3]]); + const char4 val1 = make_char4(values[aux8[ 4]], values[aux8[ 5]], values[aux8[ 6]], values[aux8[ 7]]); + const char4 val2 = make_char4(values[aux8[ 8]], values[aux8[ 9]], values[aux8[10]], values[aux8[11]]); + const char4 val3 = make_char4(values[aux8[12]], values[aux8[13]], values[aux8[14]], values[aux8[15]]); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = *(const int *)&val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = *(const int *)&val1; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = *(const int *)&val2; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = *(const int *)&val3; +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = *(const int *)&val0; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = *(const int *)&val1; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = *(const int *)&val2; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = *(const int *)&val3; +#endif // INT8_MMA_AVAILABLE + } + + uint16_t sh = bxi->scales_h >> 2*kqsx; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * ((2*((bxi->scales_l[kqsx] >> 0) & 0xf) + 1) * (sh & 1 ? -1 : 1)); + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * ((2*((bxi->scales_l[kqsx] >> 4) & 0xf) + 1) * (sh & 2 ? -1 : 1)); +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * ((2*((bxi->scales_l[kqsx] >> 0) & 0xf) + 1) * (sh & 1 ? -1 : 1)); + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * ((2*((bxi->scales_l[kqsx] >> 4) & 0xf) + 1) * (sh & 2 ? -1 : 1)); +#endif // INT8_MMA_AVAILABLE + } +} + template static __device__ __forceinline__ void load_tiles_iq4_ks( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -2933,6 +3010,14 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_k; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; @@ -3403,6 +3488,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k.cu new file mode 100644 index 00000000..7edf778c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ3_K);