mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 22:54:10 +00:00
mmq_id: add iq2_kl
This commit is contained in:
@@ -201,6 +201,9 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx,
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
mul_mat_q_case_id<GGML_TYPE_IQ2_KS>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_KL:
|
||||
mul_mat_q_case_id<GGML_TYPE_IQ2_KL>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_K:
|
||||
mul_mat_q_case_id<GGML_TYPE_IQ2_K>(ctx, args, stream);
|
||||
break;
|
||||
@@ -433,6 +436,7 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) {
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
case GGML_TYPE_IQ2_KL:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ2_K_R4:
|
||||
mmq_supported = true;
|
||||
|
||||
@@ -81,6 +81,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
case GGML_TYPE_IQ2_KL:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ2_K_R4:
|
||||
return MMQ_Q8_1_DS_LAYOUT_D4;
|
||||
@@ -374,6 +375,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
||||
case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0;
|
||||
// ================= ik_llama.cpp quants
|
||||
case GGML_TYPE_IQ2_KS : return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_IQ2_KL : return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_IQ2_K : return MMQ_DP4A_TXS_Q8_0_16;
|
||||
case GGML_TYPE_IQ2_K_R4: return MMQ_DP4A_TXS_Q8_0_16;
|
||||
default: return tile_x_sizes{0, 0, 0};
|
||||
@@ -415,6 +417,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||
case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
// ================= ik_llama.cpp quants
|
||||
case GGML_TYPE_IQ2_KS : return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_IQ2_KL : return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_IQ2_K : return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
case GGML_TYPE_IQ2_K_R4: return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
default: return 0;
|
||||
@@ -3951,6 +3954,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
|
||||
// =================== ik_llama.cpp quants
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KS);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KL);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K_R4);
|
||||
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
#include "../mmq_id_common.cuh"
|
||||
|
||||
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_kl(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
||||
|
||||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
const int kqsx = threadIdx.x/4;
|
||||
|
||||
uint32_t aux32[2];
|
||||
const uint8_t * a8 = (const uint8_t *)aux32;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
|
||||
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const half * dptr = (const half *)(x + i*stride);
|
||||
const float d = *dptr;
|
||||
const block_iq2_kl * bxi = (const block_iq2_kl *)(dptr + 1) + kbx0;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
auto ql = get_int_b2(bxi->qs, 4*(kqsx/2) + 2*(kqsx%2) + j);
|
||||
auto qh = get_int_b2(bxi->qh, 2*(kqsx%2) + j) >> 2*(kqsx/2);
|
||||
aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh << 4) & 0x10101010);
|
||||
aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh << 3) & 0x10101010);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
int val1 = iq2kl_values[a8[2*l+0]] | (iq2kl_values[a8[2*l+1]] << 16);
|
||||
int val2 = iq2kl_values[a8[2*l+4]] | (iq2kl_values[a8[2*l+5]] << 16);
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 16*(kqsx/2) + 4*(kqsx%2) + 2*j + l + 0] = val1;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 16*(kqsx/2) + 4*(kqsx%2) + 2*j + l + 8] = val2;
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 16*(kqsx/2) + 4*(kqsx%2) + 2*j + l + 0] = val1;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 16*(kqsx/2) + 4*(kqsx%2) + 2*j + l + 8] = val2;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
int ls = int(((bxi->scales_l[kqsx%4] >> 4*(kqsx/4)) & 0xf) | (((bxi->scales_h >> 2*kqsx) & 3) << 4)) - 32;
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls;
|
||||
#else
|
||||
x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = d * ls;
|
||||
#endif
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, bool need_check>
|
||||
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_KL> {
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_kl<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
||||
};
|
||||
|
||||
DECL_MMQ_CASE(GGML_TYPE_IQ2_KL);
|
||||
Reference in New Issue
Block a user