cuda: slightly faster MMQ for iq4_xs

This commit is contained in:
Iwan Kawrakow
2025-07-09 11:09:14 +03:00
parent 6836839907
commit 62f5ab2cf2

View File

@@ -2436,6 +2436,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kbx = 0; // threadIdx.x / QI4_XS
const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
uint32_t aux32[2];
auto a8 = (const uint8_t *)aux32;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;
@@ -2446,15 +2449,16 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0 + kbx;
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
const int2 v = get_int_from_table_16(aux_q4);
const int q4 = get_int_b4(bxi->qs, kqsx);
aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = int_from_table_x(a8+0, iq4k_table);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = int_from_table_x(a8+4, iq4k_table);
#else
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = int_from_table_x(a8+0, iq4k_table);
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = int_from_table_x(a8+4, iq4k_table);
#endif // INT8_MMA_AVAILABLE
}