mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 02:41:47 +00:00
mmq_id: adding iq6_k
This commit is contained in:
@@ -249,6 +249,9 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx,
|
|||||||
case GGML_TYPE_IQ5_K_R4:
|
case GGML_TYPE_IQ5_K_R4:
|
||||||
mul_mat_q_case_id<GGML_TYPE_IQ5_K_R4>(ctx, args, stream);
|
mul_mat_q_case_id<GGML_TYPE_IQ5_K_R4>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
|
case GGML_TYPE_IQ6_K:
|
||||||
|
mul_mat_q_case_id<GGML_TYPE_IQ6_K>(ctx, args, stream);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
break;
|
break;
|
||||||
@@ -491,6 +494,7 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) {
|
|||||||
case GGML_TYPE_IQ5_KS_R4:
|
case GGML_TYPE_IQ5_KS_R4:
|
||||||
case GGML_TYPE_IQ5_K:
|
case GGML_TYPE_IQ5_K:
|
||||||
case GGML_TYPE_IQ5_K_R4:
|
case GGML_TYPE_IQ5_K_R4:
|
||||||
|
case GGML_TYPE_IQ6_K:
|
||||||
mmq_supported = true;
|
mmq_supported = true;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
|||||||
case GGML_TYPE_IQ5_KS_R4:
|
case GGML_TYPE_IQ5_KS_R4:
|
||||||
case GGML_TYPE_IQ5_K:
|
case GGML_TYPE_IQ5_K:
|
||||||
case GGML_TYPE_IQ5_K_R4:
|
case GGML_TYPE_IQ5_K_R4:
|
||||||
|
case GGML_TYPE_IQ6_K:
|
||||||
return MMQ_Q8_1_DS_LAYOUT_D4;
|
return MMQ_Q8_1_DS_LAYOUT_D4;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
@@ -405,6 +406,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
|||||||
case GGML_TYPE_IQ5_KS_R4: return MMQ_DP4A_TXS_Q8_0;
|
case GGML_TYPE_IQ5_KS_R4: return MMQ_DP4A_TXS_Q8_0;
|
||||||
case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16;
|
case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16;
|
||||||
case GGML_TYPE_IQ5_K_R4: return MMQ_DP4A_TXS_Q8_0_16;
|
case GGML_TYPE_IQ5_K_R4: return MMQ_DP4A_TXS_Q8_0_16;
|
||||||
|
case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16;
|
||||||
default: return tile_x_sizes{0, 0, 0};
|
default: return tile_x_sizes{0, 0, 0};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -460,6 +462,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|||||||
case GGML_TYPE_IQ5_KS_R4: return MMQ_MMA_TILE_X_K_Q8_0;
|
case GGML_TYPE_IQ5_KS_R4: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||||
case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K;
|
case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K;
|
||||||
case GGML_TYPE_IQ5_K_R4: return MMQ_MMA_TILE_X_K_Q3_K;
|
case GGML_TYPE_IQ5_K_R4: return MMQ_MMA_TILE_X_K_Q3_K;
|
||||||
|
case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K;
|
||||||
default: return 0;
|
default: return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -4082,5 +4085,6 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS);
|
|||||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4);
|
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4);
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K);
|
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K);
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K_R4);
|
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K_R4);
|
||||||
|
extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K);
|
||||||
|
|
||||||
// -------------------------------------------------------------------------------------------------------------------------
|
// -------------------------------------------------------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -0,0 +1,80 @@
|
|||||||
|
#include "../mmq_id_common.cuh"
|
||||||
|
|
||||||
|
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq6_k(
|
||||||
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
||||||
|
|
||||||
|
constexpr int nwarps = mmq_get_nwarps_device();
|
||||||
|
|
||||||
|
#ifdef INT8_MMA_AVAILABLE
|
||||||
|
int * x_qs = (int *) x_tile;
|
||||||
|
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
||||||
|
#else
|
||||||
|
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
|
||||||
|
int * x_qs = (int *) x_tile;
|
||||||
|
float * x_df = (float *) (x_qs + txs.qs);
|
||||||
|
#endif // INT8_MMA_AVAILABLE
|
||||||
|
|
||||||
|
constexpr int qstep = 8;
|
||||||
|
const int kqsx = threadIdx.x % qstep;
|
||||||
|
|
||||||
|
auto values = iq6nl_values;
|
||||||
|
int qh[2];
|
||||||
|
|
||||||
|
uint32_t aux32[2];
|
||||||
|
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) {
|
||||||
|
int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep;
|
||||||
|
|
||||||
|
if (need_check) {
|
||||||
|
i = min(i, i_max);
|
||||||
|
}
|
||||||
|
|
||||||
|
const block_iq6_k * bxi = (const block_iq6_k *)(x + i*stride) + kbx0;
|
||||||
|
|
||||||
|
const float d = bxi->d;
|
||||||
|
uint16_t extra = bxi->extra >> (kqsx/4);
|
||||||
|
|
||||||
|
qh[0] = get_int_b4(bxi->qh, kqsx+0);
|
||||||
|
qh[1] = get_int_b4(bxi->qh, kqsx+8);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < qstep/2; ++l) {
|
||||||
|
|
||||||
|
const int ql = get_int_b4(bxi->qs, kqsx + qstep*l);
|
||||||
|
aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh[l/2] & 0x03030303) << 4) | ((extra & 1) * 0x40404040);
|
||||||
|
aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh[l/2] & 0x0c0c0c0c) << 2) | ((extra & 4) * 0x10101010);
|
||||||
|
qh[l/2] >>= 4;
|
||||||
|
extra >>= 4;
|
||||||
|
|
||||||
|
const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]);
|
||||||
|
const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]);
|
||||||
|
|
||||||
|
#ifdef INT8_MMA_AVAILABLE
|
||||||
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 0] = *(const int *)&val0;
|
||||||
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 8] = *(const int *)&val1;
|
||||||
|
#else
|
||||||
|
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 0] = *(const int *)&val0;
|
||||||
|
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 8] = *(const int *)&val1;
|
||||||
|
#endif // INT8_MMA_AVAILABLE
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef INT8_MMA_AVAILABLE
|
||||||
|
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * bxi->scales[2*kqsx+0];
|
||||||
|
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * bxi->scales[2*kqsx+1];
|
||||||
|
#else
|
||||||
|
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * bxi->scales[2*kqsx+0];
|
||||||
|
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * bxi->scales[2*kqsx+1];
|
||||||
|
#endif // INT8_MMA_AVAILABLE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int mmq_x, int mmq_y, bool need_check>
|
||||||
|
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ6_K> {
|
||||||
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq6_k<mmq_y, need_check>;
|
||||||
|
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
|
||||||
|
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
|
||||||
|
};
|
||||||
|
|
||||||
|
DECL_MMQ_CASE(GGML_TYPE_IQ6_K);
|
||||||
Reference in New Issue
Block a user