mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-02 01:50:01 +00:00
MMQ for iq5_k
This commit is contained in:
@@ -97,6 +97,9 @@ void ggml_cuda_op_mul_mat_q(
|
||||
case GGML_TYPE_IQ4_K:
|
||||
mul_mat_q_case<GGML_TYPE_IQ4_K>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ5_K:
|
||||
mul_mat_q_case<GGML_TYPE_IQ5_K>(ctx, args, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
@@ -136,6 +139,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
case GGML_TYPE_IQ4_K:
|
||||
case GGML_TYPE_IQ5_K:
|
||||
mmq_supported = true;
|
||||
break;
|
||||
default:
|
||||
|
||||
@@ -84,6 +84,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
case GGML_TYPE_IQ4_K:
|
||||
case GGML_TYPE_IQ5_K:
|
||||
return MMQ_Q8_1_DS_LAYOUT_D4;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -183,6 +184,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
||||
case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_IQ4_K : return MMQ_DP4A_TXS_Q8_0_16;
|
||||
case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16;
|
||||
default : return tile_x_sizes{0, 0, 0};
|
||||
}
|
||||
}
|
||||
@@ -222,6 +224,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||
case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_IQ4_K : return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
default : return 0;
|
||||
}
|
||||
}
|
||||
@@ -2485,75 +2488,89 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
}
|
||||
|
||||
//template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_k(
|
||||
// const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
||||
//
|
||||
//#ifdef INT8_MMA_AVAILABLE
|
||||
// int * x_qs = (int *) x_tile;
|
||||
// float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
||||
//#else
|
||||
// constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
|
||||
// int * x_qs = (int *) x_tile;
|
||||
// float * x_df = (float *) (x_qs + txs.qs);
|
||||
//#endif // INT8_MMA_AVAILABLE
|
||||
//
|
||||
//#pragma unroll
|
||||
// for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
||||
// int i = i0 + threadIdx.y;
|
||||
//
|
||||
// if (need_check) {
|
||||
// i = min(i, i_max);
|
||||
// }
|
||||
//
|
||||
// const block_iq4_k * bxi = (const block_iq4_k *)(x + i*stride) + kbx0;
|
||||
//
|
||||
// const uint16_t extra = bxi->extra >> 2*(threadIdx.x/4);
|
||||
// auto values0 = iq4k_values + ((extra & 1) << 4);
|
||||
// auto values1 = iq4k_values + ((extra & 2) << 3);
|
||||
// const int q4 = get_int_b4(bxi->qs, threadIdx.x);
|
||||
// const int q40 = (q4 >> 0) & 0x0F0F0F0F;
|
||||
// const int q41 = (q4 >> 4) & 0x0F0F0F0F;
|
||||
//
|
||||
// const int8_t * aux80 = (const int8_t *)&q40;
|
||||
// const char4 val0 = make_char4(values0[aux80[0]], values0[aux80[1]], values0[aux80[2]], values0[aux80[3]]);
|
||||
// const int8_t * aux81 = (const int8_t *)&q41;
|
||||
// const char4 val1 = make_char4(values1[aux80[1]], values1[aux81[1]], values1[aux81[2]], values1[aux81[3]]);
|
||||
//
|
||||
// const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
|
||||
//#ifdef INT8_MMA_AVAILABLE
|
||||
// x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k0 + 0] = *(const int *)&val0;
|
||||
// x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k0 + 4] = *(const int *)&val1;
|
||||
//#else
|
||||
// x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = *(const int *)&val0;
|
||||
// x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = *(const int *)&val1;
|
||||
//#endif // INT8_MMA_AVAILABLE
|
||||
// }
|
||||
//
|
||||
// const int ib32 = threadIdx.x % 8;
|
||||
//#pragma unroll
|
||||
// for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
|
||||
// int i = i0 + threadIdx.y * 4 + threadIdx.x / 8;
|
||||
//
|
||||
// if (need_check) {
|
||||
// i = min(i, i_max);
|
||||
// }
|
||||
//
|
||||
// const block_iq4_k * bxi = (const block_iq4_k *)(x + i*stride) + kbx0;
|
||||
// const uint8_t sh = bxi->scales_h[ib32/2] >> 4*(ib32%2);
|
||||
// const int ls1 = ((bxi->scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32;
|
||||
// const int ls2 = ((bxi->scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32;
|
||||
//
|
||||
// const float d = bxi->d;
|
||||
//#ifdef INT8_MMA_AVAILABLE
|
||||
// x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*ib32 + 0] = d * ls1;
|
||||
// x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*ib32 + 1] = d * ls2;
|
||||
//#else
|
||||
// // TODO
|
||||
// x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls1;
|
||||
//#endif // INT8_MMA_AVAILABLE
|
||||
// }
|
||||
//
|
||||
//}
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq5_k(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
constexpr int qstep = 8;
|
||||
const int kqsx = threadIdx.x % qstep;
|
||||
|
||||
auto values = iq5nl_values;
|
||||
|
||||
uint32_t aux32[2];
|
||||
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) {
|
||||
int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep;
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_iq5_k * bxi = (const block_iq5_k *)(x + i*stride) + kbx0;
|
||||
|
||||
// kqsx = 0 -> 0,1,2,3 + 8,9,10,11
|
||||
// kqsx = 1 -> 4,5,6,7 + 12,13,14,15
|
||||
// kqsx = 2 -> 16,17,18,19 + 24,25,26,27
|
||||
// kqsx = 3 -> 20,21,22,23 + 28,29,30,31
|
||||
// or is it better
|
||||
// kqsx = 0 -> 0,1 + 8,9 + 16,17 + 24,25
|
||||
// kqsx = 1 -> 2,3 + 10,11 + 18,19 + 26,27, etc.
|
||||
// or perhaps even
|
||||
// kqsx = 0 -> 0, 8, 16, 24, 32, 40, 48, 56
|
||||
// kqsx = 1 -> 1, 9, 17, 25, 33, 41, 49, 57, etc.
|
||||
|
||||
int qh = get_int_b4(bxi->qh, kqsx);
|
||||
uint16_t extra = bxi->extra >> (kqsx/4);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < qstep/2; ++l) {
|
||||
|
||||
const int ql = get_int_b4(bxi->qs, kqsx + qstep*l);
|
||||
aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh & 0x01010101) << 4) | ((extra & 1) << 5);
|
||||
aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh & 0x02020202) << 3) | ((extra & 4) << 3);
|
||||
qh >>= 2;
|
||||
extra >>= 4;
|
||||
|
||||
const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]);
|
||||
const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]);
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 0] = *(const int *)&val0;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 8] = *(const int *)&val1;
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 0] = *(const int *)&val0;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 8] = *(const int *)&val1;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
// iq4_k: scales_h[ib/8] |= (l_h << 2*(ib%8)); ib: 0...15
|
||||
// iq5_k: scales_h[ib/4] |= (l_h << 2*(ib%4)); ib: 0...15
|
||||
|
||||
const uint8_t sh = bxi->scales_h[kqsx/2] >> 4*(kqsx%2);
|
||||
const int ls1 = ((bxi->scales_l[kqsx] & 0xf) | ((sh << 4) & 0x30)) - 32;
|
||||
const int ls2 = ((bxi->scales_l[kqsx] >> 4) | ((sh << 2) & 0x30)) - 32;
|
||||
|
||||
const float d = bxi->d;
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * ls1;
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * ls2;
|
||||
#else
|
||||
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * ls1;
|
||||
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * ls2;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
static __device__ __forceinline__ void mmq_write_back_dp4a(
|
||||
@@ -2779,12 +2796,19 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_K> {
|
||||
static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
|
||||
//static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_k<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ5_K> {
|
||||
static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_k<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KS> {
|
||||
static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
|
||||
@@ -3231,6 +3255,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K);
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmq.cuh"
|
||||
|
||||
DECL_MMQ_CASE(GGML_TYPE_IQ5_K);
|
||||
Reference in New Issue
Block a user