mmq_id: add iq2_k, iq2_k_r4

This commit is contained in:
Iwan Kawrakow
2025-08-25 08:52:32 +03:00
parent e9afb0b8fc
commit 16e477a945
3 changed files with 248 additions and 19 deletions

View File

@@ -198,6 +198,12 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx,
case GGML_TYPE_IQ4_NL:
mul_mat_q_case_id<GGML_TYPE_IQ4_NL>(ctx, args, stream);
break;
case GGML_TYPE_IQ2_K:
mul_mat_q_case_id<GGML_TYPE_IQ2_K>(ctx, args, stream);
break;
case GGML_TYPE_IQ2_K_R4:
mul_mat_q_case_id<GGML_TYPE_IQ2_K_R4>(ctx, args, stream);
break;
default:
GGML_ABORT("fatal error");
break;
@@ -423,6 +429,8 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
mmq_supported = true;
break;
default:

View File

@@ -3,6 +3,7 @@
#include "common.cuh"
#include "mma_new.cuh"
#include "vecdotq.cuh"
#include "iqk_cuda_common.h"
#include <vector>
#include <climits>
@@ -79,6 +80,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
return MMQ_Q8_1_DS_LAYOUT_DS4;
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
return MMQ_Q8_1_DS_LAYOUT_D4;
default:
GGML_ABORT("fatal error");
@@ -368,6 +371,9 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0;
// ================= ik_llama.cpp quants
case GGML_TYPE_IQ2_K : return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ2_K_R4: return MMQ_DP4A_TXS_Q8_0_16;
default: return tile_x_sizes{0, 0, 0};
}
}
@@ -405,6 +411,9 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0;
// ================= ik_llama.cpp quants
case GGML_TYPE_IQ2_K : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ2_K_R4: return MMQ_MMA_TILE_X_K_Q3_K;
default: return 0;
}
}
@@ -3071,6 +3080,15 @@ static __device__ __forceinline__ void mmq_write_back_mma_id(
}
}
// ===================================== ik_llama.cpp quants =============================================================
//
// Strictly speaking, we should bite the bullet and change WARP_SIZE to warp_size or MMQ_TILE_NE_K.
// But as we basically don't support anything but Nvidia in the CUDA backend, we alays have
// WARP_SIZE = MMQ_TILE_NE_K = 32
//
// -------------------------------------------------------------------------------------------------------------------------------------
template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
@@ -3078,7 +3096,7 @@ struct mmq_type_traits_id;
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
//static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_DS4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3086,7 +3104,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> {
static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
//static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3094,7 +3112,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> {
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
//static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3102,7 +3120,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
//static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3110,7 +3128,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
//static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3118,7 +3136,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
//static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3126,7 +3144,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
//static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3134,7 +3152,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> {
static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
//static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3142,7 +3160,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> {
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
//static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3150,7 +3168,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> {
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
//static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3158,7 +3176,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> {
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
//static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3166,7 +3184,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> {
static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
//static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3174,7 +3192,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> {
static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
//static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3182,7 +3200,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> {
static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
//static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3190,7 +3208,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> {
static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
//static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3198,7 +3216,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {
static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
//static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3206,7 +3224,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {
static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
//static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3214,7 +3232,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {
static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
//static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3222,7 +3240,7 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_XS> {
static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
//static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
@@ -3925,5 +3943,8 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
// =================== ik_llama.cpp quants
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K_R4);
// -------------------------------------------------------------------------------------------------------------------------

View File

