mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 22:54:10 +00:00
mmq_id: add iq1_s_r4
This commit is contained in:
@@ -195,6 +195,9 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx,
|
||||
case GGML_TYPE_IQ1_S:
|
||||
mul_mat_q_case_id<GGML_TYPE_IQ1_S>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S_R4:
|
||||
mul_mat_q_case_id<GGML_TYPE_IQ1_S_R4>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
mul_mat_q_case_id<GGML_TYPE_IQ4_XS>(ctx, args, stream);
|
||||
break;
|
||||
@@ -476,6 +479,7 @@ bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) {
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_S_R4:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
|
||||
@@ -79,6 +79,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
||||
case GGML_TYPE_IQ3_S:
|
||||
return MMQ_Q8_1_DS_LAYOUT_D4;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_S_R4:
|
||||
return MMQ_Q8_1_DS_LAYOUT_DS4;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
@@ -387,9 +388,9 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
||||
case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_IQ1_S_R4:return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0;
|
||||
// ================= ik_llama.cpp quants
|
||||
case GGML_TYPE_IQ2_KS : return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_IQ2_KL : return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_IQ2_K : return MMQ_DP4A_TXS_Q8_0_16;
|
||||
@@ -443,9 +444,9 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||
case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_IQ1_S_R4:return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
// ================= ik_llama.cpp quants
|
||||
case GGML_TYPE_IQ2_KS : return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_IQ2_KL : return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_IQ2_K : return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
@@ -3052,6 +3053,76 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s_r4(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
||||
|
||||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + 2*WARP_SIZE);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
const int kbx = threadIdx.x / 4;
|
||||
const int kqsx = threadIdx.x % 4;
|
||||
|
||||
int32_t grid32[2];
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
||||
int i = i0 + threadIdx.y;
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
const int i4 = i/4;
|
||||
const int ir = i%4;
|
||||
|
||||
const block_iq1_s_r4 * bxi = (const block_iq1_s_r4 *)(x + 4*i4*stride + 4*sizeof(half)) + kbx0 + kbx;
|
||||
|
||||
grid32[0] = iq1s_grid_gpu[bxi->qs[4*kqsx+ir] | (((bxi->qh[ir] >> 3*kqsx) & 7) << 8)];
|
||||
grid32[1] = ((grid32[0] >> 4) & 0x0f0f0f0f) << 3;
|
||||
grid32[0] = (grid32[0] & 0x0f0f0f0f) << 3;
|
||||
const int shift = bxi->qh[ir] & 0x8000 ? 0x09090909 : 0x07070707;
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kbx + 2*kqsx + 0] = __vsubss4(grid32[0], shift);
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kbx + 2*kqsx + 1] = __vsubss4(grid32[1], shift);
|
||||
#else
|
||||
// TODO
|
||||
//x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
const int blocks_per_tile_x_row = WARP_SIZE / 4;
|
||||
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
|
||||
int i = i0 + threadIdx.y * 4 + threadIdx.x / blocks_per_tile_x_row;
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
const int i4 = i/4;
|
||||
const int ir = i%4;
|
||||
|
||||
const half * dptr = (const half *)(x + 4*i4*stride);
|
||||
const block_iq1_s_r4 * bxi = (const block_iq1_s_r4 *)(dptr + 4) + kbx0 + kbxd;
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = 0.125f * __half2float(dptr[ir]) * (((bxi->qh[ir] >> 11) & 14) + 1);
|
||||
#else
|
||||
// TODO
|
||||
//x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
||||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
@@ -3345,12 +3416,18 @@ struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {
|
||||
|
||||
template <int mmq_x, int mmq_y, bool need_check>
|
||||
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {
|
||||
//static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, bool need_check>
|
||||
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S_R4> {
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s_r4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_DS4>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, bool need_check>
|
||||
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {
|
||||
//static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
|
||||
@@ -4066,9 +4143,9 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
|
||||
// =================== ik_llama.cpp quants
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KS);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KL);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K);
|
||||
|
||||
@@ -3,3 +3,4 @@
|
||||
#include "../mmq_id_common.cuh"
|
||||
|
||||
DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
|
||||
DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4);
|
||||
|
||||
Reference in New Issue
Block a user