mmq_id: adding iq6_k

This commit is contained in:
Iwan Kawrakow
2025-08-25 17:03:16 +03:00
parent 601c6006d9
commit a6fe757cd8
3 changed files with 88 additions and 0 deletions

View File

@@ -249,6 +249,9 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx,
case GGML_TYPE_IQ5_K_R4:
mul_mat_q_case_id<GGML_TYPE_IQ5_K_R4>(ctx, args, stream);
break;
case GGML_TYPE_IQ6_K:
mul_mat_q_case_id<GGML_TYPE_IQ6_K>(ctx, args, stream);
break;
default:
GGML_ABORT("fatal error");
break;
@@ -491,6 +494,7 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_IQ5_KS_R4:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ5_K_R4:
case GGML_TYPE_IQ6_K:
mmq_supported = true;
break;
default:

View File

@@ -98,6 +98,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
case GGML_TYPE_IQ5_KS_R4:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ5_K_R4:
case GGML_TYPE_IQ6_K:
return MMQ_Q8_1_DS_LAYOUT_D4;
default:
GGML_ABORT("fatal error");
@@ -405,6 +406,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_IQ5_KS_R4: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ5_K_R4: return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16;
default: return tile_x_sizes{0, 0, 0};
}
}
@@ -460,6 +462,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_IQ5_KS_R4: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ5_K_R4: return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K;
default: return 0;
}
}
@@ -4082,5 +4085,6 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4);
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K_R4);
extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K);
// -------------------------------------------------------------------------------------------------------------------------

View File

@@ -0,0 +1,80 @@
#include "../mmq_id_common.cuh"
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq6_k(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE*2);
#else
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
constexpr int qstep = 8;
const int kqsx = threadIdx.x % qstep;
auto values = iq6nl_values;
int qh[2];
uint32_t aux32[2];
const uint8_t * aux8 = (const uint8_t *)aux32;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) {
int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep;
if (need_check) {
i = min(i, i_max);
}
const block_iq6_k * bxi = (const block_iq6_k *)(x + i*stride) + kbx0;
const float d = bxi->d;
uint16_t extra = bxi->extra >> (kqsx/4);
qh[0] = get_int_b4(bxi->qh, kqsx+0);
qh[1] = get_int_b4(bxi->qh, kqsx+8);
#pragma unroll
for (int l = 0; l < qstep/2; ++l) {
const int ql = get_int_b4(bxi->qs, kqsx + qstep*l);
aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh[l/2] & 0x03030303) << 4) | ((extra & 1) * 0x40404040);
aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh[l/2] & 0x0c0c0c0c) << 2) | ((extra & 4) * 0x10101010);
qh[l/2] >>= 4;
extra >>= 4;
const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]);
const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 0] = *(const int *)&val0;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 8] = *(const int *)&val1;
#else
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 0] = *(const int *)&val0;
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 8] = *(const int *)&val1;
#endif // INT8_MMA_AVAILABLE
}
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * bxi->scales[2*kqsx+0];
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * bxi->scales[2*kqsx+1];
#else
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * bxi->scales[2*kqsx+0];
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * bxi->scales[2*kqsx+1];
#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ6_K> {
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq6_k<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
};
DECL_MMQ_CASE(GGML_TYPE_IQ6_K);