iq3_kt: MMQ

This commit is contained in:
Iwan Kawrakow
2025-06-16 17:21:17 +03:00
parent 32ff1f956f
commit f6fa5652a3
3 changed files with 98 additions and 1 deletions

View File

@@ -106,6 +106,9 @@ void ggml_cuda_op_mul_mat_q(
case GGML_TYPE_IQ2_KT:
mul_mat_q_case<GGML_TYPE_IQ2_KT>(ctx, args, stream);
break;
case GGML_TYPE_IQ3_KT:
mul_mat_q_case<GGML_TYPE_IQ3_KT>(ctx, args, stream);
break;
case GGML_TYPE_IQ5_KS:
mul_mat_q_case<GGML_TYPE_IQ5_KS>(ctx, args, stream);
break;
@@ -178,8 +181,9 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ4_KT:
case GGML_TYPE_IQ2_KT:
case GGML_TYPE_IQ3_KT:
case GGML_TYPE_IQ4_KT:
mmq_supported = true;
break;
default:

View File

@@ -94,6 +94,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
case GGML_TYPE_IQ5_KS_R4:
case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ2_KT:
case GGML_TYPE_IQ3_KT:
case GGML_TYPE_IQ4_KT:
return MMQ_Q8_1_DS_LAYOUT_D4;
default:
@@ -205,6 +206,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ2_KT : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ3_KT : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_KT : return MMQ_DP4A_TXS_Q8_0;
default : return tile_x_sizes{0, 0, 0};
}
@@ -255,6 +257,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ2_KT : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ3_KT : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_KT : return MMQ_MMA_TILE_X_K_Q8_0;
default : return 0;
}
@@ -2939,6 +2942,83 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_kt(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
constexpr uint32_t ka = 0xCBAC1FED;
constexpr uint32_t km = 0x3f3f3f3f;
#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
const int kqsx = threadIdx.x;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;
if (need_check) {
i = min(i, i_max);
}
const block_iq3_kt * bxi = (const block_iq3_kt *)(x + i*stride + sizeof(float)) + kbx0;
int ib32 = kqsx/4;
int j = kqsx%4;
const auto ql = (const uint16_t *)bxi->ql;
const auto qh = (const uint32_t *)bxi->qh;
uint32_t mask = 0x01010101 << ib32;
uint32_t val = ql[4*ib32+j] + 4096;
int2 v = {0, 0};
for (int k = 0; k < 4; ++k) {
val *= ka;
v.x |= std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)) << 8*k;
}
auto signs = __vcmpne4(qh[2*j+0] & mask, 0);
v.x = __vsub4(v.x ^ signs, signs);
for (int k = 0; k < 4; ++k) {
val *= ka;
v.y |= std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)) << 8*k;
}
signs = __vcmpne4(qh[2*j+1] & mask, 0);
v.y = __vsub4(v.y ^ signs, signs);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x;
x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y;
#endif // INT8_MMA_AVAILABLE
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
if (need_check) {
i = min(i, i_max);
}
const float * dptr = (const float *)(x + i*stride);
const float d = dptr[0] * 1.01f;
const block_iq3_kt * bxi = (const block_iq3_kt *)(dptr + 1) + kbx0;
int ib32 = threadIdx.x % 8;
const int ls = (bxi->scales[ib32%4] >> 4*(ib32/4)) & 0xf;
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls;
#else
x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls;
#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq5_ks_r4(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
@@ -3545,6 +3625,13 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_KT> {
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_KT> {
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_kt<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ5_KS> {
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks<mmq_y, nwarps, need_check>;
@@ -4008,6 +4095,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KT);
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KT);
extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KT);
// -------------------------------------------------------------------------------------------------------------------------

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmq.cuh"
DECL_MMQ_CASE(GGML_TYPE_IQ3_KT);