From 693e9d1a1618ff01391c24e588c0941b10958be5 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 21 Aug 2025 18:00:17 +0300 Subject: [PATCH] Use bperm trick for iq2_k_r4 gemm -> ~3% gain --- ggml/src/ggml-cuda/iqk_cuda_common.h | 9 +++++ ggml/src/ggml-cuda/mmq.cu | 2 +- ggml/src/ggml-cuda/mmq.cuh | 9 ----- .../mmq-instance-iq2_k_r4.cu | 34 ++++++++++++++++--- 4 files changed, 40 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cuda/iqk_cuda_common.h b/ggml/src/ggml-cuda/iqk_cuda_common.h index fbe655c4..3548900e 100644 --- a/ggml/src/ggml-cuda/iqk_cuda_common.h +++ b/ggml/src/ggml-cuda/iqk_cuda_common.h @@ -127,3 +127,12 @@ __device__ __forceinline__ int int_from_table_x(const uint8_t * a8, const uint16 return values[a8[0] | (a8[1] << 4)] | (values[a8[2] | (a8[3] << 4)] << 16); } +#ifdef __CUDA_ARCH__ +static __device__ __forceinline__ int2 get_int_from_table_8(const int & q4, const int8_t * values) { + const uint32_t * values32 = (const uint32_t *)values; + uint32_t v1 = __byte_perm(values32[0], values32[1], q4); + uint32_t v2 = __byte_perm(values32[0], values32[1], q4 >> 16); + return make_int2(__byte_perm(v1, v2, 0x6420), __byte_perm(v1, v2, 0x7531)); +} +#endif + diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index bebc7c87..82514c8d 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -187,7 +187,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { break; case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: - mmq_supported = ne11 < 2048; + mmq_supported = ne11 <= 3072; break; case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 613463dc..648e8d71 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2554,15 +2554,6 @@ template static __device__ __forceinlin } } -#ifdef __CUDA_ARCH__ -static __device__ __forceinline__ int2 get_int_from_table_8(const int & q4, const int8_t * values) { - const uint32_t * values32 = (const uint32_t *)values; - uint32_t v1 = __byte_perm(values32[0], values32[1], q4); - uint32_t v2 = __byte_perm(values32[0], values32[1], q4 >> 16); - return make_int2(__byte_perm(v1, v2, 0x6420), __byte_perm(v1, v2, 0x7531)); -} -#endif - template static __device__ __forceinline__ void load_tiles_iq2_ks( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_r4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_r4.cu index e40d55a0..f0f4a8f0 100644 --- a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_r4.cu +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_r4.cu @@ -14,8 +14,6 @@ template static __device__ __forceinlin float * x_df = (float *) (x_qs + txs.qs); #endif // INT8_MMA_AVAILABLE - const int * all_values = (const int *)iq2k_table; - const int kqsx = threadIdx.x/4; // 0...7 -> block of 32 #pragma unroll @@ -32,10 +30,37 @@ template static __device__ __forceinlin const float d = __half2float(bxi->d[ir]); - #pragma unroll +#ifdef __CUDA_ARCH__ + #pragma unroll for (int l = 0; l < 2; ++l) { - auto values_l = all_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 8); + uint32_t extra = uint32_t((bxi->extra[ir+4*l] >> kqsx) & 1) * 0x04040404; + extra = extra | (extra << 4); + + const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l); + uint32_t val1 = ((ql >> 0) & 0x33333333) | extra; + uint32_t val2 = ((ql >> 2) & 0x33333333) | extra; + int2 v1 = get_int_from_table_8(val1, iq2nl_values); + int2 v2 = get_int_from_table_8(val2, iq2nl_values); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = v1.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = v2.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = v1.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = v2.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = v1.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = v2.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = v1.y; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = v2.y; +#endif // INT8_MMA_AVAILABLE + } + +#else + #pragma unroll + for (int l = 0; l < 2; ++l) { + + auto values_l = (const int *)iq2k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 8); const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l); @@ -51,6 +76,7 @@ template static __device__ __forceinlin x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = int_from_table_4((ql >> 6) & 0x03030303, values_l); #endif // INT8_MMA_AVAILABLE } +#endif // __CUDA_ARCH__ int is = 8*kqsx + ir; float dl1 = d * (((bxi->scales[is%32] >> 4*(is/32)) & 0xf) - 8);