Use bperm trick for iq2_k_r4 gemm -> ~3% gain

This commit is contained in:
Iwan Kawrakow
2025-08-21 18:00:17 +03:00
parent 1d91c16869
commit 693e9d1a16
4 changed files with 40 additions and 14 deletions

View File

@@ -127,3 +127,12 @@ __device__ __forceinline__ int int_from_table_x(const uint8_t * a8, const uint16
return values[a8[0] | (a8[1] << 4)] | (values[a8[2] | (a8[3] << 4)] << 16);
}
#ifdef __CUDA_ARCH__
static __device__ __forceinline__ int2 get_int_from_table_8(const int & q4, const int8_t * values) {
const uint32_t * values32 = (const uint32_t *)values;
uint32_t v1 = __byte_perm(values32[0], values32[1], q4);
uint32_t v2 = __byte_perm(values32[0], values32[1], q4 >> 16);
return make_int2(__byte_perm(v1, v2, 0x6420), __byte_perm(v1, v2, 0x7531));
}
#endif

View File

@@ -187,7 +187,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
break;
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
mmq_supported = ne11 < 2048;
mmq_supported = ne11 <= 3072;
break;
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:

View File

@@ -2554,15 +2554,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
}
#ifdef __CUDA_ARCH__
static __device__ __forceinline__ int2 get_int_from_table_8(const int & q4, const int8_t * values) {
const uint32_t * values32 = (const uint32_t *)values;
uint32_t v1 = __byte_perm(values32[0], values32[1], q4);
uint32_t v2 = __byte_perm(values32[0], values32[1], q4 >> 16);
return make_int2(__byte_perm(v1, v2, 0x6420), __byte_perm(v1, v2, 0x7531));
}
#endif
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_ks(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {

View File

@@ -14,8 +14,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
const int * all_values = (const int *)iq2k_table;
const int kqsx = threadIdx.x/4; // 0...7 -> block of 32
#pragma unroll
@@ -32,10 +30,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const float d = __half2float(bxi->d[ir]);
#pragma unroll
#ifdef __CUDA_ARCH__
#pragma unroll
for (int l = 0; l < 2; ++l) {
auto values_l = all_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 8);
uint32_t extra = uint32_t((bxi->extra[ir+4*l] >> kqsx) & 1) * 0x04040404;
extra = extra | (extra << 4);
const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l);
uint32_t val1 = ((ql >> 0) & 0x33333333) | extra;
uint32_t val2 = ((ql >> 2) & 0x33333333) | extra;
int2 v1 = get_int_from_table_8(val1, iq2nl_values);
int2 v2 = get_int_from_table_8(val2, iq2nl_values);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = v1.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = v2.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = v1.y;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = v2.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = v1.x;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = v2.x;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = v1.y;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = v2.y;
#endif // INT8_MMA_AVAILABLE
}
#else
#pragma unroll
for (int l = 0; l < 2; ++l) {
auto values_l = (const int *)iq2k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 8);
const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l);
@@ -51,6 +76,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = int_from_table_4((ql >> 6) & 0x03030303, values_l);
#endif // INT8_MMA_AVAILABLE
}
#endif // __CUDA_ARCH__
int is = 8*kqsx + ir;
float dl1 = d * (((bxi->scales[is%32] >> 4*(is/32)) & 0xf) - 8);