@@ -0,0 +1,200 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmq_id_common.cuh"
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_k(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
//constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE*2);
#else
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
constexpr int qstep = 8;
const int kqsx = threadIdx.x % qstep;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) {
int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep;
if (need_check) {
i = min(i, i_max);
}
const block_iq2_k * bxi = (const block_iq2_k *)x + i*stride + kbx0;
const float d = bxi->d;
uint16_t extra = bxi->extra >> (kqsx/4);
#ifdef __CUDA_ARCH__
uint32_t extra32[2] = { uint32_t(extra & 0xff) * 0x01010101, uint32_t(extra >> 8) * 0x01010101 };
#pragma unroll
for (int l = 0; l < qstep/4; ++l) {
const int ql = get_int_b4(bxi->qs, kqsx + qstep*l);
uint32_t val1 = ((ql >> 0) & 0x33333333) | ((extra32[l] << 2) & 0x44444444);
uint32_t val2 = ((ql >> 2) & 0x33333333) | ((extra32[l] << 0) & 0x44444444);
int2 v1 = get_int_from_table_8(val1, iq2nl_values);
int2 v2 = get_int_from_table_8(val2, iq2nl_values);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = v1.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = v2.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = v1.y;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = v2.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = v1.x;
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = v2.x;
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = v1.y;
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = v2.y;
#endif // INT8_MMA_AVAILABLE
}
#else
auto all_values = (const int *)iq2k_table;
#pragma unroll
for (int l = 0; l < qstep/4; ++l) {
const int ql = get_int_b4(bxi->qs, kqsx + qstep*l);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, all_values + ((extra & 0x01) << 8));
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = int_from_table_4((ql >> 2) & 0x03030303, all_values + ((extra & 0x04) << 6));
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = int_from_table_4((ql >> 4) & 0x03030303, all_values + ((extra & 0x10) << 4));
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = int_from_table_4((ql >> 6) & 0x03030303, all_values + ((extra & 0x40) << 2));
#else
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, all_values + ((extra & 0x01) << 8));
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = int_from_table_4((ql >> 2) & 0x03030303, all_values + ((extra & 0x04) << 6));
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = int_from_table_4((ql >> 4) & 0x03030303, all_values + ((extra & 0x10) << 4));
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = int_from_table_4((ql >> 6) & 0x03030303, all_values + ((extra & 0x40) << 2));
#endif // INT8_MMA_AVAILABLE
extra >>= 8;
}
#endif // __CUDA_ARCH__
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * (((bxi->scales[kqsx] >> 0) & 0xf) - 8);
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * (((bxi->scales[kqsx] >> 4) & 0xf) - 8);
#else
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * (((bxi->scales[kqsx] >> 0) & 0xf) - 8);
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * (((bxi->scales[kqsx] >> 4) & 0xf) - 8);
#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_k_r4(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE*2);
#else
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
const int kqsx = threadIdx.x/4; // 0...7 -> block of 32
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
if (need_check) {
i = min(i, i_max);
}
int i4 = i/4;
int ir = i%4;
const block_iq2_k_r4 * bxi = (const block_iq2_k_r4 *)x + 4*i4*stride + kbx0;
const float d = __half2float(bxi->d[ir]);
#ifdef __CUDA_ARCH__
#pragma unroll
for (int l = 0; l < 2; ++l) {
uint32_t extra = uint32_t((bxi->extra[ir+4*l] >> kqsx) & 1) * 0x04040404;
extra = extra | (extra << 4);
const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l);
uint32_t val1 = ((ql >> 0) & 0x33333333) | extra;
uint32_t val2 = ((ql >> 2) & 0x33333333) | extra;
int2 v1 = get_int_from_table_8(val1, iq2nl_values);
int2 v2 = get_int_from_table_8(val2, iq2nl_values);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = v1.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = v2.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = v1.y;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = v2.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = v1.x;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = v2.x;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = v1.y;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = v2.y;
#endif // INT8_MMA_AVAILABLE
}
#else
#pragma unroll
for (int l = 0; l < 2; ++l) {
auto values_l = (const int *)iq2k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 8);
const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, values_l);
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = int_from_table_4((ql >> 2) & 0x03030303, values_l);
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = int_from_table_4((ql >> 4) & 0x03030303, values_l);
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = int_from_table_4((ql >> 6) & 0x03030303, values_l);
#else
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, values_l);
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = int_from_table_4((ql >> 2) & 0x03030303, values_l);
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = int_from_table_4((ql >> 4) & 0x03030303, values_l);
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = int_from_table_4((ql >> 6) & 0x03030303, values_l);
#endif // INT8_MMA_AVAILABLE
}
#endif // __CUDA_ARCH__
int is = 8*kqsx + ir;
float dl1 = d * (((bxi->scales[is%32] >> 4*(is/32)) & 0xf) - 8);
is += 4;
float dl2 = d * (((bxi->scales[is%32] >> 4*(is/32)) & 0xf) - 8);
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = dl1;
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = dl2;
#else
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = dl1;
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = dl2;
#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_K> {
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_k<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_K_R4> {
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_k_r4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
};
DECL_MMQ_CASE(GGML_TYPE_IQ2_K);
DECL_MMQ_CASE(GGML_TYPE_IQ2_K_R4);