mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
Use bperm trick for iq2_ks gemm -> 7% gain
This commit is contained in:
@@ -2554,6 +2554,15 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
static __device__ __forceinline__ int2 get_int_from_table_8(const int & q4, const int8_t * values) {
|
||||
const uint32_t * values32 = (const uint32_t *)values;
|
||||
uint32_t v1 = __byte_perm(values32[0], values32[1], q4);
|
||||
uint32_t v2 = __byte_perm(values32[0], values32[1], q4 >> 16);
|
||||
return make_int2(__byte_perm(v1, v2, 0x6420), __byte_perm(v1, v2, 0x7531));
|
||||
}
|
||||
#endif
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_ks(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
||||
|
||||
@@ -2566,11 +2575,45 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
const int * all_values = (const int *)iq2k_table;
|
||||
|
||||
const int kqsx = threadIdx.x%16;
|
||||
|
||||
#pragma unroll
|
||||
#ifdef __CUDA_ARCH__
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += 2*nwarps) {
|
||||
int i = i0 + 2*threadIdx.y + threadIdx.x/16;
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_iq2_ks * bxi = (const block_iq2_ks *)(x + i*stride + sizeof(half)) + kbx0;
|
||||
|
||||
uint16_t extra = bxi->extra >> 4*(kqsx/8);
|
||||
int q2 = get_int_b2(bxi->qs, kqsx);
|
||||
|
||||
uint32_t extra32 = uint32_t(extra & 0xf) * 0x01010101;
|
||||
uint32_t val1 = ((q2 >> 0) & 0x33333333) | ((extra32 << 2) & 0x04040404) | ((extra32 << 4) & 0x40404040);
|
||||
uint32_t val2 = ((q2 >> 2) & 0x33333333) | ((extra32 << 1) & 0x04040404) | ((extra32 << 3) & 0x40404040);
|
||||
int2 v1 = get_int_from_table_8(val1, iq2nl_values);
|
||||
int2 v2 = get_int_from_table_8(val2, iq2nl_values);
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 0] = v1.x;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 8] = v2.x;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 16] = v1.y;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 24] = v2.y;
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 0] = v1.x;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 8] = v2.x;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 16] = v1.y;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 24] = v2.y;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
#else // __CUDA_ARCH__
|
||||
|
||||
const int * all_values = (const int *)iq2k_table;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += 2*nwarps) {
|
||||
int i = i0 + 2*threadIdx.y + threadIdx.x/16;
|
||||
|
||||
@@ -2595,6 +2638,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 24] = int_from_table_4((q2 >> 6) & 0x03030303, all_values + ((extra & 8) << 5));
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
#endif // __CUDA_ARCH__
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
|
||||
|
||||
Reference in New Issue
Block a user