From 2078f269ef3c44e69c69194c6e04d70bbb5ce7cd Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 21 Aug 2025 13:21:33 +0300 Subject: [PATCH] Use bperm trick for iq3_ks - 5% PP performance gain --- ggml/src/ggml-cuda/mmq.cuh | 39 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index b877545c..455e20ee 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2769,8 +2769,6 @@ template static __device__ __forceinlin constexpr int qstep = 8; const int kqsx = threadIdx.x % qstep; - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; @@ -2786,33 +2784,32 @@ template static __device__ __forceinlin uint16_t extra = bxi->extra >> 8; int qh = get_int_b2(bxi->qh, kqsx); + uint32_t extra32 = extra * 0x01010101; + #pragma unroll for (int l = 0; l < qstep/4; ++l) { const int ql = get_int_b2(bxi->qs, kqsx + qstep*l); - aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404); - aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404); - aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404); - aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404); + uint32_t val1 = ((ql >> 0) & 0x33333333) | ((qh << 2) & 0x04040404) | ((extra32 << 3) & 0x08080808) + | ((qh << 4) & 0x40404040) | ((extra32 << 5) & 0x80808080); + uint32_t val2 = ((ql >> 2) & 0x33333333) | ((qh << 1) & 0x04040404) | ((extra32 << 2) & 0x08080808) + | ((qh << 3) & 0x40404040) | ((extra32 << 4) & 0x80808080); + int2 v1 = get_int_from_table_16(val1, iq3nl_values); + int2 v2 = get_int_from_table_16(val2, iq3nl_values); - const int val0 = int_from_table_2(aux8+ 0, iq3k_table + ((extra << 6) & 0x40)); - const int val1 = int_from_table_2(aux8+ 4, iq3k_table + ((extra << 5) & 0x40)); - const int val2 = int_from_table_2(aux8+ 8, iq3k_table + ((extra << 4) & 0x40)); - const int val3 = int_from_table_2(aux8+12, iq3k_table + ((extra << 3) & 0x40)); - - extra >>= 4; - qh >>= 4; + extra32 >>= 4; + qh >>= 4; #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 0] = val0; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 8] = val1; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 16] = val2; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 24] = val3; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 0] = v1.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 8] = v2.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 16] = v1.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 24] = v2.y; #else - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = val0; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = val1; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = val2; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = val3; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = v1.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = v2.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = v1.y; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = v2.y; #endif // INT8_MMA_AVAILABLE }