cuda: slightly faster MMQ for iq4_ks

This commit is contained in:
Iwan Kawrakow
2025-07-09 10:46:21 +03:00
parent b5751170d3
commit 6836839907

View File

@@ -2765,6 +2765,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kqsx = threadIdx.x / 4;
uint32_t aux32[2];
auto a8 = (const uint8_t *)aux32;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
@@ -2776,18 +2779,20 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const float * dptr = (const float *)(x + i*stride);
const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
const int ls = (bxi->scales[kqsx] & 254) - 127;
auto values = iq4k_values + ((bxi->scales[kqsx] & 1) << 4);
auto values = iq4k_table + ((bxi->scales[kqsx] & 1) << 8);
#pragma unroll
for (int j = 0; j < 4; ++j) {
const int aux_q4 = get_int_b4(bxi->qs, 4*kqsx+j);
const int2 v = get_int_from_table_16(aux_q4, values);
const int q4 = get_int_b4(bxi->qs, 4*kqsx+j);
aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
#else
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
#endif // INT8_MMA_AVAILABLE
}
#ifdef INT8_MMA_AVAILABLE