mmq_id: adding iq5_k, iq5_k_r4, q6_0

This commit is contained in:
Iwan Kawrakow
2025-08-25 16:52:47 +03:00
parent 8f3c813ab0
commit 601c6006d9
4 changed files with 278 additions and 1 deletions

View File

@@ -153,6 +153,9 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx,
case GGML_TYPE_Q5_1:
mul_mat_q_case_id<GGML_TYPE_Q5_1>(ctx, args, stream);
break;
case GGML_TYPE_Q6_0:
mul_mat_q_case_id<GGML_TYPE_Q6_0>(ctx, args, stream);
break;
case GGML_TYPE_Q8_0:
mul_mat_q_case_id<GGML_TYPE_Q8_0>(ctx, args, stream);
break;
@@ -240,6 +243,12 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx,
case GGML_TYPE_IQ5_KS_R4:
mul_mat_q_case_id<GGML_TYPE_IQ5_KS_R4>(ctx, args, stream);
break;
case GGML_TYPE_IQ5_K:
mul_mat_q_case_id<GGML_TYPE_IQ5_K>(ctx, args, stream);
break;
case GGML_TYPE_IQ5_K_R4:
mul_mat_q_case_id<GGML_TYPE_IQ5_K_R4>(ctx, args, stream);
break;
default:
GGML_ABORT("fatal error");
break;
@@ -450,6 +459,7 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
@@ -479,6 +489,8 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_IQ4_K_R4:
case GGML_TYPE_IQ5_KS:
case GGML_TYPE_IQ5_KS_R4:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ5_K_R4:
mmq_supported = true;
break;
default:
@@ -513,7 +525,8 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) {
if (GGML_CUDA_CC_IS_CDNA3(cc)) {
return true;
}
if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0
|| type == GGML_TYPE_Q5_1 || type == GGML_TYPE_Q6_0) {
return true;
}
if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) {

View File

@@ -58,6 +58,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_Q5_1:
return MMQ_Q8_1_DS_LAYOUT_DS4;
case GGML_TYPE_Q6_0:
return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_Q8_0:
return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_MXFP4:
@@ -94,6 +96,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
case GGML_TYPE_IQ4_K_R4:
case GGML_TYPE_IQ5_KS:
case GGML_TYPE_IQ5_KS_R4:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ5_K_R4:
return MMQ_Q8_1_DS_LAYOUT_D4;
default:
GGML_ABORT("fatal error");
@@ -368,6 +372,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
case GGML_TYPE_Q6_0 : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
@@ -398,6 +403,8 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_IQ4_K_R4: return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ5_KS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ5_KS_R4: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ5_K_R4: return MMQ_DP4A_TXS_Q8_0_16;
default: return tile_x_sizes{0, 0, 0};
}
}
@@ -420,6 +427,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q6_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
@@ -450,6 +458,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_IQ4_K_R4: return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ5_KS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ5_KS_R4: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ5_K_R4: return MMQ_MMA_TILE_X_K_Q3_K;
default: return 0;
}
}
@@ -861,6 +871,71 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
}
}
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q6_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
const int kbx = threadIdx.x / QI6_0;
const int kqsx = threadIdx.x % QI6_0;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;
if (need_check) {
i = min(i, i_max);
}
const block_q6_0 * bxi = (const block_q6_0 *)(x + i*stride) + kbx0 + kbx;
const int ql = get_int_b2(bxi->qs, kqsx);
const int qh = get_int_b2(bxi->qh, kqsx%2) >> 4*(kqsx/2);
int qs0 = ((ql >> 0) & 0x0F0F0F0F) | ((qh << 4) & 0x30303030);
int qs1 = ((ql >> 4) & 0x0F0F0F0F) | ((qh << 2) & 0x30303030);
qs0 = __vsubss4(qs0, 0x20202020); // subtract 32
qs1 = __vsubss4(qs1, 0x20202020); // subtract 32
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI6_0) + kqsx + 0] = qs0;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI6_0) + kqsx + QI6_0] = qs1;
#else
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI6_0) + kqsx + 0] = qs0;
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI6_0) + kqsx + QI6_0] = qs1;
#endif // INT8_MMA_AVAILABLE
}
const int blocks_per_tile_x_row = WARP_SIZE / QI6_0;
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_0) {
int i = i0 + threadIdx.y * QI6_0 + threadIdx.x / blocks_per_tile_x_row;
if (need_check) {
i = min(i, i_max);
}
const block_q6_0 * bxi = (const block_q6_0 *)(x + i*stride) + kbx0 + kbxd;
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
#else
x_df[i*(WARP_SIZE/QI6_0) + i/QI6_0 + kbxd] = bxi->d;
#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
@@ -3162,6 +3237,13 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_0> {
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_0<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
//static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
@@ -3967,6 +4049,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
extern DECL_MMQ_CASE(GGML_TYPE_Q6_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
@@ -3997,5 +4080,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K_R4);
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4);
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K_R4);
// -------------------------------------------------------------------------------------------------------------------------

View File

