diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 685d9678..ee34452a 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2436,6 +2436,9 @@ template static __device__ __forceinlin const int kbx = 0; // threadIdx.x / QI4_XS const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS + uint32_t aux32[2]; + auto a8 = (const uint8_t *)aux32; + #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { int i = i0 + threadIdx.y; @@ -2446,15 +2449,16 @@ template static __device__ __forceinlin const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0 + kbx; - const int aux_q4 = get_int_b4(bxi->qs, kqsx); - const int2 v = get_int_from_table_16(aux_q4); + const int q4 = get_int_b4(bxi->qs, kqsx); + aux32[0] = (q4 >> 0) & 0x0f0f0f0f; + aux32[1] = (q4 >> 4) & 0x0f0f0f0f; const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = int_from_table_x(a8+0, iq4k_table); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = int_from_table_x(a8+4, iq4k_table); #else - x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; - x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; + x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = int_from_table_x(a8+0, iq4k_table); + x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = int_from_table_x(a8+4, iq4k_table); #endif // INT8_MMA_AVAILABLE }