CUDA: quantized GEMM for for IQ4_K, IQ5_K, IQ6_K (#417)

* MMQ for iq4_k: WIP (not working)

* MMQ for iq4_k: working now

* MMQ for iq5_k

* Cleanup

* MMQ for iq5_k: slightly faster

* MMQ for iq6_k

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-05-14 14:04:11 +03:00
committed by GitHub
parent fba62d61c0
commit 51db1bf2d2
5 changed files with 268 additions and 0 deletions

View File

@@ -94,6 +94,15 @@ void ggml_cuda_op_mul_mat_q(
case GGML_TYPE_IQ4_KS:
mul_mat_q_case<GGML_TYPE_IQ4_KS>(ctx, args, stream);
break;
case GGML_TYPE_IQ4_K:
mul_mat_q_case<GGML_TYPE_IQ4_K>(ctx, args, stream);
break;
case GGML_TYPE_IQ5_K:
mul_mat_q_case<GGML_TYPE_IQ5_K>(ctx, args, stream);
break;
case GGML_TYPE_IQ6_K:
mul_mat_q_case<GGML_TYPE_IQ6_K>(ctx, args, stream);
break;
default:
GGML_ABORT("fatal error");
break;
@@ -132,6 +141,9 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
mmq_supported = true;
break;
default:

View File

@@ -83,6 +83,9 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
return MMQ_Q8_1_DS_LAYOUT_D4;
default:
GGML_ABORT("fatal error");
@@ -181,6 +184,9 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_K : return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16;
default : return tile_x_sizes{0, 0, 0};
}
}
@@ -219,6 +225,9 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_K : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K;
default : return 0;
}
}
@@ -2416,6 +2425,211 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_k(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE*2);
#else
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
constexpr int qstep = 8;
const int kqsx = threadIdx.x % qstep;
uint32_t aux32[2];
const uint8_t * aux8 = (const uint8_t *)aux32;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) {
int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep;
if (need_check) {
i = min(i, i_max);
}
const block_iq4_k * bxi = (const block_iq4_k *)(x + i*stride) + kbx0;
const uint16_t extra = bxi->extra >> 2*kqsx;
auto values_l = iq4k_values + ((extra & 1) << 4);
auto values_h = iq4k_values + ((extra & 2) << 3);
#pragma unroll
for (int l = 0; l < qstep/2; ++l) {
const int q4 = get_int_b4(bxi->qs, (qstep/2)*kqsx + l);
aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
const char4 val0 = make_char4(values_l[aux8[0]], values_l[aux8[1]], values_l[aux8[2]], values_l[aux8[3]]);
const char4 val1 = make_char4(values_h[aux8[4]], values_h[aux8[5]], values_h[aux8[6]], values_h[aux8[7]]);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + l + 0] = *(const int *)&val0;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + l + 4] = *(const int *)&val1;
#else
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + l + 0] = *(const int *)&val0;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + l + 4] = *(const int *)&val1;
#endif // INT8_MMA_AVAILABLE
}
const uint8_t sh = bxi->scales_h[kqsx/2] >> 4*(kqsx%2);
const int ls1 = ((bxi->scales_l[kqsx] & 0xf) | ((sh << 4) & 0x30)) - 32;
const int ls2 = ((bxi->scales_l[kqsx] >> 4) | ((sh << 2) & 0x30)) - 32;
const float d = bxi->d;
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * ls1;
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * ls2;
#else
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * ls1;
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * ls2;
#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq5_k(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE*2);
#else
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
constexpr int qstep = 8;
const int kqsx = threadIdx.x % qstep;
auto values = iq5nl_values;
uint32_t aux32[2];
const uint8_t * aux8 = (const uint8_t *)aux32;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) {
int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep;
if (need_check) {
i = min(i, i_max);
}
const block_iq5_k * bxi = (const block_iq5_k *)(x + i*stride) + kbx0;
int qh = get_int_b4(bxi->qh, kqsx);
uint16_t extra = bxi->extra >> (kqsx/4);
#pragma unroll
for (int l = 0; l < qstep/2; ++l) {
const int ql = get_int_b4(bxi->qs, kqsx + qstep*l);
aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh & 0x01010101) << 4) | ((extra & 1) * 0x20202020); // this is very slightly faster
aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh & 0x02020202) << 3) | ((extra & 4) * 0x08080808); // then the version below
//aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh & 0x01010101) << 4) | ((extra & 1) ? 0x20202020 : 0);
//aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh & 0x02020202) << 3) | ((extra & 4) ? 0x20202020 : 0);
qh >>= 2;
extra >>= 4;
const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]);
const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 0] = *(const int *)&val0;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 8] = *(const int *)&val1;
#else
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 0] = *(const int *)&val0;
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 8] = *(const int *)&val1;
#endif // INT8_MMA_AVAILABLE
}
const uint8_t sh = bxi->scales_h[kqsx/2] >> 4*(kqsx%2);
const int ls1 = ((bxi->scales_l[kqsx] & 0xf) | ((sh << 4) & 0x30)) - 32;
const int ls2 = ((bxi->scales_l[kqsx] >> 4) | ((sh << 2) & 0x30)) - 32;
const float d = bxi->d;
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * ls1;
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * ls2;
#else
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * ls1;
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * ls2;
#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq6_k(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE*2);
#else
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
constexpr int qstep = 8;
const int kqsx = threadIdx.x % qstep;
auto values = iq6nl_values;
int qh[2];
uint32_t aux32[2];
const uint8_t * aux8 = (const uint8_t *)aux32;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) {
int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep;
if (need_check) {
i = min(i, i_max);
}
const block_iq6_k * bxi = (const block_iq6_k *)(x + i*stride) + kbx0;
const float d = bxi->d;
uint16_t extra = bxi->extra >> (kqsx/4);
qh[0] = get_int_b4(bxi->qh, kqsx+0);
qh[1] = get_int_b4(bxi->qh, kqsx+8);
#pragma unroll
for (int l = 0; l < qstep/2; ++l) {
const int ql = get_int_b4(bxi->qs, kqsx + qstep*l);
aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh[l/2] & 0x03030303) << 4) | ((extra & 1) * 0x40404040);
aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh[l/2] & 0x0c0c0c0c) << 2) | ((extra & 4) * 0x10101010);
qh[l/2] >>= 4;
extra >>= 4;
const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]);
const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 0] = *(const int *)&val0;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 8] = *(const int *)&val1;
#else
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 0] = *(const int *)&val0;
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 8] = *(const int *)&val1;
#endif // INT8_MMA_AVAILABLE
}
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * bxi->scales[2*kqsx+0];
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * bxi->scales[2*kqsx+1];
#else
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * bxi->scales[2*kqsx+0];
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * bxi->scales[2*kqsx+1];
#endif // INT8_MMA_AVAILABLE
}
}
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void mmq_write_back_dp4a(
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
@@ -2637,6 +2851,30 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_K> {
static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_k<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ5_K> {
static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_k<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ6_K> {
static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq6_k<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KS> {
static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
@@ -3082,6 +3320,9 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K);
// -------------------------------------------------------------------------------------------------------------------------

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmq.cuh"
DECL_MMQ_CASE(GGML_TYPE_IQ4_K);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmq.cuh"
DECL_MMQ_CASE(GGML_TYPE_IQ5_K);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmq.cuh"
DECL_MMQ_CASE(GGML_TYPE_IQ6_K);