mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-02 01:50:01 +00:00
Use bperm trick for iq2_k_r4 gemm -> ~3% gain
This commit is contained in:
@@ -127,3 +127,12 @@ __device__ __forceinline__ int int_from_table_x(const uint8_t * a8, const uint16
|
||||
return values[a8[0] | (a8[1] << 4)] | (values[a8[2] | (a8[3] << 4)] << 16);
|
||||
}
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
static __device__ __forceinline__ int2 get_int_from_table_8(const int & q4, const int8_t * values) {
|
||||
const uint32_t * values32 = (const uint32_t *)values;
|
||||
uint32_t v1 = __byte_perm(values32[0], values32[1], q4);
|
||||
uint32_t v2 = __byte_perm(values32[0], values32[1], q4 >> 16);
|
||||
return make_int2(__byte_perm(v1, v2, 0x6420), __byte_perm(v1, v2, 0x7531));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
break;
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ2_K_R4:
|
||||
mmq_supported = ne11 < 2048;
|
||||
mmq_supported = ne11 <= 3072;
|
||||
break;
|
||||
case GGML_TYPE_IQ3_K:
|
||||
case GGML_TYPE_IQ4_K:
|
||||
|
||||
@@ -2554,15 +2554,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
static __device__ __forceinline__ int2 get_int_from_table_8(const int & q4, const int8_t * values) {
|
||||
const uint32_t * values32 = (const uint32_t *)values;
|
||||
uint32_t v1 = __byte_perm(values32[0], values32[1], q4);
|
||||
uint32_t v2 = __byte_perm(values32[0], values32[1], q4 >> 16);
|
||||
return make_int2(__byte_perm(v1, v2, 0x6420), __byte_perm(v1, v2, 0x7531));
|
||||
}
|
||||
#endif
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_ks(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
||||
|
||||
|
||||
@@ -14,8 +14,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
const int * all_values = (const int *)iq2k_table;
|
||||
|
||||
const int kqsx = threadIdx.x/4; // 0...7 -> block of 32
|
||||
|
||||
#pragma unroll
|
||||
@@ -32,10 +30,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
|
||||
const float d = __half2float(bxi->d[ir]);
|
||||
|
||||
#pragma unroll
|
||||
#ifdef __CUDA_ARCH__
|
||||
#pragma unroll
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
|
||||
auto values_l = all_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 8);
|
||||
uint32_t extra = uint32_t((bxi->extra[ir+4*l] >> kqsx) & 1) * 0x04040404;
|
||||
extra = extra | (extra << 4);
|
||||
|
||||
const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l);
|
||||
uint32_t val1 = ((ql >> 0) & 0x33333333) | extra;
|
||||
uint32_t val2 = ((ql >> 2) & 0x33333333) | extra;
|
||||
int2 v1 = get_int_from_table_8(val1, iq2nl_values);
|
||||
int2 v2 = get_int_from_table_8(val2, iq2nl_values);
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = v1.x;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = v2.x;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = v1.y;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = v2.y;
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = v1.x;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = v2.x;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = v1.y;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = v2.y;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
|
||||
auto values_l = (const int *)iq2k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 8);
|
||||
|
||||
const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l);
|
||||
|
||||
@@ -51,6 +76,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = int_from_table_4((ql >> 6) & 0x03030303, values_l);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
#endif // __CUDA_ARCH__
|
||||
|
||||
int is = 8*kqsx + ir;
|
||||
float dl1 = d * (((bxi->scales[is%32] >> 4*(is/32)) & 0xf) - 8);
|
||||
|
||||
Reference in New Issue
Block a user