Use bperm trick for iq3_k -> 8% PP performance gain

This commit is contained in:
Iwan Kawrakow
2025-08-21 14:04:18 +03:00
parent 3eaacf235e
commit fa9e69fdd6

View File

@@ -16,8 +16,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kqsx = threadIdx.x/4; // 0...7 -> block of 32
uint32_t aux32[4];
const uint8_t * aux8 = (const uint8_t *)aux32;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
@@ -37,29 +35,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
#pragma unroll
for (int l = 0; l < 2; ++l) {
auto values_l = iq3k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 6);
//auto values_l = iq3k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 6);
uint32_t extra32 = uint32_t((bxi->extra[ir+4*l] >> kqsx) & 1) * 0x88888888;
const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l);
aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404);
aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404);
aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404);
aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404);
int val0 = int_from_table_2(aux8+ 0, values_l);
int val1 = int_from_table_2(aux8+ 4, values_l);
int val2 = int_from_table_2(aux8+ 8, values_l);
int val3 = int_from_table_2(aux8+12, values_l);
uint32_t val1 = ((ql >> 0) & 0x33333333) | extra32 | ((qh << 2) & 0x04040404) | ((qh << 4) & 0x40404040);
uint32_t val2 = ((ql >> 2) & 0x33333333) | extra32 | ((qh << 1) & 0x04040404) | ((qh << 3) & 0x40404040);
int2 v1 = get_int_from_table_16(val1, iq3nl_values);
int2 v2 = get_int_from_table_16(val2, iq3nl_values);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = val0;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = val1;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = val2;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = val3;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = v1.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = v2.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = v1.y;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = v2.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = val0;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = val1;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = val2;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = val3;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = v1.x;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = v2.x;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = v1.y;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = v2.y;
#endif // INT8_MMA_AVAILABLE
qh >>= 4;