mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
mmq_id: adding iq5_ks, iq5_ks_r4
This commit is contained in:
@@ -234,6 +234,12 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx,
|
||||
case GGML_TYPE_IQ4_K_R4:
|
||||
mul_mat_q_case_id<GGML_TYPE_IQ4_K_R4>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ5_KS:
|
||||
mul_mat_q_case_id<GGML_TYPE_IQ5_KS>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ5_KS_R4:
|
||||
mul_mat_q_case_id<GGML_TYPE_IQ5_KS_R4>(ctx, args, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
@@ -471,6 +477,8 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) {
|
||||
case GGML_TYPE_IQ4_KS_R4:
|
||||
case GGML_TYPE_IQ4_K:
|
||||
case GGML_TYPE_IQ4_K_R4:
|
||||
case GGML_TYPE_IQ5_KS:
|
||||
case GGML_TYPE_IQ5_KS_R4:
|
||||
mmq_supported = true;
|
||||
break;
|
||||
default:
|
||||
|
||||
@@ -92,6 +92,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
||||
case GGML_TYPE_IQ4_KS_R4:
|
||||
case GGML_TYPE_IQ4_K:
|
||||
case GGML_TYPE_IQ4_K_R4:
|
||||
case GGML_TYPE_IQ5_KS:
|
||||
case GGML_TYPE_IQ5_KS_R4:
|
||||
return MMQ_Q8_1_DS_LAYOUT_D4;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -394,6 +396,8 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
||||
case GGML_TYPE_IQ4_KS_R4: return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_IQ4_K : return MMQ_DP4A_TXS_Q8_0_16;
|
||||
case GGML_TYPE_IQ4_K_R4: return MMQ_DP4A_TXS_Q8_0_16;
|
||||
case GGML_TYPE_IQ5_KS : return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_IQ5_KS_R4: return MMQ_DP4A_TXS_Q8_0;
|
||||
default: return tile_x_sizes{0, 0, 0};
|
||||
}
|
||||
}
|
||||
@@ -444,6 +448,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||
case GGML_TYPE_IQ4_KS_R4: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_IQ4_K : return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
case GGML_TYPE_IQ4_K_R4: return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
case GGML_TYPE_IQ5_KS : return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_IQ5_KS_R4: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
@@ -3989,5 +3995,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS_R4);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K_R4);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4);
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
146
ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_id.cu
Normal file
146
ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_id.cu
Normal file
@@ -0,0 +1,146 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmq_id_common.cuh"
|
||||
|
||||
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq5_ks(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
||||
|
||||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ5_KS, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
constexpr int qstep = 8;
|
||||
const int kqsx = threadIdx.x % qstep;
|
||||
|
||||
auto values = iq5nl_values;
|
||||
|
||||
uint32_t aux32[2];
|
||||
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) {
|
||||
int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep;
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const float * dptr = (const float *)(x + i*stride);
|
||||
const float d = dptr[0];
|
||||
const block_iq5_ks * bxi = (const block_iq5_ks *)(dptr + 1) + kbx0;
|
||||
|
||||
int qh = get_int_b4(bxi->qh, kqsx);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < qstep/2; ++l) {
|
||||
|
||||
const int ql = get_int_b4(bxi->qs, kqsx + qstep*l);
|
||||
aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh & 0x01010101) << 4) | ((bxi->scales[2*l+0] & 1) * 0x20202020);
|
||||
aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh & 0x02020202) << 3) | ((bxi->scales[2*l+1] & 1) * 0x20202020);
|
||||
qh >>= 2;
|
||||
|
||||
const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]);
|
||||
const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]);
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 16*l + 0] = *(const int *)&val0;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 16*l + 8] = *(const int *)&val1;
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 0] = *(const int *)&val0;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 8] = *(const int *)&val1;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ((bxi->scales[kqsx] & 254) - 127);
|
||||
#else
|
||||
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + kqsx] = d * ((bxi->scales[kqsx] & 254) - 127);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq5_ks_r4(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
||||
|
||||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ5_KS_R4, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
const int kqsx = threadIdx.x/4;
|
||||
|
||||
uint32_t aux32[2];
|
||||
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
|
||||
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
int i4 = i/4;
|
||||
int ir = i%4;
|
||||
|
||||
const float * dptr = (const float *)(x + 4*i4*stride);
|
||||
const block_iq5_ks_r4 * bxi = (const block_iq5_ks_r4 *)(dptr + 4) + kbx0;
|
||||
|
||||
const int ls = (bxi->scales[4*kqsx + ir] & 254) - 127;
|
||||
auto values = iq5nl_values + ((bxi->scales[4*kqsx+ir] & 1) << 5);
|
||||
|
||||
int qh = *((const int *)bxi->qh + 4*kqsx + ir);
|
||||
const int * ql = (const int *)bxi->qs + 16*kqsx + ir;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
aux32[0] = ((ql[4*j] >> 0) & 0x0f0f0f0f) | ((qh << 4) & 0x10101010);
|
||||
aux32[1] = ((ql[4*j] >> 4) & 0x0f0f0f0f) | ((qh << 3) & 0x10101010);
|
||||
qh >>= 2;
|
||||
const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]);
|
||||
const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]);
|
||||
const int k0 = 8*kqsx + 4*(j%2) + j/2;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = *(const int *)&val0;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = *(const int *)&val1;
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = *(const int *)&val0;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = *(const int *)&val1;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[ir] * ls;
|
||||
#else
|
||||
x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[ir] * ls;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, bool need_check>
|
||||
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ5_KS> {
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, bool need_check>
|
||||
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ5_KS_R4> {
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks_r4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
||||
};
|
||||
|
||||
DECL_MMQ_CASE(GGML_TYPE_IQ5_KS);
|
||||
DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4);
|
||||
Reference in New Issue
Block a user