mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
cuda: slightly faster MMQ for iq4_xs
This commit is contained in:
@@ -2436,6 +2436,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
const int kbx = 0; // threadIdx.x / QI4_XS
|
||||
const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
|
||||
|
||||
uint32_t aux32[2];
|
||||
auto a8 = (const uint8_t *)aux32;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
||||
int i = i0 + threadIdx.y;
|
||||
@@ -2446,15 +2449,16 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
|
||||
const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0 + kbx;
|
||||
|
||||
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
|
||||
const int2 v = get_int_from_table_16(aux_q4);
|
||||
const int q4 = get_int_b4(bxi->qs, kqsx);
|
||||
aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
|
||||
aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
|
||||
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = int_from_table_x(a8+0, iq4k_table);
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = int_from_table_x(a8+4, iq4k_table);
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = int_from_table_x(a8+0, iq4k_table);
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = int_from_table_x(a8+4, iq4k_table);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user