From 8d91235b0e9b6f2ace11a0dbcfd513608314f32f Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 21 Aug 2025 11:27:39 +0300 Subject: [PATCH] Use get_int_from_table_16 everywhere for 4-bit quants --- ggml/src/ggml-cuda/iqk_mmvq.cu | 21 ++------- ggml/src/ggml-cuda/mmq.cuh | 45 +++++++------------ .../mmq-instance-iq4_kss.cu | 16 +++---- 3 files changed, 26 insertions(+), 56 deletions(-) diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index c9534d6b..c596f00e 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -409,7 +409,6 @@ __device__ __forceinline__ void vec_dot_iq4_ks_q8_1( float scale = *(const float *)vbq; const block_iq4_ks * bq4 = (const block_iq4_ks *)((const char *)vbq + sizeof(float)) + kbx; - //const uint8_t * all_values = (const uint8_t *)iq4k_values; // iqs is 0...28 const int ib32 = iqs/4; // Why iqs/4 ? @@ -417,22 +416,11 @@ __device__ __forceinline__ void vec_dot_iq4_ks_q8_1( const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32; const float dl = scale * ((bq4->scales[ib32] & 254) - 127); auto values = iq4k_values + ((bq4->scales[ib32] & 1) << 4); - //auto values = iq4k_table + ((bq4->scales[ib32] & 1) << 8); - //uint32_t aux32[2]; - //auto a8 = (const uint8_t *)aux32; - //int v1, v2; int sumi = 0; for (int j = 0; j < 4; ++j) { - //aux32[0] = (q4[j] >> 0) & 0x0f0f0f0f; - //aux32[1] = (q4[j] >> 4) & 0x0f0f0f0f; - //sumi = ggml_cuda_dp4a(int_from_table_x(a8+0, values), q8[j+0], sumi); - //sumi = ggml_cuda_dp4a(int_from_table_x(a8+4, values), q8[j+4], sumi); auto v = get_int_from_table_16(q4[j], values); sumi = ggml_cuda_dp4a(v.x, q8[j+0], sumi); sumi = ggml_cuda_dp4a(v.y, q8[j+4], sumi); - ////get_int_from_table_16_shift(q4[j], bq4->scales[ib32] & 1, all_values, v1, v2); - ////sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi); - ////sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi); } *result += dl * __low2float(bq8_1[ib32].ds) * sumi; } @@ -591,7 +579,6 @@ __device__ __forceinline__ void vec_dot_iq4_kss_q8_1( float scale = *(const float *)vbq; const block_iq4_kss * bq4 = (const block_iq4_kss *)((const char *)vbq + sizeof(float)) + kbx; - const uint8_t * all_values = (const uint8_t *)iq4k_values; // iqs is 0...28 const int ib32 = iqs/4; // Why iqs/4 ? @@ -600,14 +587,14 @@ __device__ __forceinline__ void vec_dot_iq4_kss_q8_1( uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6); uint8_t ls = (s32 | (s32 >> 15)) & 0xff; const float dl = scale * ((ls & 254) - 127); - int v1, v2; + auto values = iq4k_values + ((ls & 1) << 4); int sumi = 0; for (int j = 0; j < 4; ++j) { uint32_t aux32 = q4[j] & 0xfffefffe; aux32 ^= (aux32 >> 1); - get_int_from_table_16_shift(aux32, ls & 1, all_values, v1, v2); - sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi); - sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi); + auto v = get_int_from_table_16(aux32, values); + sumi = ggml_cuda_dp4a(v.x, q8[j+0], sumi); + sumi = ggml_cuda_dp4a(v.y, q8[j+4], sumi); } *result += dl * __low2float(bq8_1[ib32].ds) * sumi; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 0f00058f..b877545c 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2509,9 +2509,6 @@ template static __device__ __forceinlin const int kbx = 0; // threadIdx.x / QI4_XS const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS - uint32_t aux32[2]; - auto a8 = (const uint8_t *)aux32; - #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { int i = i0 + threadIdx.y; @@ -2523,15 +2520,14 @@ template static __device__ __forceinlin const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0 + kbx; const int q4 = get_int_b4(bxi->qs, kqsx); - aux32[0] = (q4 >> 0) & 0x0f0f0f0f; - aux32[1] = (q4 >> 4) & 0x0f0f0f0f; + const int2 v = get_int_from_table_16(q4); const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = int_from_table_x(a8+0, iq4k_table); - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = int_from_table_x(a8+4, iq4k_table); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else - x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = int_from_table_x(a8+0, iq4k_table); - x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = int_from_table_x(a8+4, iq4k_table); + x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; #endif // INT8_MMA_AVAILABLE } @@ -2842,9 +2838,6 @@ template static __device__ __forceinlin const int kqsx = threadIdx.x / 4; - //uint32_t aux32[2]; - //auto a8 = (const uint8_t *)aux32; - #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { int i = i0 + 4*threadIdx.y + threadIdx.x%4; @@ -2858,20 +2851,17 @@ template static __device__ __forceinlin const int ls = (bxi->scales[kqsx] & 254) - 127; auto values = iq4k_values + ((bxi->scales[kqsx] & 1) << 4); - //auto values = iq4k_table + ((bxi->scales[kqsx] & 1) << 8); #pragma unroll for (int j = 0; j < 4; ++j) { const int q4 = get_int_b4(bxi->qs, 4*kqsx+j); const int2 v = get_int_from_table_16(q4, values); - //aux32[0] = (q4 >> 0) & 0x0f0f0f0f; - //aux32[1] = (q4 >> 4) & 0x0f0f0f0f; #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x; //int_from_table_x(a8+0, values); - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y; //int_from_table_x(a8+4, values); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x; //int_from_table_x(a8+0, values); - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y; //int_from_table_x(a8+4, values); + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y; #endif // INT8_MMA_AVAILABLE } #ifdef INT8_MMA_AVAILABLE @@ -2898,9 +2888,6 @@ template static __device__ __forceinlin const int kqsx = threadIdx.x/4; - uint32_t aux32[2]; - const uint8_t * a8 = (const uint8_t *)aux32; - #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { int i = i0 + 4*threadIdx.y + threadIdx.x%4; @@ -2915,19 +2902,19 @@ template static __device__ __forceinlin const block_iq4_ks_r4 * bxi = (const block_iq4_ks_r4 *)(dptr + 4) + kbx0; const int ls = (bxi->scales[4*kqsx + ir] & 254) - 127; - auto values = iq4k_table + ((bxi->scales[4*kqsx+ir] & 1) << 8); + auto values = iq4k_values + ((bxi->scales[4*kqsx+ir] & 1) << 4); + #pragma unroll for (int j = 0; j < 4; ++j) { const int q4 = get_int_b4(bxi->qs, 16*kqsx+4*j+ir); - aux32[0] = (q4 >> 0) & 0x0f0f0f0f; - aux32[1] = (q4 >> 4) & 0x0f0f0f0f; + const int2 v = get_int_from_table_16(q4, values); const int k0 = 8*kqsx + 4*(j%2) + j/2; #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = int_from_table_x(a8+0, values); - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = int_from_table_x(a8+4, values); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = v.y; #else - x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = int_from_table_x(a8+0, values); - x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = int_from_table_x(a8+4, values); + x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = v.y; #endif // INT8_MMA_AVAILABLE } #ifdef INT8_MMA_AVAILABLE diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kss.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kss.cu index 3f107588..ce428411 100644 --- a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kss.cu +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kss.cu @@ -14,9 +14,6 @@ template static __device__ __forceinlin const int kqsx = threadIdx.x / 4; - uint32_t aux32[2]; - auto a8 = (const uint8_t *)aux32; - #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { int i = i0 + 4*threadIdx.y + threadIdx.x%4; @@ -31,20 +28,19 @@ template static __device__ __forceinlin uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6); uint8_t ls = (s32 | (s32 >> 15)) & 0xff; - auto values = iq4k_table + ((ls & 1) << 8); + auto values = iq4k_values + ((ls & 1) << 4); #pragma unroll for (int j = 0; j < 4; ++j) { uint32_t val = q4[j] & 0xfffefffe; val = val ^ (val >> 1); - aux32[0] = (val >> 0) & 0x0f0f0f0f; - aux32[1] = (val >> 4) & 0x0f0f0f0f; + auto v = get_int_from_table_16(val, values); #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = int_from_table_x(a8+0, values); - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = int_from_table_x(a8+4, values); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = int_from_table_x(a8+0, values); - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = int_from_table_x(a8+4, values); + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y; #endif // INT8_MMA_AVAILABLE } #ifdef INT8_MMA_AVAILABLE