@@ -0,0 +1,174 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmq_id_common.cuh"
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq5_k(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE*2);
#else
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
constexpr int qstep = 8;
const int kqsx = threadIdx.x % qstep;
auto values = iq5nl_values;
uint32_t aux32[2];
const uint8_t * aux8 = (const uint8_t *)aux32;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) {
int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep;
if (need_check) {
i = min(i, i_max);
}
const block_iq5_k * bxi = (const block_iq5_k *)(x + i*stride) + kbx0;
int qh = get_int_b4(bxi->qh, kqsx);
uint16_t extra = bxi->extra >> (kqsx/4);
#pragma unroll
for (int l = 0; l < qstep/2; ++l) {
const int ql = get_int_b4(bxi->qs, kqsx + qstep*l);
aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh & 0x01010101) << 4) | ((extra & 1) * 0x20202020); // this is very slightly faster
aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh & 0x02020202) << 3) | ((extra & 4) * 0x08080808); // then the version below
//aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh & 0x01010101) << 4) | ((extra & 1) ? 0x20202020 : 0);
//aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh & 0x02020202) << 3) | ((extra & 4) ? 0x20202020 : 0);
qh >>= 2;
extra >>= 4;
const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]);
const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 0] = *(const int *)&val0;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 8] = *(const int *)&val1;
#else
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 0] = *(const int *)&val0;
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 8] = *(const int *)&val1;
#endif // INT8_MMA_AVAILABLE
}
const uint8_t sh = bxi->scales_h[kqsx/2] >> 4*(kqsx%2);
const int ls1 = ((bxi->scales_l[kqsx] & 0xf) | ((sh << 4) & 0x30)) - 32;
const int ls2 = ((bxi->scales_l[kqsx] >> 4) | ((sh << 2) & 0x30)) - 32;
const float d = bxi->d;
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * ls1;
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * ls2;
#else
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * ls1;
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * ls2;
#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq5_k_r4(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE*2);
#else
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
const int kqsx = threadIdx.x/4; // 0...7 -> block of 32
uint32_t aux32[4];
const uint8_t * aux8 = (const uint8_t *)aux32;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
if (need_check) {
i = min(i, i_max);
}
int i4 = i/4;
int ir = i%4;
const block_iq5_k_r4 * bxi = (const block_iq5_k_r4 *)(x + 4*i4*stride) + kbx0;
const float d = __half2float(bxi->d[ir]);
int qh = get_int_b4(bxi->qh, 4*kqsx + ir);
#pragma unroll
for (int l = 0; l < 2; ++l) {
auto values_l = iq5nl_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 5);
const int ql1 = get_int_b4(bxi->qs, 16*kqsx + ir + 4*l + 0);
const int ql2 = get_int_b4(bxi->qs, 16*kqsx + ir + 4*l + 8);
aux32[0] = ((ql1 >> 0) & 0x0f0f0f0f) | ((qh << 4) & 0x10101010);
aux32[1] = ((ql1 >> 4) & 0x0f0f0f0f) | ((qh << 3) & 0x10101010);
aux32[2] = ((ql2 >> 0) & 0x0f0f0f0f) | ((qh >> 0) & 0x10101010);
aux32[3] = ((ql2 >> 4) & 0x0f0f0f0f) | ((qh >> 1) & 0x10101010);
const char4 val0 = make_char4(values_l[aux8[ 0]], values_l[aux8[ 1]], values_l[aux8[ 2]], values_l[aux8[ 3]]);
const char4 val1 = make_char4(values_l[aux8[ 4]], values_l[aux8[ 5]], values_l[aux8[ 6]], values_l[aux8[ 7]]);
const char4 val2 = make_char4(values_l[aux8[ 8]], values_l[aux8[ 9]], values_l[aux8[10]], values_l[aux8[11]]);
const char4 val3 = make_char4(values_l[aux8[12]], values_l[aux8[13]], values_l[aux8[14]], values_l[aux8[15]]);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = *(const int *)&val0;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = *(const int *)&val1;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = *(const int *)&val2;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = *(const int *)&val3;
#else
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = *(const int *)&val0;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = *(const int *)&val1;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = *(const int *)&val2;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = *(const int *)&val3;
#endif // INT8_MMA_AVAILABLE
qh >>= 2;
}
int is = 8*kqsx + ir;
float dl1 = d * ((((bxi->scales_l[is%32] >> 4*(is/32)) & 0xf) | (((bxi->scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32);
is += 4;
float dl2 = d * ((((bxi->scales_l[is%32] >> 4*(is/32)) & 0xf) | (((bxi->scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32);
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = dl1;
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = dl2;
#else
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = dl1;
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = dl2;
#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ5_K> {
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_k<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ5_K_R4> {
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_k_r4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
};
DECL_MMQ_CASE(GGML_TYPE_IQ5_K);
DECL_MMQ_CASE(GGML_TYPE_IQ5_K_R4);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmq_id_common.cuh"
DECL_MMQ_CASE(GGML_TYPE_Q6_0);