From 89a94f978df87520cf16729a0cf043e172f1fe07 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 8 Jul 2025 20:23:42 +0300 Subject: [PATCH] cuda: slightly faster MMQ for iq3_k, iq3_k_r4 --- ggml/src/ggml-cuda/iqk_cuda_common.h | 15 +++++ ggml/src/ggml-cuda/iqk_mmvq.cu | 15 ----- ggml/src/ggml-cuda/mmq.cuh | 60 +++++++------------ .../mmq-instance-iq3_k_r4.cu | 26 ++++---- 4 files changed, 50 insertions(+), 66 deletions(-) diff --git a/ggml/src/ggml-cuda/iqk_cuda_common.h b/ggml/src/ggml-cuda/iqk_cuda_common.h index 95d9b40e..9ed073d3 100644 --- a/ggml/src/ggml-cuda/iqk_cuda_common.h +++ b/ggml/src/ggml-cuda/iqk_cuda_common.h @@ -73,3 +73,18 @@ __device__ __forceinline__ int int_from_table_4(const uint32_t idx, const int * return values[ggml_cuda_dp4a(idx, 0x40100401, 0)]; } +static const __device__ uint16_t iq3k_table[128] = { + 0xc1c1, 0xc1d8, 0xc1e9, 0xc1f6, 0xc101, 0xc10d, 0xc11c, 0xc12f, 0xd8c1, 0xd8d8, 0xd8e9, 0xd8f6, 0xd801, 0xd80d, 0xd81c, 0xd82f, + 0xe9c1, 0xe9d8, 0xe9e9, 0xe9f6, 0xe901, 0xe90d, 0xe91c, 0xe92f, 0xf6c1, 0xf6d8, 0xf6e9, 0xf6f6, 0xf601, 0xf60d, 0xf61c, 0xf62f, + 0x01c1, 0x01d8, 0x01e9, 0x01f6, 0x0101, 0x010d, 0x011c, 0x012f, 0x0dc1, 0x0dd8, 0x0de9, 0x0df6, 0x0d01, 0x0d0d, 0x0d1c, 0x0d2f, + 0x1cc1, 0x1cd8, 0x1ce9, 0x1cf6, 0x1c01, 0x1c0d, 0x1c1c, 0x1c2f, 0x2fc1, 0x2fd8, 0x2fe9, 0x2ff6, 0x2f01, 0x2f0d, 0x2f1c, 0x2f2f, + 0xc5c5, 0xc5dc, 0xc5ed, 0xc5fa, 0xc505, 0xc511, 0xc520, 0xc533, 0xdcc5, 0xdcdc, 0xdced, 0xdcfa, 0xdc05, 0xdc11, 0xdc20, 0xdc33, + 0xedc5, 0xeddc, 0xeded, 0xedfa, 0xed05, 0xed11, 0xed20, 0xed33, 0xfac5, 0xfadc, 0xfaed, 0xfafa, 0xfa05, 0xfa11, 0xfa20, 0xfa33, + 0x05c5, 0x05dc, 0x05ed, 0x05fa, 0x0505, 0x0511, 0x0520, 0x0533, 0x11c5, 0x11dc, 0x11ed, 0x11fa, 0x1105, 0x1111, 0x1120, 0x1133, + 0x20c5, 0x20dc, 0x20ed, 0x20fa, 0x2005, 0x2011, 0x2020, 0x2033, 0x33c5, 0x33dc, 0x33ed, 0x33fa, 0x3305, 0x3311, 0x3320, 0x3333, +}; + +__device__ __forceinline__ int int_from_table_2(const uint8_t * a8, const uint16_t * values) { + return values[a8[0] | (a8[1] << 3)] | (values[a8[2] | (a8[3] << 3)] << 16); +} + diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 54d03f78..d897063f 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -950,21 +950,6 @@ __device__ __forceinline__ void vec_dot_iq2_k_r4_q8_1( #define VDR_IQ3_K_Q8_1_MMVQ 4 #define VDR_IQ3_K_Q8_1_MMQ 4 -static const __device__ uint16_t iq3k_table[128] = { - 0xc1c1, 0xc1d8, 0xc1e9, 0xc1f6, 0xc101, 0xc10d, 0xc11c, 0xc12f, 0xd8c1, 0xd8d8, 0xd8e9, 0xd8f6, 0xd801, 0xd80d, 0xd81c, 0xd82f, - 0xe9c1, 0xe9d8, 0xe9e9, 0xe9f6, 0xe901, 0xe90d, 0xe91c, 0xe92f, 0xf6c1, 0xf6d8, 0xf6e9, 0xf6f6, 0xf601, 0xf60d, 0xf61c, 0xf62f, - 0x01c1, 0x01d8, 0x01e9, 0x01f6, 0x0101, 0x010d, 0x011c, 0x012f, 0x0dc1, 0x0dd8, 0x0de9, 0x0df6, 0x0d01, 0x0d0d, 0x0d1c, 0x0d2f, - 0x1cc1, 0x1cd8, 0x1ce9, 0x1cf6, 0x1c01, 0x1c0d, 0x1c1c, 0x1c2f, 0x2fc1, 0x2fd8, 0x2fe9, 0x2ff6, 0x2f01, 0x2f0d, 0x2f1c, 0x2f2f, - 0xc5c5, 0xc5dc, 0xc5ed, 0xc5fa, 0xc505, 0xc511, 0xc520, 0xc533, 0xdcc5, 0xdcdc, 0xdced, 0xdcfa, 0xdc05, 0xdc11, 0xdc20, 0xdc33, - 0xedc5, 0xeddc, 0xeded, 0xedfa, 0xed05, 0xed11, 0xed20, 0xed33, 0xfac5, 0xfadc, 0xfaed, 0xfafa, 0xfa05, 0xfa11, 0xfa20, 0xfa33, - 0x05c5, 0x05dc, 0x05ed, 0x05fa, 0x0505, 0x0511, 0x0520, 0x0533, 0x11c5, 0x11dc, 0x11ed, 0x11fa, 0x1105, 0x1111, 0x1120, 0x1133, - 0x20c5, 0x20dc, 0x20ed, 0x20fa, 0x2005, 0x2011, 0x2020, 0x2033, 0x33c5, 0x33dc, 0x33ed, 0x33fa, 0x3305, 0x3311, 0x3320, 0x3333, -}; - -__device__ __forceinline__ int int_from_table_2(const uint8_t * a8, const uint16_t * values) { - return values[a8[0] | (a8[1] << 3)] | (values[a8[2] | (a8[3] << 3)] << 16); -} - __device__ __forceinline__ void vec_dot_iq3_k_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iiqs, float * result) { const block_iq3_k * bq3 = (const block_iq3_k *) vbq + kbx; diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 8a87e3e8..f2d0d735 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2623,8 +2623,6 @@ template static __device__ __forceinlin constexpr int qstep = 8; const int kqsx = threadIdx.x % qstep; - auto values = iq3nl_values; - uint32_t aux32[4]; const uint8_t * aux8 = (const uint8_t *)aux32; #pragma unroll @@ -2646,57 +2644,43 @@ template static __device__ __forceinlin for (int l = 0; l < qstep/4; ++l) { const int ql = get_int_b2(bxi->qs, kqsx + qstep*l); - aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404) | (((extra << 3) & 8) * 0x01010101); - aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404) | (((extra << 1) & 8) * 0x01010101); - aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404) | (((extra >> 1) & 8) * 0x01010101); - aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404) | (((extra >> 3) & 8) * 0x01010101); + aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404); + aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404); + aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404); + aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404); + + const int val0 = int_from_table_2(aux8+ 0, iq3k_table + ((extra << 6) & 0x40)); + const int val1 = int_from_table_2(aux8+ 4, iq3k_table + ((extra << 4) & 0x40)); + const int val2 = int_from_table_2(aux8+ 8, iq3k_table + ((extra << 2) & 0x40)); + const int val3 = int_from_table_2(aux8+12, iq3k_table + ((extra << 0) & 0x40)); + extra >>= 8; qh >>= 4; - const char4 val0 = make_char4(values[aux8[ 0]], values[aux8[ 1]], values[aux8[ 2]], values[aux8[ 3]]); - const char4 val1 = make_char4(values[aux8[ 4]], values[aux8[ 5]], values[aux8[ 6]], values[aux8[ 7]]); - const char4 val2 = make_char4(values[aux8[ 8]], values[aux8[ 9]], values[aux8[10]], values[aux8[11]]); - const char4 val3 = make_char4(values[aux8[12]], values[aux8[13]], values[aux8[14]], values[aux8[15]]); - #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = *(const int *)&val0; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = *(const int *)&val1; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = *(const int *)&val2; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = *(const int *)&val3; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = val1; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = val2; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = val3; #else - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = *(const int *)&val0; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = *(const int *)&val1; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = *(const int *)&val2; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = *(const int *)&val3; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = val0; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = val1; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = val2; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = val3; #endif // INT8_MMA_AVAILABLE } uint16_t sh = bxi->scales_h >> 2*kqsx; #ifdef INT8_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * ((2*((bxi->scales_l[kqsx] >> 0) & 0xf) + 1) * (sh & 1 ? -1 : 1)); - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * ((2*((bxi->scales_l[kqsx] >> 4) & 0xf) + 1) * (sh & 2 ? -1 : 1)); + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * ((2*(bxi->scales_l[kqsx] & 0xf) + 1) * (sh & 1 ? -1 : 1)); + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * ((2*(bxi->scales_l[kqsx] >> 4) + 1) * (sh & 2 ? -1 : 1)); #else - x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * ((2*((bxi->scales_l[kqsx] >> 0) & 0xf) + 1) * (sh & 1 ? -1 : 1)); - x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * ((2*((bxi->scales_l[kqsx] >> 4) & 0xf) + 1) * (sh & 2 ? -1 : 1)); + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * ((2*(bxi->scales_l[kqsx] & 0xf) + 1) * (sh & 1 ? -1 : 1)); + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * ((2*(bxi->scales_l[kqsx] >> 4) + 1) * (sh & 2 ? -1 : 1)); #endif // INT8_MMA_AVAILABLE } } -static const __device__ uint16_t iq3k_table[128] = { - 0xc1c1, 0xc1d8, 0xc1e9, 0xc1f6, 0xc101, 0xc10d, 0xc11c, 0xc12f, 0xd8c1, 0xd8d8, 0xd8e9, 0xd8f6, 0xd801, 0xd80d, 0xd81c, 0xd82f, - 0xe9c1, 0xe9d8, 0xe9e9, 0xe9f6, 0xe901, 0xe90d, 0xe91c, 0xe92f, 0xf6c1, 0xf6d8, 0xf6e9, 0xf6f6, 0xf601, 0xf60d, 0xf61c, 0xf62f, - 0x01c1, 0x01d8, 0x01e9, 0x01f6, 0x0101, 0x010d, 0x011c, 0x012f, 0x0dc1, 0x0dd8, 0x0de9, 0x0df6, 0x0d01, 0x0d0d, 0x0d1c, 0x0d2f, - 0x1cc1, 0x1cd8, 0x1ce9, 0x1cf6, 0x1c01, 0x1c0d, 0x1c1c, 0x1c2f, 0x2fc1, 0x2fd8, 0x2fe9, 0x2ff6, 0x2f01, 0x2f0d, 0x2f1c, 0x2f2f, - 0xc5c5, 0xc5dc, 0xc5ed, 0xc5fa, 0xc505, 0xc511, 0xc520, 0xc533, 0xdcc5, 0xdcdc, 0xdced, 0xdcfa, 0xdc05, 0xdc11, 0xdc20, 0xdc33, - 0xedc5, 0xeddc, 0xeded, 0xedfa, 0xed05, 0xed11, 0xed20, 0xed33, 0xfac5, 0xfadc, 0xfaed, 0xfafa, 0xfa05, 0xfa11, 0xfa20, 0xfa33, - 0x05c5, 0x05dc, 0x05ed, 0x05fa, 0x0505, 0x0511, 0x0520, 0x0533, 0x11c5, 0x11dc, 0x11ed, 0x11fa, 0x1105, 0x1111, 0x1120, 0x1133, - 0x20c5, 0x20dc, 0x20ed, 0x20fa, 0x2005, 0x2011, 0x2020, 0x2033, 0x33c5, 0x33dc, 0x33ed, 0x33fa, 0x3305, 0x3311, 0x3320, 0x3333, -}; - -__device__ __forceinline__ int int_from_table_2(const uint8_t * a8, const uint16_t * values) { - return values[a8[0] | (a8[1] << 3)] | (values[a8[2] | (a8[3] << 3)] << 16); -} - template static __device__ __forceinline__ void load_tiles_iq3_ks( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_r4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_r4.cu index 7b6096a3..a588969f 100644 --- a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_r4.cu +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_r4.cu @@ -37,7 +37,7 @@ template static __device__ __forceinlin #pragma unroll for (int l = 0; l < 2; ++l) { - auto values_l = iq3nl_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 3); + auto values_l = iq3k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 6); const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l); aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404); @@ -45,21 +45,21 @@ template static __device__ __forceinlin aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404); aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404); - const char4 val0 = make_char4(values_l[aux8[ 0]], values_l[aux8[ 1]], values_l[aux8[ 2]], values_l[aux8[ 3]]); - const char4 val1 = make_char4(values_l[aux8[ 4]], values_l[aux8[ 5]], values_l[aux8[ 6]], values_l[aux8[ 7]]); - const char4 val2 = make_char4(values_l[aux8[ 8]], values_l[aux8[ 9]], values_l[aux8[10]], values_l[aux8[11]]); - const char4 val3 = make_char4(values_l[aux8[12]], values_l[aux8[13]], values_l[aux8[14]], values_l[aux8[15]]); + int val0 = int_from_table_2(aux8+ 0, values_l); + int val1 = int_from_table_2(aux8+ 4, values_l); + int val2 = int_from_table_2(aux8+ 8, values_l); + int val3 = int_from_table_2(aux8+12, values_l); #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = *(const int *)&val0; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = *(const int *)&val1; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = *(const int *)&val2; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = *(const int *)&val3; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = val1; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = val2; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = val3; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = *(const int *)&val0; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = *(const int *)&val1; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = *(const int *)&val2; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = *(const int *)&val3; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = val0; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = val1; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = val2; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = val3; #endif // INT8_MMA_AVAILABLE qh >>= 4;